GOOD.ood_algorithms.algorithms.GSAT

Implementation of the GSAT algorithm from “Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism” paper

Classes

GSAT(config)

Implementation of the GSAT algorithm from "Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism" paper

class GOOD.ood_algorithms.algorithms.GSAT.GSAT(config: Union[CommonArgs, Munch])[source]

Bases: BaseOODAlg

Implementation of the GSAT algorithm from “Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism” paper

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]

Process loss based on GSAT algorithm

Parameters
  • loss (Tensor) – base loss between model predictions and input labels

  • data (Batch) – input data

  • mask (Tensor) – NAN masks for data formats

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

config = munchify({device: torch.device('cuda'),
                       dataset: {num_envs: int(10)},
                       ood: {ood_param: float(0.1)}
                       })
Returns (Tensor):

loss based on DIR algorithm

output_postprocess(model_output: Tensor, **kwargs) Tensor[source]

Process the raw output of model

Parameters

model_output (Tensor) – model raw output

Returns (Tensor):

model raw predictions.

stage_control(config: Union[CommonArgs, Munch])[source]

Set valuables before each epoch. Largely used for controlling multi-stage training and epoch related parameter settings.

Parameters

config – munchified dictionary of args.