GOOD.ood_algorithms.algorithms.Mixup
Implementation of the Mixup algorithm from “Mixup for Node and Graph Classification” paper
Functions
|
Mixup node according to given index. |
|
Prepare data and index for node mixup. |
- GOOD.ood_algorithms.algorithms.Mixup.idNode(data: Batch, id_a2b: Tensor, config: Union[CommonArgs, Munch]) Batch [source]
Mixup node according to given index. Modified from “MixupForGraph/mixup.py” code.
- Parameters
data (Batch) – input data
id_a2b (Tensor) – the random permuted index tensor to index each mixup pair
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
)
config = munchify({device: torch.device('cuda')})
- Returns (Batch):
mixed-up data
- GOOD.ood_algorithms.algorithms.Mixup.shuffleData(data: Batch, config: Union[CommonArgs, Munch]) Tuple[Batch, Tensor] [source]
Prepare data and index for node mixup. Modified from “MixupForGraph/mixup.py” code.
- Parameters
data (Batch) – input data
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
)
config = munchify({device: torch.device('cuda')})
- Returns
[data (Batch) - mixed-up data, id_a2b (Tensor) - the random permuted index tensor to index each mixup pair]
Classes
|
Implementation of the Mixup algorithm from "Mixup for Node and Graph Classification" paper |
- class GOOD.ood_algorithms.algorithms.Mixup.Mixup(config: Union[CommonArgs, Munch])[source]
Bases:
BaseOODAlg
Implementation of the Mixup algorithm from “Mixup for Node and Graph Classification” paper
- Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (
config.device
,config.model.model_level
,config.metric.loss_func()
,config.ood.ood_param
)
- input_preprocess(data: Batch, targets: Tensor, mask: Tensor, node_norm: Tensor, training: bool, config: Union[CommonArgs, Munch], **kwargs) Tuple[Batch, Tensor, Tensor, Tensor] [source]
Set input data and mask format to prepare for mixup
- Parameters
data (Batch) – input data
targets (Tensor) – input labels
mask (Tensor) – NAN masks for data formats
node_norm (Tensor) – node weights for normalization (for node prediction only)
training (bool) – whether the task is training
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
,config.ood.ood_param
)
config = munchify({device: torch.device('cuda'), ood: {ood_param: float(0.1)} })
- Returns
data (Batch) - Processed input data.
targets (Tensor) - Processed input labels.
mask (Tensor) - Processed NAN masks for data formats.
node_norm (Tensor) - Processed node weights for normalization.
- loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor [source]
Calculate loss based on Mixup algorithm
- Parameters
raw_pred (Tensor) – model predictions
targets (Tensor) – input labels
mask (Tensor) – NAN masks for data formats
node_norm (Tensor) – node weights for normalization (for node prediction only)
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.metric.loss_func()
,config.model.model_level
)
config = munchify({model: {model_level: str('graph')}, metric: {loss_func()} })
- Returns (Tensor):
loss based on Mixup algorithm