Source code for GOOD.ood_algorithms.algorithms.DANN

"""
Implementation of the DANN algorithm from `"Domain-Adversarial Training of Neural Networks"
<https://www.jmlr.org/papers/volume17/15-239/15-239.pdf>`_ paper
"""
import torch
from torch import Tensor
from torch_geometric.data import Batch
from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseOOD import BaseOODAlg


[docs]@register.ood_alg_register class DANN(BaseOODAlg): r""" Implementation of the DANN algorithm from `"Domain-Adversarial Training of Neural Networks" <https://www.jmlr.org/papers/volume17/15-239/15-239.pdf>`_ paper Args: config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.model_level`, :obj:`config.metric.cross_entropy_with_logit()`, :obj:`config.ood.ood_param`) """ def __init__(self, config: Union[CommonArgs, Munch]): super(DANN, self).__init__(config) self.dc_pred = None
[docs] def output_postprocess(self, model_output: Tensor, **kwargs) -> Tensor: r""" Process the raw output of model; get domain classifier predictions Args: model_output (Tensor): model raw output Returns (Tensor): model raw predictions """ self.dc_pred = model_output[1] return model_output[0]
[docs] def loss_postprocess(self, loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) -> Tensor: r""" Process loss based on DANN algorithm Args: loss (Tensor): base loss between model predictions and input labels data (Batch): input data mask (Tensor): NAN masks for data formats config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.model_level`, :obj:`config.metric.cross_entropy_with_logit()`, :obj:`config.ood.ood_param`) .. code-block:: python config = munchify({model: {model_level: str('graph')}, metric: {cross_entropy_with_logit()}, ood: {ood_param: float(0.1)} }) Returns (Tensor): loss based on DANN algorithm """ if config.model.model_level == 'node': dc_loss: torch.Tensor = config.metric.cross_entropy_with_logit(self.dc_pred[data.train_mask], data.env_id[data.train_mask], reduction='none') else: dc_loss: torch.Tensor = config.metric.cross_entropy_with_logit(self.dc_pred, data.env_id, reduction='none') # else: # dc_loss: torch.Tensor = binary_cross_entropy_with_logits(dc_pred, torch.nn.functional.one_hot(data.env_id % config.dataset.num_envs, num_classes=config.dataset.num_envs).float(), reduction='none') * mask spec_loss = config.ood.ood_param * dc_loss.mean() mean_loss = loss.sum() / mask.sum() loss = mean_loss + spec_loss self.mean_loss = mean_loss self.spec_loss = spec_loss return loss