r"""A project configuration module that reads config argument from a file; set automatic generated arguments; and
overwrite configuration arguments by command arguments.
"""
import copy
import warnings
from os.path import join as opj
from pathlib import Path
from typing import Union
import torch
from munch import Munch
from munch import munchify
from ruamel.yaml import YAML
from tap import Tap
from GOOD.definitions import STORAGE_DIR
from GOOD.utils.args import CommonArgs
from GOOD.utils.metric import Metric
[docs]def merge_dicts(dict1: dict, dict2: dict):
"""Recursively merge two dictionaries.
Values in dict2 override values in dict1. If dict1 and dict2 contain a dictionary as a
value, this will call itself recursively to merge these dictionaries.
This does not modify the input dictionaries (creates an internal copy).
Additionally returns a list of detected duplicates.
Adapted from https://github.com/TUM-DAML/seml/blob/master/seml/utils.py
Parameters
----------
dict1: dict
First dict.
dict2: dict
Second dict. Values in dict2 will override values from dict1 in case they share the same key.
Returns
-------
return_dict: dict
Merged dictionaries.
"""
if not isinstance(dict1, dict):
raise ValueError(f"Expecting dict1 to be dict, found {type(dict1)}.")
if not isinstance(dict2, dict):
raise ValueError(f"Expecting dict2 to be dict, found {type(dict2)}.")
return_dict = copy.deepcopy(dict1)
duplicates = []
for k, v in dict2.items():
if k not in dict1:
return_dict[k] = v
else:
if isinstance(v, dict) and isinstance(dict1[k], dict):
return_dict[k], duplicates_k = merge_dicts(dict1[k], dict2[k])
duplicates += [f"{k}.{dup}" for dup in duplicates_k]
else:
return_dict[k] = dict2[k]
duplicates.append(k)
return return_dict, duplicates
[docs]def load_config(path: str, previous_includes: list = [], skip_include=False) -> dict:
r"""Config loader.
Loading configs from a config file.
Args:
path (str): The path to your yaml configuration file.
previous_includes (list): Included configurations. It is for the :obj:`include` configs used for recursion.
Please leave it blank when call this function outside.
Returns:
config (dict): config dictionary loaded from the given yaml file.
"""
path = Path(path)
if path in previous_includes:
raise ValueError(
f"Cyclic config include detected. {path} included in sequence {previous_includes}."
)
previous_includes = previous_includes + [path]
yaml = YAML(typ='safe')
direct_config = yaml.load(open(path, "r"))
if skip_include:
return direct_config, None, None
# direct_config = yaml.safe_load(open(path, "r"))
# Load config from included files.
if "includes" in direct_config:
includes = direct_config.pop("includes")
else:
includes = []
if not isinstance(includes, list):
raise AttributeError(
"Includes must be a list, '{}' provided".format(type(includes))
)
config = {}
duplicates_warning = []
duplicates_error = []
for include in includes:
include = path.parent / include
include_config, inc_dup_warning, inc_dup_error = load_config(
include, previous_includes
)
duplicates_warning += inc_dup_warning
duplicates_error += inc_dup_error
# Duplicates between includes causes an error
config, merge_dup_error = merge_dicts(config, include_config)
duplicates_error += merge_dup_error
# Duplicates between included and main file causes warnings
config, merge_dup_warning = merge_dicts(config, direct_config)
duplicates_warning += merge_dup_warning
return config, duplicates_warning, duplicates_error
[docs]def search_tap_args(args: CommonArgs, query: str):
r"""
Search a key in command line arguments.
Args:
args (CommonArgs): Command line arguments.
query (str): The query for the target argument.
Returns:
A found or not flag and the target value if found.
"""
found = False
value = None
for key in args.class_variables.keys():
if query == key:
found = True
value = getattr(args, key)
elif issubclass(type(getattr(args, key)), Tap):
found, value = search_tap_args(getattr(args, key), query)
if found:
break
return found, value
[docs]def args2config(config: Union[CommonArgs, Munch], args: CommonArgs):
r"""
Overwrite config by assigned arguments.
If an argument is not :obj:`None`, this argument has the highest priority; thus, it will overwrite the corresponding
config.
Args:
config (Union[CommonArgs, Munch]): Loaded configs.
args (CommonArgs): Command line arguments.
Returns:
Overwritten configs.
"""
for key in config.keys():
if type(config[key]) is dict:
args2config(config[key], args)
else:
found, value = search_tap_args(args, key)
if found:
if value is not None:
config[key] = value
else:
warnings.warn(f'Argument {key} in the chosen config yaml file are not defined in command arguments, '
f'which will lead to incomplete code detection and the lack of argument temporary '
f'modification by adding command arguments.')
[docs]def process_configs(config: Union[CommonArgs, Munch]):
r"""
Process loaded configs.
This process includes setting storage places for datasets, tensorboard logs, logs, and checkpoints. In addition,
we also set random seed for each experiment round, checkpoint saving gap, and gpu device. Finally, we connect the
config with two components :class:`GOOD.utils.metric.Metric` and :class:`GOOD.utils.train.TrainHelper` for easy and
unified accesses.
Args:
config (Union[CommonArgs, Munch]): Loaded configs.
Returns:
Configs after setting.
"""
# --- Dataset setting ---
if config.dataset.dataset_root is None:
config.dataset.dataset_root = opj(STORAGE_DIR, 'datasets')
# --- tensorboard directory setting ---
config.tensorboard_logdir = opj(STORAGE_DIR, 'tensorboard', f'{config.dataset.dataset_name}')
if config.dataset.shift_type:
config.tensorboard_logdir = opj(config.tensorboard_logdir, config.dataset.shift_type, config.ood.ood_alg,
str(config.ood.ood_param))
# --- Round setting ---
if config.exp_round:
config.random_seed = config.exp_round * 97 + 13
# --- Directory name definitions ---
dataset_dirname = config.dataset.dataset_name + '_' + config.dataset.domain
if config.dataset.shift_type:
dataset_dirname += '_' + config.dataset.shift_type
model_dirname = f'{config.model.model_name}_{config.model.model_layer}l_{config.model.global_pool}pool_{config.model.dropout_rate}dp'
train_dirname = f'{config.train.lr}lr_{config.train.weight_decay}wd'
ood_dirname = config.ood.ood_alg
if config.ood.ood_param is not None and config.ood.ood_param >= 0:
ood_dirname += f'_{config.ood.ood_param}'
else:
ood_dirname += '_no_param'
if config.ood.extra_param is not None:
for i, param in enumerate(config.ood.extra_param):
ood_dirname += f'_{param}'
# --- Log setting ---
log_dir_root = opj(STORAGE_DIR, 'log', 'round' + str(config.exp_round))
log_dirs = opj(log_dir_root, dataset_dirname, model_dirname, train_dirname, ood_dirname)
if config.save_tag:
log_dirs = opj(log_dirs, config.save_tag)
config.log_path = opj(log_dirs, config.log_file + '.log')
# --- Checkpoint setting ---
if config.ckpt_root is None:
config.ckpt_root = opj(STORAGE_DIR, 'checkpoints')
if config.ckpt_dir is None:
config.ckpt_dir = opj(config.ckpt_root, 'round' + str(config.exp_round))
config.ckpt_dir = opj(config.ckpt_dir, dataset_dirname, model_dirname, train_dirname, ood_dirname)
if config.save_tag:
config.ckpt_dir = opj(config.ckpt_dir, config.save_tag)
config.test_ckpt = opj(config.ckpt_dir, f'best.ckpt')
config.id_test_ckpt = opj(config.ckpt_dir, f'id_best.ckpt')
# --- Other settings ---
if config.train.max_epoch > 100:
config.train.save_gap = config.train.max_epoch // 10
config.device = torch.device(f'cuda:{config.gpu_idx}' if torch.cuda.is_available() else 'cpu')
config.train.stage_stones.append(100000)
# --- Attach train_helper and metric modules ---
config.metric = Metric()
[docs]def config_summoner(args: CommonArgs) -> Union[CommonArgs, Munch]:
r"""
A config loading and postprocessing function.
Args:
args (CommonArgs): Command line arguments.
Returns:
Processed configs.
"""
config, duplicate_warnings, duplicate_errors = load_config(args.config_path)
args2config(config, args)
config = munchify(config)
process_configs(config)
return config