GOOD.kernel.pipelines.basic_pipeline
Training pipeline: training/evaluation structure, batch training.
Classes
|
Kernel pipeline. |
- class GOOD.kernel.pipelines.basic_pipeline.Pipeline(task: str, model: Module, loader: Union[DataLoader, Dict[str, DataLoader]], ood_algorithm: BaseOODAlg, config: Union[CommonArgs, Munch])[source]
Bases:
object
Kernel pipeline.
- Parameters
task (str) – Current running task. ‘train’ or ‘test’
model (torch.nn.Module) – The GNN model.
loader (Union[DataLoader, Dict[str, DataLoader]]) – The data loader.
ood_algorithm (BaseOODAlg) – The OOD algorithm.
config (Union[CommonArgs, Munch]) – Please refer to configs:GOOD Configs and command line Arguments (CA).
- config_model(mode: str, load_param=False)[source]
A model configuration utility. Responsible for transiting model from CPU -> GPU and loading checkpoints. :param mode: ‘train’ or ‘test’. :type mode: str :param load_param: When True, loading test checkpoint will load parameters to the GNN model.
- Returns
Test score and loss if mode==’test’.
- evaluate(split: str)[source]
This function is design to collect data results and calculate scores and loss given a dataset subset. (For project use only)
- Parameters
split (str) – A split string for choosing the corresponding dataloader. Allowed: ‘train’, ‘id_val’, ‘id_test’, ‘val’, and ‘test’.
- Returns
A score and a loss.
- save_epoch(epoch: int, train_stat: dir, id_val_stat: dir, id_test_stat: dir, val_stat: dir, test_stat: dir, config: Union[CommonArgs, Munch])[source]
Training util for checkpoint saving.
- Parameters
epoch (int) – epoch number
train_stat (dir) – train statistics
id_val_stat (dir) – in-domain validation statistics
id_test_stat (dir) – in-domain test statistics
val_stat (dir) – ood validation statistics
test_stat (dir) – ood test statistics
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.ckpt_dir
,config.dataset
,config.train
,config.model
,config.metric
,config.log_path
,config.ood
)
- Returns
None