Source code for GOOD.ood_algorithms.algorithms.GSAT

"""
Implementation of the GSAT algorithm from `"Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism" <https://arxiv.org/abs/2201.12987>`_ paper
"""
from typing import Tuple

import torch
from torch import Tensor
from torch_geometric.data import Batch

from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from GOOD.utils.initial import reset_random_seed
from GOOD.utils.train import at_stage
from .BaseOOD import BaseOODAlg


[docs]@register.ood_alg_register class GSAT(BaseOODAlg): r""" Implementation of the GSAT algorithm from `"Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism" <https://arxiv.org/abs/2201.12987>`_ paper Args: config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`, :obj:`config.dataset.num_envs`, :obj:`config.ood.ood_param`) """ def __init__(self, config: Union[CommonArgs, Munch]): super(GSAT, self).__init__(config) self.att = None self.edge_att = None self.decay_r = 0.1 self.decay_interval = config.ood.extra_param[1] self.final_r = config.ood.extra_param[2] # 0.5 or 0.7
[docs] def stage_control(self, config: Union[CommonArgs, Munch]): r""" Set valuables before each epoch. Largely used for controlling multi-stage training and epoch related parameter settings. Args: config: munchified dictionary of args. """ if self.stage == 0 and at_stage(1, config): reset_random_seed(config) self.stage = 1
[docs] def output_postprocess(self, model_output: Tensor, **kwargs) -> Tensor: r""" Process the raw output of model Args: model_output (Tensor): model raw output Returns (Tensor): model raw predictions. """ raw_out, self.att, self.edge_att = model_output return raw_out
[docs] def loss_postprocess(self, loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) -> Tensor: r""" Process loss based on GSAT algorithm Args: loss (Tensor): base loss between model predictions and input labels data (Batch): input data mask (Tensor): NAN masks for data formats config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`, :obj:`config.dataset.num_envs`, :obj:`config.ood.ood_param`) .. code-block:: python config = munchify({device: torch.device('cuda'), dataset: {num_envs: int(10)}, ood: {ood_param: float(0.1)} }) Returns (Tensor): loss based on DIR algorithm """ att = self.att eps = 1e-6 r = self.get_r(self.decay_interval, self.decay_r, config.train.epoch, final_r=self.final_r) info_loss = (att * torch.log(att / r + eps) + (1 - att) * torch.log((1 - att) / (1 - r + eps) + eps)).mean() self.mean_loss = loss.mean() self.spec_loss = config.ood.ood_param * info_loss loss = self.mean_loss + self.spec_loss return loss
def get_r(self, decay_interval, decay_r, current_epoch, init_r=0.9, final_r=0.5): r = init_r - current_epoch // decay_interval * decay_r if r < final_r: r = final_r return r