GOOD.ood_algorithms.algorithms.SRGNN
Implementation of the SRGNN algorithm from “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper
Functions
|
Kernel mean matching (KMM) to compute the weight for each training instance |
|
central moment discrepancy (cmd). |
|
standard euclidean norm |
|
difference between moments |
|
computation tool for pairwise distances |
- GOOD.ood_algorithms.algorithms.SRGNN.KMM(X, Xtest, config: Union[CommonArgs, Munch], _A=None, _sigma=10.0, beta=0.2)[source]
Kernel mean matching (KMM) to compute the weight for each training instance
- Parameters
X (Tensor) – training instances to be matched
Xtest (Tensor) – IID samples to match the training instances
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
)_A (numpy array) – one hot matrix of the training instance labels
_sigma (float) – normalization term
beta (float) – regularization weight
- Returns
KMM_weight (numpy array) - KMM_weight to match each training instance
MMD_dist (Tensor) - MMD distance
- GOOD.ood_algorithms.algorithms.SRGNN.cmd(X, X_test, K=5)[source]
central moment discrepancy (cmd). objective function for keras models (theano or tensorflow backend). Zellinger, Werner, et al. “Robust unsupervised domain adaptation for neural networks via moment alignment.”, Zellinger, Werner, et al. “Central moment discrepancy (CMD) for domain-invariant representation learning.”, ICLR, 2017.
- Parameters
X (Tensor) – training instances
X_test (Tensor) – IID samples
K (int) – number of approximation degrees
- Returns (Tensor):
central moment discrepancy
- GOOD.ood_algorithms.algorithms.SRGNN.pairwise_distances(x, y=None)[source]
computation tool for pairwise distances
- Parameters
x (Tensor) – a Nxd matrix
y (Tensor) – an optional Mxd matirx
- Returns (Tensor):
dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] if y is not given then use ‘y=x’. i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
Classes
|
Implementation of the SRGNN algorithm from "Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data" paper |
- class GOOD.ood_algorithms.algorithms.SRGNN.SRGNN(config: Union[CommonArgs, Munch])[source]
Bases:
BaseOODAlg
Implementation of the SRGNN algorithm from “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper
- Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (
config.device
,config.dataset.num_envs
,config.ood.ood_param
)
- input_preprocess(data: Batch, targets: Tensor, mask: Tensor, node_norm: Tensor, training: bool, config: Union[CommonArgs, Munch], **kwargs) Tuple[Batch, Tensor, Tensor, Tensor] [source]
Set input data and mask format to prepare for SRGNN
- Parameters
data (Batch) – input data
targets (Tensor) – input labels
mask (Tensor) – NAN masks for data formats
node_norm (Tensor) – node weights for normalization (for node prediction only)
training (bool) – whether the task is training
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
,config.ood.ood_param
)
config = munchify({device: torch.device('cuda'), ood: {ood_param: float(0.1)} })
- Returns
data (Batch) - Processed input data.
targets (Tensor) - Processed input labels.
mask (Tensor) - Processed NAN masks for data formats.
node_norm (Tensor) - Processed node weights for normalization.
- loss_postprocess(loss: Tensor, data: Batch, mask: Tensor, config: Union[CommonArgs, Munch], **kwargs) Tensor [source]
Process loss based on SRGNN algorithm
- Parameters
loss (Tensor) – base loss between model predictions and input labels
data (Batch) – input data
mask (Tensor) – NAN masks for data formats
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
,config.dataset.num_envs
,config.ood.ood_param
)
config = munchify({device: torch.device('cuda'), dataset: {num_envs: int(10)}, ood: {ood_param: float(0.1)} })
- Returns (Tensor):
loss based on SRGNN algorithm