GOOD.ood_algorithms.algorithms.BaseOOD

Base class for OOD algorithms

Classes

BaseOODAlg(config)

Base class for OOD algorithms

class GOOD.ood_algorithms.algorithms.BaseOOD.BaseOODAlg(config: Union[CommonArgs, Munch])[source]

Bases: ABC

Base class for OOD algorithms

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args

backward(loss)[source]

Gradient backward process and parameter update.

Parameters

loss – target loss

input_preprocess(data: Batch, targets: Tensor, mask: Tensor, node_norm: Tensor, training: bool, config: Union[CommonArgs, Munch], **kwargs) Tuple[Batch, Tensor, Tensor, Tensor][source]

Set input data format and preparations

Parameters
  • data (Batch) – input data

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • training (bool) – whether the task is training

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args

Returns

  • data (Batch) - Processed input data.

  • targets (Tensor) - Processed input labels.

  • mask (Tensor) - Processed NAN masks for data formats.

  • node_norm (Tensor) - Processed node weights for normalization.

loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor[source]

Calculate prediction loss without any special OOD constrains

Parameters
  • raw_pred (Tensor) – model predictions

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.metric.loss_func(), config.model.model_level)

config = munchify({model: {model_level: str('graph')},
                       metric: {loss_func: Accuracy}
                       })
Returns (Tensor):

cross entropy loss

loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]

Process loss

Parameters
  • loss (Tensor) – base loss between model predictions and input labels

  • data (Batch) – input data

  • mask (Tensor) – NAN masks for data formats

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args

Returns (Tensor):

processed loss

output_postprocess(model_output: Tensor, **kwargs) Tensor[source]

Process the raw output of model

Parameters

model_output (Tensor) – model raw output

Returns (Tensor):

model raw predictions

set_up(model: Module, config: Union[CommonArgs, Munch])[source]

Training setup of optimizer and scheduler

Parameters
  • model (torch.nn.Module) – model for setup

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.train.lr, config.metric, config.train.mile_stones)

Returns

None

stage_control(config)[source]

Set valuables before each epoch. Largely used for controlling multi-stage training and epoch related parameter settings.

Parameters

config – munchified dictionary of args.