GOOD.ood_algorithms.algorithms.Coral

Implementation of the Deep Coral algorithm from “Deep CORAL: Correlation Alignment for Deep Domain Adaptation” paper

Functions

compute_covariance(input_data, config)

Compute Covariance matrix of the input data

GOOD.ood_algorithms.algorithms.Coral.compute_covariance(input_data: Tensor, config: Union[CommonArgs, Munch]) Tensor[source]

Compute Covariance matrix of the input data

Parameters
  • input_data (Tensor) – feature of the input data

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device)

config = munchify({device: torch.device('cuda')})
Returns (Tensor):

covariance value of the input features

Classes

Coral(config)

Implementation of the Deep Coral algorithm from "Deep CORAL: Correlation Alignment for Deep Domain Adaptation" paper

class GOOD.ood_algorithms.algorithms.Coral.Coral(config: Union[CommonArgs, Munch])[source]

Bases: BaseOODAlg

Implementation of the Deep Coral algorithm from “Deep CORAL: Correlation Alignment for Deep Domain Adaptation” paper

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]

Process loss based on Deep Coral algorithm

Parameters
  • loss (Tensor) – base loss between model predictions and input labels

  • data (Batch) – input data

  • mask (Tensor) – NAN masks for data formats

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

config = munchify({device: torch.device('cuda'),
                       dataset: {num_envs: int(10)},
                       ood: {ood_param: float(0.1)}
                       })
Returns (Tensor):

loss based on Deep Coral algorithm

output_postprocess(model_output: Tensor, **kwargs) Tensor[source]

Process the raw output of model; get feature representations

Parameters

model_output (Tensor) – model raw output

Returns (Tensor):

model raw predictions