GOOD.ood_algorithms.algorithms.BaseOOD
Base class for OOD algorithms
Classes
|
Base class for OOD algorithms |
- class GOOD.ood_algorithms.algorithms.BaseOOD.BaseOODAlg(config: Union[CommonArgs, Munch])[source]
Bases:
ABC
Base class for OOD algorithms
- Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args
- backward(loss)[source]
Gradient backward process and parameter update.
- Parameters
loss – target loss
- input_preprocess(data: Batch, targets: Tensor, mask: Tensor, node_norm: Tensor, training: bool, config: Union[CommonArgs, Munch], **kwargs) Tuple[Batch, Tensor, Tensor, Tensor] [source]
Set input data format and preparations
- Parameters
data (Batch) – input data
targets (Tensor) – input labels
mask (Tensor) – NAN masks for data formats
node_norm (Tensor) – node weights for normalization (for node prediction only)
training (bool) – whether the task is training
config (Union[CommonArgs, Munch]) – munchified dictionary of args
- Returns
data (Batch) - Processed input data.
targets (Tensor) - Processed input labels.
mask (Tensor) - Processed NAN masks for data formats.
node_norm (Tensor) - Processed node weights for normalization.
- loss_calculate(raw_pred: Tensor, targets: Tensor, mask: Tensor, node_norm: Tensor, config: Union[CommonArgs, Munch]) Tensor [source]
Calculate prediction loss without any special OOD constrains
- Parameters
raw_pred (Tensor) – model predictions
targets (Tensor) – input labels
mask (Tensor) – NAN masks for data formats
node_norm (Tensor) – node weights for normalization (for node prediction only)
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.metric.loss_func()
,config.model.model_level
)
config = munchify({model: {model_level: str('graph')}, metric: {loss_func: Accuracy} })
- Returns (Tensor):
cross entropy loss
- loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor [source]
Process loss
- Parameters
loss (Tensor) – base loss between model predictions and input labels
data (Batch) – input data
mask (Tensor) – NAN masks for data formats
config (Union[CommonArgs, Munch]) – munchified dictionary of args
- Returns (Tensor):
processed loss
- output_postprocess(model_output: Tensor, **kwargs) Tensor [source]
Process the raw output of model
- Parameters
model_output (Tensor) – model raw output
- Returns (Tensor):
model raw predictions
- set_up(model: Module, config: Union[CommonArgs, Munch])[source]
Training setup of optimizer and scheduler
- Parameters
model (torch.nn.Module) – model for setup
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.train.lr
,config.metric
,config.train.mile_stones
)
- Returns
None