GOOD.ood_algorithms.algorithms.Coral
Implementation of the Deep Coral algorithm from “Deep CORAL: Correlation Alignment for Deep Domain Adaptation” paper
Functions
|
Compute Covariance matrix of the input data |
- GOOD.ood_algorithms.algorithms.Coral.compute_covariance(input_data: Tensor, config: Union[CommonArgs, Munch]) Tensor[source]
Compute Covariance matrix of the input data
- Parameters
input_data (Tensor) – feature of the input data
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device)
config = munchify({device: torch.device('cuda')})
- Returns (Tensor):
covariance value of the input features
Classes
|
Implementation of the Deep Coral algorithm from "Deep CORAL: Correlation Alignment for Deep Domain Adaptation" paper |
- class GOOD.ood_algorithms.algorithms.Coral.Coral(config: Union[CommonArgs, Munch])[source]
Bases:
BaseOODAlgImplementation of the Deep Coral algorithm from “Deep CORAL: Correlation Alignment for Deep Domain Adaptation” paper
- Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (
config.device,config.dataset.num_envs,config.ood.ood_param)
- loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]
Process loss based on Deep Coral algorithm
- Parameters
loss (Tensor) – base loss between model predictions and input labels
data (Batch) – input data
mask (Tensor) – NAN masks for data formats
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device,config.dataset.num_envs,config.ood.ood_param)
config = munchify({device: torch.device('cuda'), dataset: {num_envs: int(10)}, ood: {ood_param: float(0.1)} })
- Returns (Tensor):
loss based on Deep Coral algorithm