diff --git a/pufferlib/sweep.py b/pufferlib/sweep.py index 36e27bf42a..6ef8446423 100644 --- a/pufferlib/sweep.py +++ b/pufferlib/sweep.py @@ -23,6 +23,25 @@ EPSILON = 1e-6 +SWEEP_CONFIG_KEYS = frozenset(( + 'method', + 'metric', + 'metric_distribution', + 'goal', + 'downsample', + 'use_gpu', + 'prune_pareto', + 'sweep_only', + 'max_suggestion_cost', + 'early_stop_quantile', + 'gpus', + 'max_runs', + 'match_enemy_model_path', + 'match_num_games', + 'match_enemy_hidden_size', + 'match_enemy_num_layers', +)) + def unroll_nested_dict(d): if not isinstance(d, dict): return d @@ -145,8 +164,7 @@ def _params_from_puffer_sweep(sweep_config, only_include=None): only_include = [p.strip() for p in sweep_config['sweep_only'].split(',')] for name, param in sweep_config.items(): - if name in ('method', 'metric', 'metric_distribution', 'goal', 'downsample', 'use_gpu', 'prune_pareto', - 'sweep_only', 'max_suggestion_cost', 'early_stop_quantile', 'gpus', 'max_runs'): + if name in SWEEP_CONFIG_KEYS: continue assert isinstance(param, dict), f'Param {name} is not a dict' diff --git a/tests/test_sweep_config.py b/tests/test_sweep_config.py new file mode 100644 index 0000000000..f7b0a3cdba --- /dev/null +++ b/tests/test_sweep_config.py @@ -0,0 +1,36 @@ +import ast +import configparser +from collections import defaultdict + +import pufferlib.sweep + + +def _load_sweep_config(*paths): + parser = configparser.ConfigParser() + parser.read(paths) + + args = defaultdict(dict) + for section in parser.sections(): + for key in parser[section]: + try: + value = ast.literal_eval(parser[section][key]) + except Exception: + value = parser[section][key] + + name = key if section == 'base' else f'{section}.{key}' + cursor = args + for subkey in name.split('.'): + previous = cursor + cursor = cursor.setdefault(subkey, {}) + previous[subkey] = value + + return args['sweep'] + + +def test_match_sweep_config_keys_are_not_hyperparameters(): + sweep_config = _load_sweep_config('config/default.ini', 'config/breakout.ini') + spaces = pufferlib.sweep._params_from_puffer_sweep(sweep_config) + + assert 'match_enemy_model_path' not in spaces + assert 'match_num_games' not in spaces + assert 'train' in spaces