Source code for GOOD.kernel.launch

import itertools
import os
import os.path
import sys
from pathlib import Path

from GOOD.definitions import ROOT_DIR
from GOOD.kernel.launcher_manager import load_launcher
from GOOD.utils.args import AutoArgs
from GOOD.utils.config_reader import load_config


[docs]def launch(): conda_interpreter = sys.executable conda_goodtg = os.path.join(sys.exec_prefix, 'bin', 'goodtg') auto_args = AutoArgs().parse_args(known_only=True) auto_args.config_root = get_config_root(auto_args) jobs_group = make_list_cmds(auto_args, conda_goodtg) launcher = load_launcher(auto_args.launcher) launcher(jobs_group, auto_args)
[docs]def get_config_root(auto_args): if auto_args.config_root: if os.path.isabs(auto_args.config_root): config_root = Path(auto_args.config_root) else: config_root = Path(ROOT_DIR, 'configs', auto_args.config_root) else: config_root = Path(ROOT_DIR, 'configs', 'GOOD_configs') return config_root
[docs]def make_list_cmds(auto_args, conda_goodtg): args_group = [] for dataset_path in auto_args.config_root.iterdir(): if not dataset_path.is_dir() or dataset_path.name not in auto_args.allow_datasets: continue for domain_path in dataset_path.iterdir(): if not domain_path.is_dir() or ( auto_args.allow_domains and domain_path.name not in auto_args.allow_domains): continue for shift_path in domain_path.iterdir(): if not shift_path.is_dir() or ( auto_args.allow_shifts and shift_path.name not in auto_args.allow_shifts): continue for ood_config_path in shift_path.iterdir(): if 'base' in ood_config_path.name: continue allowed = False if auto_args.allow_algs: for allowed_alg in auto_args.allow_algs: if allowed_alg in ood_config_path.name: allowed = True else: allowed = True if not allowed: continue # args = args_parser(['--config_path', str(ood_config_path)]) # config = config_summoner(args) if auto_args.sweep_root: cmd_args = [f'{conda_goodtg} --config_path \"{ood_config_path}\" --log_file default'] if os.path.isabs(auto_args.sweep_root): sweep_root = Path(auto_args.sweep_root) else: sweep_root = Path(ROOT_DIR, 'configs', auto_args.sweep_root) sweep_path = sweep_root / ood_config_path.stem / dataset_path.name / domain_path.name / shift_path.name / 'base.yaml' sweep_config, _, _ = load_config(str(sweep_path)) if 'extra_param' in sweep_config.keys(): sweep_config['extra_param'] = list(itertools.product(*sweep_config['extra_param'])) sweep_keys = list(sweep_config.keys()) sweep_values = list(itertools.product(*sweep_config.values())) sweep_args = [] for value_set in sweep_values: sweep_arg = [] for key, value in zip(sweep_keys, value_set): if key == 'extra_param': sweep_arg.append(f'--{key} ' + ' '.join([str(v) for v in value])) else: sweep_arg.append(f'--{key} {value}') sweep_args.append(' '.join(sweep_arg)) cmd_args_product = list(itertools.product(cmd_args, sweep_args)) cmd_args = [' '.join(args_set) for args_set in cmd_args_product] else: cmd_args = [ f'{conda_goodtg} --exp_round {round} --config_path \"{ood_config_path}\"' for round in auto_args.allow_rounds] args_group += cmd_args return args_group
if __name__ == '__main__': launch()