"""
Implementation of the Mixup algorithm from `"Mixup for Node and Graph Classification"
<https://dl.acm.org/doi/abs/10.1145/3442381.3449796>`_ paper
"""
import copy
from typing import Tuple
import numpy as np
import torch
from torch import Tensor
from torch_geometric.data import Batch
from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseOOD import BaseOODAlg
[docs]def idNode(data: Batch, id_a2b: Tensor, config: Union[CommonArgs, Munch]) -> Batch:
r"""
Mixup node according to given index. Modified from `"MixupForGraph/mixup.py"
<https://github.com/vanoracai/MixupForGraph/blob/76c2f8b7138b597bdd95a33b0bb32376e3f55227/mixup.py#L46>`_ code.
Args:
data (Batch): input data
id_a2b (Tensor): the random permuted index tensor to index each mixup pair
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`)
.. code-block:: python
config = munchify({device: torch.device('cuda')})
Returns (Batch):
mixed-up data
"""
data.x = None
data.y[data.val_id] = -1
data.y[data.test_id] = -1
data.y = data.y[id_a2b]
data.train_id = None
data.test_id = None
data.val_id = None
id_b2a = torch.zeros(id_a2b.shape[0], dtype=torch.long, device=config.device)
id_b2a[id_a2b] = torch.arange(0, id_a2b.shape[0], dtype=torch.long, device=config.device)
row = data.edge_index[0]
col = data.edge_index[1]
row = id_b2a[row]
col = id_b2a[col]
data.edge_index = torch.stack([row, col], dim=0)
return data
[docs]def shuffleData(data: Batch, config: Union[CommonArgs, Munch]) -> Tuple[Batch, Tensor]:
r"""
Prepare data and index for node mixup. Modified from `"MixupForGraph/mixup.py"
<https://github.com/vanoracai/MixupForGraph/blob/76c2f8b7138b597bdd95a33b0bb32376e3f55227/mixup.py#L46>`_ code.
Args:
data (Batch): input data
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`)
.. code-block:: python
config = munchify({device: torch.device('cuda')})
Returns:
[data (Batch) - mixed-up data,
id_a2b (Tensor) - the random permuted index tensor to index each mixup pair]
"""
data = copy.deepcopy(data)
data.train_id = torch.nonzero(data.train_mask)
data.val_id = torch.nonzero(data.val_mask)
data.test_id = torch.nonzero(data.test_mask)
# --- identify new id by providing old id value ---
id_a2b = torch.arange(data.num_nodes, device=config.device)
train_id_shuffle = copy.deepcopy(data.train_id)
# random.shuffle(train_id_shuffle)
train_id_shuffle = train_id_shuffle[torch.randperm(train_id_shuffle.shape[0])]
id_a2b[data.train_id] = train_id_shuffle
data = idNode(data, id_a2b, config)
return data, id_a2b
[docs]@register.ood_alg_register
class Mixup(BaseOODAlg):
r"""
Implementation of the Mixup algorithm from `"Mixup for Node and Graph Classification"
<https://dl.acm.org/doi/abs/10.1145/3442381.3449796>`_ paper
Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`, :obj:`config.model.model_level`, :obj:`config.metric.loss_func()`, :obj:`config.ood.ood_param`)
"""
def __init__(self, config: Union[CommonArgs, Munch]):
super(Mixup, self).__init__(config)
self.lam = None
self.data_perm = None
self.id_a2b = None
[docs] def loss_calculate(self, raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor,
config: Union[CommonArgs, Munch]) -> Tensor:
r"""
Calculate loss based on Mixup algorithm
Args:
raw_pred (Tensor): model predictions
targets (Tensor): input labels
mask (Tensor): NAN masks for data formats
node_norm (Tensor): node weights for normalization (for node prediction only)
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.metric.loss_func()`, :obj:`config.model.model_level`)
.. code-block:: python
config = munchify({model: {model_level: str('graph')},
metric: {loss_func()}
})
Returns (Tensor):
loss based on Mixup algorithm
"""
loss_a = config.metric.loss_func(raw_pred, targets, reduction='none') * mask
loss_b = config.metric.loss_func(raw_pred, targets[self.id_a2b], reduction='none') * mask
if config.model.model_level == 'node':
loss_a = loss_a * node_norm * mask.sum()
loss_b = loss_b * node_norm[self.id_a2b] * mask.sum()
loss = self.lam * loss_a + (1 - self.lam) * loss_b
return loss