Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions pufferlib/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@

EPSILON = 1e-6

SWEEP_CONFIG_KEYS = frozenset((
'method',
'metric',
'metric_distribution',
'goal',
'downsample',
'use_gpu',
'prune_pareto',
'sweep_only',
'max_suggestion_cost',
'early_stop_quantile',
'gpus',
'max_runs',
'match_enemy_model_path',
'match_num_games',
'match_enemy_hidden_size',
'match_enemy_num_layers',
))

def unroll_nested_dict(d):
if not isinstance(d, dict):
return d
Expand Down Expand Up @@ -145,8 +164,7 @@ def _params_from_puffer_sweep(sweep_config, only_include=None):
only_include = [p.strip() for p in sweep_config['sweep_only'].split(',')]

for name, param in sweep_config.items():
if name in ('method', 'metric', 'metric_distribution', 'goal', 'downsample', 'use_gpu', 'prune_pareto',
'sweep_only', 'max_suggestion_cost', 'early_stop_quantile', 'gpus', 'max_runs'):
if name in SWEEP_CONFIG_KEYS:
continue

assert isinstance(param, dict), f'Param {name} is not a dict'
Expand Down
36 changes: 36 additions & 0 deletions tests/test_sweep_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import ast
import configparser
from collections import defaultdict

import pufferlib.sweep


def _load_sweep_config(*paths):
parser = configparser.ConfigParser()
parser.read(paths)

args = defaultdict(dict)
for section in parser.sections():
for key in parser[section]:
try:
value = ast.literal_eval(parser[section][key])
except Exception:
value = parser[section][key]

name = key if section == 'base' else f'{section}.{key}'
cursor = args
for subkey in name.split('.'):
previous = cursor
cursor = cursor.setdefault(subkey, {})
previous[subkey] = value

return args['sweep']


def test_match_sweep_config_keys_are_not_hyperparameters():
sweep_config = _load_sweep_config('config/default.ini', 'config/breakout.ini')
spaces = pufferlib.sweep._params_from_puffer_sweep(sweep_config)

assert 'match_enemy_model_path' not in spaces
assert 'match_num_games' not in spaces
assert 'train' in spaces