"""
Implementation of the Deep Coral algorithm from `"Deep CORAL: Correlation Alignment for Deep Domain Adaptation"
<https://link.springer.com/chapter/10.1007/978-3-319-49409-8_35>`_ paper
"""
import torch
from torch import Tensor
from torch_geometric.data import Batch
from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseOOD import BaseOODAlg
[docs]def compute_covariance(input_data: Tensor, config: Union[CommonArgs, Munch]) -> Tensor:
r"""
Compute Covariance matrix of the input data
Args:
input_data (Tensor): feature of the input data
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`)
.. code-block:: python
config = munchify({device: torch.device('cuda')})
Returns (Tensor):
covariance value of the input features
"""
n = input_data.shape[0] # batch_size
id_row = torch.ones((1, n), device=config.device)
sum_column = torch.mm(id_row, input_data)
mean_column = torch.div(sum_column, n)
term_mul_2 = torch.mm(mean_column.t(), mean_column)
d_t_d = torch.mm(input_data.t(), input_data)
c = torch.add(d_t_d, (-1 * term_mul_2)) * 1 / (n - 1)
return c
[docs]@register.ood_alg_register
class Coral(BaseOODAlg):
r"""
Implementation of the Deep Coral algorithm from `"Deep CORAL: Correlation Alignment for Deep Domain Adaptation"
<https://link.springer.com/chapter/10.1007/978-3-319-49409-8_35>`_ paper
Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`, :obj:`config.dataset.num_envs`, :obj:`config.ood.ood_param`)
"""
def __init__(self, config: Union[CommonArgs, Munch]):
super(Coral, self).__init__(config)
self.feat = None
[docs] def output_postprocess(self, model_output: Tensor, **kwargs) -> Tensor:
r"""
Process the raw output of model; get feature representations
Args:
model_output (Tensor): model raw output
Returns (Tensor):
model raw predictions
"""
self.feat = model_output[1]
return model_output[0]
[docs] def loss_postprocess(self, loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) -> Tensor:
r"""
Process loss based on Deep Coral algorithm
Args:
loss (Tensor): base loss between model predictions and input labels
data (Batch): input data
mask (Tensor): NAN masks for data formats
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.device`, :obj:`config.dataset.num_envs`, :obj:`config.ood.ood_param`)
.. code-block:: python
config = munchify({device: torch.device('cuda'),
dataset: {num_envs: int(10)},
ood: {ood_param: float(0.1)}
})
Returns (Tensor):
loss based on Deep Coral algorithm
"""
loss_list = []
covariance_matrices = []
for i in range(config.dataset.num_envs):
env_idx = data.env_id == i
env_feat = self.feat[env_idx]
if env_feat.shape[0] > 1:
covariance_matrices.append(compute_covariance(env_feat, config))
else:
covariance_matrices.append(None)
for i in range(config.dataset.num_envs):
for j in range(config.dataset.num_envs):
if i != j and covariance_matrices[i] is not None and covariance_matrices[j] is not None:
dis = covariance_matrices[i] - covariance_matrices[j]
cov_loss = torch.mean(torch.mul(dis, dis)) / 4
loss_list.append(cov_loss)
if len(loss_list) == 0:
coral_loss = torch.tensor(0)
else:
coral_loss = sum(loss_list) / len(loss_list)
spec_loss = config.ood.ood_param * coral_loss
if torch.isnan(spec_loss):
spec_loss = 0
mean_loss = loss.sum() / mask.sum()
loss = mean_loss + spec_loss
self.mean_loss = mean_loss
self.spec_loss = spec_loss
return loss