GOOD.kernel.pipelines.basic_pipeline

Training pipeline: training/evaluation structure, batch training.

Classes

Pipeline(task, model, loader, ood_algorithm, ...)

Kernel pipeline.

class GOOD.kernel.pipelines.basic_pipeline.Pipeline(task: str, model: Module, loader: Union[DataLoader, Dict[str, DataLoader]], ood_algorithm: BaseOODAlg, config: Union[CommonArgs, Munch])[source]

Bases: object

Kernel pipeline.

Parameters
  • task (str) – Current running task. ‘train’ or ‘test’

  • model (torch.nn.Module) – The GNN model.

  • loader (Union[DataLoader, Dict[str, DataLoader]]) – The data loader.

  • ood_algorithm (BaseOODAlg) – The OOD algorithm.

  • config (Union[CommonArgs, Munch]) – Please refer to configs:GOOD Configs and command line Arguments (CA).

config_model(mode: str, load_param=False)[source]

A model configuration utility. Responsible for transiting model from CPU -> GPU and loading checkpoints. :param mode: ‘train’ or ‘test’. :type mode: str :param load_param: When True, loading test checkpoint will load parameters to the GNN model.

Returns

Test score and loss if mode==’test’.

evaluate(split: str)[source]

This function is design to collect data results and calculate scores and loss given a dataset subset. (For project use only)

Parameters

split (str) – A split string for choosing the corresponding dataloader. Allowed: ‘train’, ‘id_val’, ‘id_test’, ‘val’, and ‘test’.

Returns

A score and a loss.

load_task()[source]

Launch a training or a test.

save_epoch(epoch: int, train_stat: dir, id_val_stat: dir, id_test_stat: dir, val_stat: dir, test_stat: dir, config: Union[CommonArgs, Munch])[source]

Training util for checkpoint saving.

Parameters
  • epoch (int) – epoch number

  • train_stat (dir) – train statistics

  • id_val_stat (dir) – in-domain validation statistics

  • id_test_stat (dir) – in-domain test statistics

  • val_stat (dir) – ood validation statistics

  • test_stat (dir) – ood test statistics

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.ckpt_dir, config.dataset, config.train, config.model, config.metric, config.log_path, config.ood)

Returns

None

train()[source]

Training pipeline. (Project use only)

train_batch(data: Batch, pbar) dict[source]

Train a batch. (Project use only)

Parameters

data (Batch) – Current batch of data.

Returns

Calculated loss.