GOOD.ood_algorithms.algorithms.CIGA

Implementation of the CIGA algorithm from “Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs” paper

Copied from “https://github.com/LFhase/GOOD”.

Functions

get_contrast_loss(causal_rep, labels[, ...])

get_irm_loss(causal_pred, labels, batch_env_idx)

GOOD.ood_algorithms.algorithms.CIGA.get_contrast_loss(causal_rep, labels, norm=<function normalize>, contrast_t=1.0, sampling='mul')[source]
GOOD.ood_algorithms.algorithms.CIGA.get_irm_loss(causal_pred, labels, batch_env_idx, criterion=<function cross_entropy>)[source]

Classes

CIGA(config)

Implementation of the CIGA algorithm from "Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs" paper

class GOOD.ood_algorithms.algorithms.CIGA.CIGA(config: Union[CommonArgs, Munch])[source]

Bases: BaseOODAlg

Implementation of the CIGA algorithm from “Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs” paper

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor[source]

Calculate loss based on Mixup algorithm

Parameters
  • raw_pred (Tensor) – model predictions

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.metric.loss_func(), config.model.model_level)

config = munchify({model: {model_level: str('graph')},
                       metric: {loss_func()}
                       })
Returns (Tensor):

loss based on IRM algorithm

loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]

Process loss

Parameters
  • loss (Tensor) – base loss between model predictions and input labels

  • data (Batch) – input data

  • mask (Tensor) – NAN masks for data formats

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args

Returns (Tensor):

processed loss

output_postprocess(model_output: Tensor, **kwargs) Tensor[source]

Process the raw output of model; apply the linear classifier

Parameters

model_output (Tensor) – model raw output

Returns (Tensor):

model raw predictions with the linear classifier applied