Source code for GOOD.kernel.launchers.harvest_launcher

import copy
import os
import shlex
import shutil
from pathlib import Path
import distutils.dir_util

import numpy as np
from ruamel.yaml import YAML
from tqdm import tqdm

from GOOD import config_summoner
from GOOD import register
from GOOD.definitions import ROOT_DIR
from GOOD.utils.args import AutoArgs
from GOOD.utils.args import args_parser
from GOOD.utils.config_reader import load_config, args2config, merge_dicts
from .basic_launcher import Launcher
from typing import Literal
from cilog import create_logger


[docs]@register.launcher_register class HarvestLauncher(Launcher): def __init__(self): super(HarvestLauncher, self).__init__() self.watch = False self.pick_reference = [-1] self.test_index = -2 self.logger = create_logger('Harvest', file='result_table.md', use_color=False) def __call__(self, jobs_group, auto_args: AutoArgs): result_dict = self.harvest_all_fruits(jobs_group) best_fruits = self.picky_farmer(result_dict) if auto_args.sweep_root: self.process_final_root(auto_args) self.update_best_config(auto_args, best_fruits) else: self.show_your_fruits(best_fruits) def update_best_config(self, auto_args, best_fruits): for ddsa_key in best_fruits.keys(): final_path = (auto_args.final_root / '/'.join(ddsa_key.split(' '))).with_suffix('.yaml') top_config, _, _ = load_config(final_path, skip_include=True) whole_config, _, _ = load_config(final_path) args_list = shlex.split(best_fruits[ddsa_key][0]) args = args_parser(['--config_path', str(final_path)] + args_list) args2config(whole_config, args) args_keys = [item[2:] for item in args_list if item.startswith('--')] # print(args_keys) modified_config = self.filter_config(whole_config, args_keys) # print(top_config) # print(modified_config) final_top_config, _ = merge_dicts(top_config, modified_config) # print(final_top_config) yaml = YAML() yaml.indent(offset=2) # yaml.dump(final_top_config, sys.stdout) yaml.dump(final_top_config, final_path) def process_final_root(self, auto_args): if auto_args.final_root is None: auto_args.final_root = auto_args.config_root else: if os.path.isabs(auto_args.final_root): auto_args.final_root = Path(auto_args.final_root) else: auto_args.final_root = Path(ROOT_DIR, 'configs', auto_args.final_root) if auto_args.final_root.exists(): ans = input(f'Overwrite {auto_args.final_root} by {auto_args.config_root}? [y/n]') while ans != 'y' and ans != 'n': ans = input(f'Invalid input: {ans}. Please answer y or n.') if ans == 'y': distutils.dir_util.copy_tree(str(auto_args.config_root), str(auto_args.final_root)) elif ans == 'n': pass else: raise ValueError(f'Unexpected value {ans}.') else: shutil.copytree(auto_args.config_root, auto_args.final_root) def picky_farmer(self, result_dict): best_fruits = dict() sorted_fruits = dict() for ddsa_key in result_dict.keys(): dataset, domain, shift, algorithm = ddsa_key.split(' ') for key, value in result_dict[ddsa_key].items(): result_dict[ddsa_key][key] = np.stack([np.mean(value, axis=1), np.std(value, axis=1)], axis=1) # lambda x: x[1][?, 0] - ? denotes the result used to choose the best setting. if self.watch: sorted_fruits[ddsa_key] = sorted(list(result_dict[ddsa_key].items()), key=lambda x: sum(x[1][i, 0] for i in self.pick_reference), reverse=True if 'ZINC' not in dataset else False) else: if 'ZINC' in dataset: best_fruits[ddsa_key] = min(list(result_dict[ddsa_key].items()), key=lambda x: sum(x[1][i, 0] for i in self.pick_reference)) else: best_fruits[ddsa_key] = max(list(result_dict[ddsa_key].items()), key=lambda x: sum(x[1][i, 0] for i in self.pick_reference)) # best_fruits[ddsa_key] = sorted_fruits[ddsa_key][0] if self.watch: print(sorted_fruits) exit(0) # print(best_fruits) return best_fruits def filter_config(self, config: dict, target_keys): new_config = copy.deepcopy(config) for key in config.keys(): if type(config[key]) is dict: new_config[key] = self.filter_config(config[key], target_keys) if not new_config[key]: new_config.pop(key) else: if key not in target_keys: new_config.pop(key) return new_config def harvest_all_fruits(self, jobs_group): all_finished = True result_dict = dict() for cmd_args in tqdm(jobs_group, desc='Harvesting ^_^'): args = args_parser(shlex.split(cmd_args)[1:]) config = config_summoner(args) last_line = self.harvest(config.log_path) if not last_line.startswith('INFO: ChartInfo'): print(cmd_args, 'Unfinished') all_finished = False continue result = last_line.split(' ')[2:] num_result = len(result) key_args = shlex.split(cmd_args)[1:] round_index = key_args.index('--exp_round') key_args = key_args[:round_index] + key_args[round_index + 2:] config_path_index = key_args.index('--config_path') key_args.pop(config_path_index) # Remove --config_path config_path = Path(key_args.pop(config_path_index)) # Remove and save its value config_path_parents = config_path.parents dataset, domain, shift, algorithm = config_path_parents[2].stem, config_path_parents[1].stem, \ config_path_parents[0].stem, config_path.stem ddsa_key = ' '.join([dataset, domain, shift, algorithm]) if ddsa_key not in result_dict.keys(): result_dict[ddsa_key] = dict() key_str = ' '.join(key_args) if key_str not in result_dict[ddsa_key].keys(): result_dict[ddsa_key][key_str] = [[] for _ in range(num_result)] # print(f'{ddsa_key}_{key_str}: {result}') result_dict[ddsa_key][key_str] = [r + [eval(result[i])] for i, r in enumerate(result_dict[ddsa_key][key_str])] # if not all_finished: # print('Please launch unfinished jobs using other launchers before harvesting.') # exit(1) return result_dict @staticmethod def new_container(key, dictionary, container: Literal['dict', 'list'] = 'dict'): if key not in dictionary: dictionary[key] = eval(container)() def show_your_fruits(self, best_fruits): format_best_fruits = dict() for ddsa_key, result in best_fruits.items(): dataset, domain, shift, algorithm = ddsa_key.split(' ') # self.new_container(dataset, format_best_fruits) # self.new_container(domain, format_best_fruits[dataset]) # self.new_container(shift, format_best_fruits[dataset][domain]) self.new_container(algorithm, format_best_fruits) dds_key = f'{dataset} {domain} {shift}' result = result[1][self.test_index] if 'ZINC' not in dataset: result = f'{result[0] * 100:.2f}({result[1] * 100:.2f})' else: result = f'{result[0]:.4f}({result[1]:.4f})' # format_best_fruits[dataset][domain][shift][algorithm] = result format_best_fruits[algorithm][dds_key] = result # for dataset, domain_shift_algorithm in format_best_fruits.items(): # for domain, shift_algorithm in domain_shift_algorithm.items(): # for shift, algorithm_result in shift_algorithm.items(): # self.logger.table_fromlist([[dataset, f'{domain}-{shift}']] + list(algorithm_result.items())) for algorithm, dds_key_result in format_best_fruits.items(): headers = ['Method', *list(dds_key_result.keys())] data_row = [algorithm, *list(dds_key_result.values())] self.logger.table_fromlist([headers, data_row])