GOOD.ood_algorithms.algorithms.SRGNN

Implementation of the SRGNN algorithm from “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper

Functions

KMM(X, Xtest, config[, _A, _sigma, beta])

Kernel mean matching (KMM) to compute the weight for each training instance

cmd(X, X_test[, K])

central moment discrepancy (cmd).

l2diff(x1, x2)

standard euclidean norm

moment_diff(sx1, sx2, k)

difference between moments

pairwise_distances(x[, y])

computation tool for pairwise distances

GOOD.ood_algorithms.algorithms.SRGNN.KMM(X, Xtest, config: Union[CommonArgs, Munch], _A=None, _sigma=10.0, beta=0.2)[source]

Kernel mean matching (KMM) to compute the weight for each training instance

Parameters
  • X (Tensor) – training instances to be matched

  • Xtest (Tensor) – IID samples to match the training instances

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device)

  • _A (numpy array) – one hot matrix of the training instance labels

  • _sigma (float) – normalization term

  • beta (float) – regularization weight

Returns

  • KMM_weight (numpy array) - KMM_weight to match each training instance

  • MMD_dist (Tensor) - MMD distance

GOOD.ood_algorithms.algorithms.SRGNN.cmd(X, X_test, K=5)[source]

central moment discrepancy (cmd). objective function for keras models (theano or tensorflow backend). Zellinger, Werner, et al. “Robust unsupervised domain adaptation for neural networks via moment alignment.”, Zellinger, Werner, et al. “Central moment discrepancy (CMD) for domain-invariant representation learning.”, ICLR, 2017.

Parameters
  • X (Tensor) – training instances

  • X_test (Tensor) – IID samples

  • K (int) – number of approximation degrees

Returns (Tensor):

central moment discrepancy

GOOD.ood_algorithms.algorithms.SRGNN.l2diff(x1, x2)[source]

standard euclidean norm

GOOD.ood_algorithms.algorithms.SRGNN.moment_diff(sx1, sx2, k)[source]

difference between moments

GOOD.ood_algorithms.algorithms.SRGNN.pairwise_distances(x, y=None)[source]

computation tool for pairwise distances

Parameters
  • x (Tensor) – a Nxd matrix

  • y (Tensor) – an optional Mxd matirx

Returns (Tensor):

dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] if y is not given then use ‘y=x’. i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2

Classes

SRGNN(config)

Implementation of the SRGNN algorithm from "Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data" paper

class GOOD.ood_algorithms.algorithms.SRGNN.SRGNN(config: Union[CommonArgs, Munch])[source]

Bases: BaseOODAlg

Implementation of the SRGNN algorithm from “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper

Args:

config (Union[CommonArgs, Munch]): munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

input_preprocess(data: Batch, targets: Tensor, mask: Tensor, node_norm: Tensor, training: bool, config: Union[CommonArgs, Munch], **kwargs) Tuple[Batch, Tensor, Tensor, Tensor][source]

Set input data and mask format to prepare for SRGNN

Parameters
  • data (Batch) – input data

  • targets (Tensor) – input labels

  • mask (Tensor) – NAN masks for data formats

  • node_norm (Tensor) – node weights for normalization (for node prediction only)

  • training (bool) – whether the task is training

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device, config.ood.ood_param)

config = munchify({device: torch.device('cuda'),
                       ood: {ood_param: float(0.1)}
                       })
Returns

  • data (Batch) - Processed input data.

  • targets (Tensor) - Processed input labels.

  • mask (Tensor) - Processed NAN masks for data formats.

  • node_norm (Tensor) - Processed node weights for normalization.

loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor[source]

Process loss based on SRGNN algorithm

Parameters
  • loss (Tensor) – base loss between model predictions and input labels

  • data (Batch) – input data

  • mask (Tensor) – NAN masks for data formats

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device, config.dataset.num_envs, config.ood.ood_param)

config = munchify({device: torch.device('cuda'),
                       dataset: {num_envs: int(10)},
                       ood: {ood_param: float(0.1)}
                       })
Returns (Tensor):

loss based on SRGNN algorithm

output_postprocess(model_output: Tensor, **kwargs) Tensor[source]

Process the raw output of model; get feature representations

Parameters

model_output (Tensor) – model raw output

Returns (Tensor):

model raw predictions