GOOD.ood_algorithms.algorithms.GSAT
Implementation of the GSAT algorithm from “Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism” paper
Classes
|
Implementation of the GSAT algorithm from "Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism" paper |
- class GOOD.ood_algorithms.algorithms.GSAT.GSAT(config: Union[CommonArgs, Munch])[source]
Bases:
BaseOODAlg
Implementation of the GSAT algorithm from “Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism” paper
- Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (
config.device
,config.dataset.num_envs
,config.ood.ood_param
)
- loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor [source]
Process loss based on GSAT algorithm
- Parameters
loss (Tensor) – base loss between model predictions and input labels
data (Batch) – input data
mask (Tensor) – NAN masks for data formats
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
,config.dataset.num_envs
,config.ood.ood_param
)
config = munchify({device: torch.device('cuda'), dataset: {num_envs: int(10)}, ood: {ood_param: float(0.1)} })
- Returns (Tensor):
loss based on DIR algorithm
- output_postprocess(model_output: Tensor, **kwargs) Tensor [source]
Process the raw output of model
- Parameters
model_output (Tensor) – model raw output
- Returns (Tensor):
model raw predictions.
- stage_control(config: Union[CommonArgs, Munch])[source]
Set valuables before each epoch. Largely used for controlling multi-stage training and epoch related parameter settings.
- Parameters
config – munchified dictionary of args.