GOOD.ood_algorithms.algorithms.DIR

Implementation of the DIR algorithm from “Discovering Invariant Rationales for Graph Neural Networks” paper

Classes

DIR(config)

Implementation of the DIR algorithm from "Discovering Invariant Rationales for Graph Neural Networks" paper

class GOOD.ood_algorithms.algorithms.DIR.DIR(config: Union[CommonArgs, Munch])[source]

Bases: BaseOODAlg

Implementation of the DIR algorithm from “Discovering Invariant Rationales for Graph Neural Networks” paper

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor[source]

Calculate loss based on DIR algorithm

Parameters
  • raw_pred (Tensor) – model predictions

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.metric.loss_func(), config.model.model_level)

config = munchify({model: {model_level: str('graph')},
                       metric: {loss_func()}
                       })
Returns (Tensor):

loss based on DIR algorithm

loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]

Process loss based on DIR algorithm

Parameters
  • loss (Tensor) – base loss between model predictions and input labels

  • data (Batch) – input data

  • mask (Tensor) – NAN masks for data formats

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

config = munchify({device: torch.device('cuda'),
                       dataset: {num_envs: int(10)},
                       ood: {ood_param: float(0.1)}
                       })
Returns (Tensor):

loss based on DIR algorithm

output_postprocess(model_output: Tensor, **kwargs) Tensor[source]

Process the raw output of model

Parameters

model_output (Tensor) – model raw output

Returns (Tensor):

model raw predictions.

stage_control(config: Union[CommonArgs, Munch])[source]

Set valuables before each epoch. Largely used for controlling multi-stage training and epoch related parameter settings.

Parameters

config – munchified dictionary of args.