Source code for GOOD.utils.train

r"""Training utils.
"""
from typing import Union

import torch
from munch import Munch
from torch_geometric.data import Batch

from GOOD.utils.args import CommonArgs


[docs]def nan2zero_get_mask(data, task, config: Union[CommonArgs, Munch]): r""" Training data filter masks to process NAN. Args: data (Batch): input data task (str): mask function type config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.model_level`) Returns (Tensor): [mask (Tensor) - NAN masks for data formats, targets (Tensor) - input labels] """ if config.model.model_level == 'node': if 'train' in task: mask = data.train_mask elif task == 'id_val': mask = data.get('id_val_mask') elif task == 'id_test': mask = data.get('id_test_mask') elif task == 'val': mask = data.val_mask elif task == 'test': mask = data.test_mask else: raise ValueError(f'Task should be train/id_val/id_test/val/test, but got {task}.') else: mask = ~torch.isnan(data.y) if mask is None: return None, None targets = torch.clone(data.y).detach() assert mask.shape[0] == targets.shape[0] mask = mask.reshape(targets.shape) targets[~mask] = 0 return mask, targets
[docs]def at_stage(i, config): r""" Test if the current training stage at stage i. Args: i: Stage that is possibly 1, 2, 3, ... config: config object. Returns: At stage i. """ if i - 1 < 0: raise ValueError(f"Stage i must be equal or larger than 0, but got {i}.") if i > len(config.train.stage_stones): raise ValueError(f"Stage i should be smaller than the largest stage {len(config.train.stage_stones)}," f"but got {i}.") if i - 2 < 0: return config.train.epoch < config.train.stage_stones[i - 1] else: return config.train.stage_stones[i - 2] <= config.train.epoch < config.train.stage_stones[i - 1]