GOOD.ood_algorithms.algorithms.Mixup

Implementation of the Mixup algorithm from “Mixup for Node and Graph Classification” paper

Functions

idNode(data, id_a2b, config)

Mixup node according to given index.

shuffleData(data, config)

Prepare data and index for node mixup.

GOOD.ood_algorithms.algorithms.Mixup.idNode(data: Batch, id_a2b: Tensor, config: Union[CommonArgs, Munch]) Batch[source]

Mixup node according to given index. Modified from “MixupForGraph/mixup.py” code.

Parameters
  • data (Batch) – input data

  • id_a2b (Tensor) – the random permuted index tensor to index each mixup pair

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device)

config = munchify({device: torch.device('cuda')})
Returns (Batch):

mixed-up data

GOOD.ood_algorithms.algorithms.Mixup.shuffleData(data: Batch, config: Union[CommonArgs, Munch]) Tuple[Batch, Tensor][source]

Prepare data and index for node mixup. Modified from “MixupForGraph/mixup.py” code.

Parameters
  • data (Batch) – input data

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device)

config = munchify({device: torch.device('cuda')})
Returns

[data (Batch) - mixed-up data, id_a2b (Tensor) - the random permuted index tensor to index each mixup pair]

Classes

Mixup(config)

Implementation of the Mixup algorithm from "Mixup for Node and Graph Classification" paper

class GOOD.ood_algorithms.algorithms.Mixup.Mixup(config: Union[CommonArgs, Munch])[source]

Bases: BaseOODAlg

Implementation of the Mixup algorithm from “Mixup for Node and Graph Classification” paper

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args (config.device, config.model.model_level, config.metric.loss_func(), config.ood.ood_param)

input_preprocess(data: Batch, targets: Tensor, mask: Tensor, node_norm: Tensor, training: bool, config: Union[CommonArgs, Munch], **kwargs) Tuple[Batch, Tensor, Tensor, Tensor][source]

Set input data and mask format to prepare for mixup

Parameters
  • data (Batch) – input data

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • training (bool) – whether the task is training

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device, config.ood.ood_param)

config = munchify({device: torch.device('cuda'),
                       ood: {ood_param: float(0.1)}
                       })
Returns

  • data (Batch) - Processed input data.

  • targets (Tensor) - Processed input labels.

  • mask (Tensor) - Processed NAN masks for data formats.

  • node_norm (Tensor) - Processed node weights for normalization.

loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor[source]

Calculate loss based on Mixup algorithm

Parameters
  • raw_pred (Tensor) – model predictions

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.metric.loss_func(), config.model.model_level)

config = munchify({model: {model_level: str('graph')},
                       metric: {loss_func()}
                       })
Returns (Tensor):

loss based on Mixup algorithm