Source code for GOOD.kernel.pipelines.basic_pipeline

r"""Training pipeline: training/evaluation structure, batch training.
"""
import datetime
import os
import shutil
from typing import Dict
from typing import Union

import numpy as np
import torch
import torch.nn
from munch import Munch
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from tqdm import tqdm

from GOOD.ood_algorithms.algorithms.BaseOOD import BaseOODAlg
from GOOD.utils.args import CommonArgs
from GOOD.utils.evaluation import eval_data_preprocess, eval_score
from GOOD.utils.logger import pbar_setting
from GOOD.utils.register import register
from GOOD.utils.train import nan2zero_get_mask


[docs]@register.pipeline_register class Pipeline: r""" Kernel pipeline. Args: task (str): Current running task. 'train' or 'test' model (torch.nn.Module): The GNN model. loader (Union[DataLoader, Dict[str, DataLoader]]): The data loader. ood_algorithm (BaseOODAlg): The OOD algorithm. config (Union[CommonArgs, Munch]): Please refer to :ref:`configs:GOOD Configs and command line Arguments (CA)`. """ def __init__(self, task: str, model: torch.nn.Module, loader: Union[DataLoader, Dict[str, DataLoader]], ood_algorithm: BaseOODAlg, config: Union[CommonArgs, Munch]): super(Pipeline, self).__init__() self.task: str = task self.model: torch.nn.Module = model self.loader: Union[DataLoader, Dict[str, DataLoader]] = loader self.ood_algorithm: BaseOODAlg = ood_algorithm self.config: Union[CommonArgs, Munch] = config
[docs] def train_batch(self, data: Batch, pbar) -> dict: r""" Train a batch. (Project use only) Args: data (Batch): Current batch of data. Returns: Calculated loss. """ data = data.to(self.config.device) self.ood_algorithm.optimizer.zero_grad() mask, targets = nan2zero_get_mask(data, 'train', self.config) node_norm = data.get('node_norm') if self.config.model.model_level == 'node' else None node_norm = node_norm.reshape(targets.shape) if node_norm is not None else None data, targets, mask, node_norm = self.ood_algorithm.input_preprocess(data, targets, mask, node_norm, self.model.training, self.config) edge_weight = data.get('edge_norm') if self.config.model.model_level == 'node' else None model_output = self.model(data=data, edge_weight=edge_weight, ood_algorithm=self.ood_algorithm) raw_pred = self.ood_algorithm.output_postprocess(model_output) loss = self.ood_algorithm.loss_calculate(raw_pred, targets, mask, node_norm, self.config) loss = self.ood_algorithm.loss_postprocess(loss, data, mask, self.config) self.ood_algorithm.backward(loss) return {'loss': loss.detach()}
[docs] def train(self): r""" Training pipeline. (Project use only) """ # config model print('#D#Config model') self.config_model('train') # Load training utils print('#D#Load training utils') self.ood_algorithm.set_up(self.model, self.config) # train the model for epoch in range(self.config.train.ctn_epoch, self.config.train.max_epoch): self.config.train.epoch = epoch print(f'#IN#Epoch {epoch}:') mean_loss = 0 spec_loss = 0 self.ood_algorithm.stage_control(self.config) pbar = tqdm(enumerate(self.loader['train']), total=len(self.loader['train']), **pbar_setting) for index, data in pbar: if data.batch is not None and (data.batch[-1] < self.config.train.train_bs - 1): continue # Parameter for DANN p = (index / len(self.loader['train']) + epoch) / self.config.train.max_epoch self.config.train.alpha = 2. / (1. + np.exp(-10 * p)) - 1 # train a batch train_stat = self.train_batch(data, pbar) mean_loss = (mean_loss * index + self.ood_algorithm.mean_loss) / (index + 1) if self.ood_algorithm.spec_loss is not None: if isinstance(self.ood_algorithm.spec_loss, dict): desc = f'ML: {mean_loss:.4f}|' for loss_name, loss_value in self.ood_algorithm.spec_loss.items(): if not isinstance(spec_loss, dict): spec_loss = dict() if loss_name not in spec_loss.keys(): spec_loss[loss_name] = 0 spec_loss[loss_name] = (spec_loss[loss_name] * index + loss_value) / (index + 1) desc += f'{loss_name}: {spec_loss[loss_name]:.4f}|' pbar.set_description(desc[:-1]) else: spec_loss = (spec_loss * index + self.ood_algorithm.spec_loss) / (index + 1) pbar.set_description(f'M/S Loss: {mean_loss:.4f}/{spec_loss:.4f}') else: pbar.set_description(f'Loss: {mean_loss:.4f}') # Eval training score # Epoch val print('#IN#\nEvaluating...') if self.ood_algorithm.spec_loss is not None: if isinstance(self.ood_algorithm.spec_loss, dict): desc = f'ML: {mean_loss:.4f}|' for loss_name, loss_value in self.ood_algorithm.spec_loss.items(): desc += f'{loss_name}: {spec_loss[loss_name]:.4f}|' print(f'#IN#Approximated ' + desc[:-1]) else: print(f'#IN#Approximated average M/S Loss {mean_loss:.4f}/{spec_loss:.4f}') else: print(f'#IN#Approximated average training loss {mean_loss.cpu().item():.4f}') epoch_train_stat = self.evaluate('eval_train') id_val_stat = self.evaluate('id_val') id_test_stat = self.evaluate('id_test') val_stat = self.evaluate('val') test_stat = self.evaluate('test') # checkpoints save self.save_epoch(epoch, epoch_train_stat, id_val_stat, id_test_stat, val_stat, test_stat, self.config) # --- scheduler step --- self.ood_algorithm.scheduler.step() print('#IN#Training end.')
[docs] @torch.no_grad() def evaluate(self, split: str): r""" This function is design to collect data results and calculate scores and loss given a dataset subset. (For project use only) Args: split (str): A split string for choosing the corresponding dataloader. Allowed: 'train', 'id_val', 'id_test', 'val', and 'test'. Returns: A score and a loss. """ stat = {'score': None, 'loss': None} if self.loader.get(split) is None: return stat self.model.eval() loss_all = [] mask_all = [] pred_all = [] target_all = [] pbar = tqdm(self.loader[split], desc=f'Eval {split.capitalize()}', total=len(self.loader[split]), **pbar_setting) for data in pbar: data: Batch = data.to(self.config.device) mask, targets = nan2zero_get_mask(data, split, self.config) if mask is None: return stat node_norm = torch.ones_like(targets, device=self.config.device) if self.config.model.model_level == 'node' else None data, targets, mask, node_norm = self.ood_algorithm.input_preprocess(data, targets, mask, node_norm, self.model.training, self.config) model_output = self.model(data=data, edge_weight=None, ood_algorithm=self.ood_algorithm) raw_preds = self.ood_algorithm.output_postprocess(model_output) # --------------- Loss collection ------------------ loss: torch.tensor = self.config.metric.loss_func(raw_preds, targets, reduction='none') * mask mask_all.append(mask) loss_all.append(loss) # ------------- Score data collection ------------------ pred, target = eval_data_preprocess(data.y, raw_preds, mask, self.config) pred_all.append(pred) target_all.append(target) # ------- Loss calculate ------- loss_all = torch.cat(loss_all) mask_all = torch.cat(mask_all) stat['loss'] = loss_all.sum() / mask_all.sum() # --------------- Metric calculation including ROC_AUC, Accuracy, AP. -------------------- stat['score'] = eval_score(pred_all, target_all, self.config) print(f'#IN#\n{split.capitalize()} {self.config.metric.score_name}: {stat["score"]:.4f}\n' f'{split.capitalize()} Loss: {stat["loss"]:.4f}') self.model.train() return {'score': stat['score'], 'loss': stat['loss']}
[docs] def load_task(self): r""" Launch a training or a test. """ if self.task == 'train': self.train() elif self.task == 'test': # config model print('#D#Config model and output the best checkpoint info...') test_score, test_loss = self.config_model('test')
[docs] def config_model(self, mode: str, load_param=False): r""" A model configuration utility. Responsible for transiting model from CPU -> GPU and loading checkpoints. Args: mode (str): 'train' or 'test'. load_param: When True, loading test checkpoint will load parameters to the GNN model. Returns: Test score and loss if mode=='test'. """ self.model.to(self.config.device) self.model.train() # load checkpoint if mode == 'train' and self.config.train.tr_ctn: ckpt = torch.load(os.path.join(self.config.ckpt_dir, f'last.ckpt')) self.model.load_state_dict(ckpt['state_dict']) best_ckpt = torch.load(os.path.join(self.config.ckpt_dir, f'best.ckpt')) self.config.metric.best_stat['score'] = best_ckpt['val_score'] self.config.metric.best_stat['loss'] = best_ckpt['val_loss'] self.config.train.ctn_epoch = ckpt['epoch'] + 1 print(f'#IN#Continue training from Epoch {ckpt["epoch"]}...') if mode == 'test': try: ckpt = torch.load(self.config.test_ckpt, map_location=self.config.device) except FileNotFoundError: print(f'#E#Checkpoint not found at {os.path.abspath(self.config.test_ckpt)}') exit(1) if os.path.exists(self.config.id_test_ckpt): id_ckpt = torch.load(self.config.id_test_ckpt, map_location=self.config.device) # model.load_state_dict(id_ckpt['state_dict']) print(f'#IN#Loading best In-Domain Checkpoint {id_ckpt["epoch"]}...') print(f'#IN#Checkpoint {id_ckpt["epoch"]}: \n-----------------------------------\n' f'Train {self.config.metric.score_name}: {id_ckpt["train_score"]:.4f}\n' f'Train Loss: {id_ckpt["train_loss"].item():.4f}\n' f'ID Validation {self.config.metric.score_name}: {id_ckpt["id_val_score"]:.4f}\n' f'ID Validation Loss: {id_ckpt["id_val_loss"].item():.4f}\n' f'ID Test {self.config.metric.score_name}: {id_ckpt["id_test_score"]:.4f}\n' f'ID Test Loss: {id_ckpt["id_test_loss"].item():.4f}\n' f'OOD Validation {self.config.metric.score_name}: {id_ckpt["val_score"]:.4f}\n' f'OOD Validation Loss: {id_ckpt["val_loss"].item():.4f}\n' f'OOD Test {self.config.metric.score_name}: {id_ckpt["test_score"]:.4f}\n' f'OOD Test Loss: {id_ckpt["test_loss"].item():.4f}\n') print(f'#IN#Loading best Out-of-Domain Checkpoint {ckpt["epoch"]}...') print(f'#IN#Checkpoint {ckpt["epoch"]}: \n-----------------------------------\n' f'Train {self.config.metric.score_name}: {ckpt["train_score"]:.4f}\n' f'Train Loss: {ckpt["train_loss"].item():.4f}\n' f'ID Validation {self.config.metric.score_name}: {ckpt["id_val_score"]:.4f}\n' f'ID Validation Loss: {ckpt["id_val_loss"].item():.4f}\n' f'ID Test {self.config.metric.score_name}: {ckpt["id_test_score"]:.4f}\n' f'ID Test Loss: {ckpt["id_test_loss"].item():.4f}\n' f'OOD Validation {self.config.metric.score_name}: {ckpt["val_score"]:.4f}\n' f'OOD Validation Loss: {ckpt["val_loss"].item():.4f}\n' f'OOD Test {self.config.metric.score_name}: {ckpt["test_score"]:.4f}\n' f'OOD Test Loss: {ckpt["test_loss"].item():.4f}\n') print(f'#IN#ChartInfo {id_ckpt["id_test_score"]:.4f} {id_ckpt["test_score"]:.4f} ' f'{ckpt["id_test_score"]:.4f} {ckpt["test_score"]:.4f} {ckpt["val_score"]:.4f}', end='') else: print(f'#IN#No In-Domain checkpoint.') # model.load_state_dict(ckpt['state_dict']) print(f'#IN#Loading best Checkpoint {ckpt["epoch"]}...') print(f'#IN#Checkpoint {ckpt["epoch"]}: \n-----------------------------------\n' f'Train {self.config.metric.score_name}: {ckpt["train_score"]:.4f}\n' f'Train Loss: {ckpt["train_loss"].item():.4f}\n' f'Validation {self.config.metric.score_name}: {ckpt["val_score"]:.4f}\n' f'Validation Loss: {ckpt["val_loss"].item():.4f}\n' f'Test {self.config.metric.score_name}: {ckpt["test_score"]:.4f}\n' f'Test Loss: {ckpt["test_loss"].item():.4f}\n') print( f'#IN#ChartInfo {ckpt["test_score"]:.4f} {ckpt["val_score"]:.4f}', end='') if load_param: if self.config.ood.ood_alg != 'EERM': self.model.load_state_dict(ckpt['state_dict']) else: self.model.gnn.load_state_dict(ckpt['state_dict']) return ckpt["test_score"], ckpt["test_loss"]
[docs] def save_epoch(self, epoch: int, train_stat: dir, id_val_stat: dir, id_test_stat: dir, val_stat: dir, test_stat: dir, config: Union[CommonArgs, Munch]): r""" Training util for checkpoint saving. Args: epoch (int): epoch number train_stat (dir): train statistics id_val_stat (dir): in-domain validation statistics id_test_stat (dir): in-domain test statistics val_stat (dir): ood validation statistics test_stat (dir): ood test statistics config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.ckpt_dir`, :obj:`config.dataset`, :obj:`config.train`, :obj:`config.model`, :obj:`config.metric`, :obj:`config.log_path`, :obj:`config.ood`) Returns: None """ state_dict = self.model.state_dict() if config.ood.ood_alg != 'EERM' else self.model.gnn.state_dict() ckpt = { 'state_dict': state_dict, 'train_score': train_stat['score'], 'train_loss': train_stat['loss'], 'id_val_score': id_val_stat['score'], 'id_val_loss': id_val_stat['loss'], 'id_test_score': id_test_stat['score'], 'id_test_loss': id_test_stat['loss'], 'val_score': val_stat['score'], 'val_loss': val_stat['loss'], 'test_score': test_stat['score'], 'test_loss': test_stat['loss'], 'time': datetime.datetime.now().strftime('%b%d %Hh %M:%S'), 'model': { 'model name': f'{config.model.model_name} {config.model.model_level} layers', 'dim_hidden': config.model.dim_hidden, 'dim_ffn': config.model.dim_ffn, 'global pooling': config.model.global_pool }, 'dataset': config.dataset.dataset_name, 'train': { 'weight_decay': config.train.weight_decay, 'learning_rate': config.train.lr, 'mile stone': config.train.mile_stones, 'shift_type': config.dataset.shift_type, 'Batch size': f'{config.train.train_bs}, {config.train.val_bs}, {config.train.test_bs}' }, 'OOD': { 'OOD alg': config.ood.ood_alg, 'OOD param': config.ood.ood_param, 'number of environments': config.dataset.num_envs }, 'log file': config.log_path, 'epoch': epoch, 'max epoch': config.train.max_epoch } if not (config.metric.best_stat['score'] is None or config.metric.lower_better * val_stat[ 'score'] < config.metric.lower_better * config.metric.best_stat['score'] or (id_val_stat.get('score') and ( config.metric.id_best_stat['score'] is None or config.metric.lower_better * id_val_stat[ 'score'] < config.metric.lower_better * config.metric.id_best_stat['score'])) or epoch % config.train.save_gap == 0): return if not os.path.exists(config.ckpt_dir): os.makedirs(config.ckpt_dir) print(f'#W#Directory does not exists. Have built it automatically.\n' f'{os.path.abspath(config.ckpt_dir)}') saved_file = os.path.join(config.ckpt_dir, f'{epoch}.ckpt') torch.save(ckpt, saved_file) shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'last.ckpt')) # --- In-Domain checkpoint --- if id_val_stat.get('score') and ( config.metric.id_best_stat['score'] is None or config.metric.lower_better * id_val_stat[ 'score'] < config.metric.lower_better * config.metric.id_best_stat['score']): config.metric.id_best_stat['score'] = id_val_stat['score'] config.metric.id_best_stat['loss'] = id_val_stat['loss'] shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'id_best.ckpt')) print('#IM#Saved a new best In-Domain checkpoint.') # --- Out-Of-Domain checkpoint --- # if id_val_stat.get('score'): # if not (config.metric.lower_better * id_val_stat['score'] < config.metric.lower_better * val_stat['score']): # return if config.metric.best_stat['score'] is None or config.metric.lower_better * val_stat[ 'score'] < config.metric.lower_better * \ config.metric.best_stat['score']: config.metric.best_stat['score'] = val_stat['score'] config.metric.best_stat['loss'] = val_stat['loss'] shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'best.ckpt')) print('#IM#Saved a new best checkpoint.') if config.clean_save: os.unlink(saved_file)