GOOD.ood_algorithms.algorithms.CIGA
Implementation of the CIGA algorithm from “Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs” paper
Copied from “https://github.com/LFhase/GOOD”.
Functions
|
|
|
- GOOD.ood_algorithms.algorithms.CIGA.get_contrast_loss(causal_rep, labels, norm=<function normalize>, contrast_t=1.0, sampling='mul')[source]
- GOOD.ood_algorithms.algorithms.CIGA.get_irm_loss(causal_pred, labels, batch_env_idx, criterion=<function cross_entropy>)[source]
Classes
|
Implementation of the CIGA algorithm from "Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs" paper |
- class GOOD.ood_algorithms.algorithms.CIGA.CIGA(config: Union[CommonArgs, Munch])[source]
Bases:
BaseOODAlg
Implementation of the CIGA algorithm from “Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs” paper
- Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (
config.device
,config.dataset.num_envs
,config.ood.ood_param
)
- loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor [source]
Calculate loss based on Mixup algorithm
- Parameters
raw_pred (Tensor) – model predictions
targets (Tensor) – input labels
mask (Tensor) – NAN masks for data formats
node_norm (Tensor) – node weights for normalization (for node prediction only)
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.metric.loss_func()
,config.model.model_level
)
config = munchify({model: {model_level: str('graph')}, metric: {loss_func()} })
- Returns (Tensor):
loss based on IRM algorithm
- loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor [source]
Process loss
- Parameters
loss (Tensor) – base loss between model predictions and input labels
data (Batch) – input data
mask (Tensor) – NAN masks for data formats
config (Union[CommonArgs, Munch]) – munchified dictionary of args
- Returns (Tensor):
processed loss