GOOD.networks.models.SRGCNs

GCN implementation of the SRGNN algorithm from “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper

Functions

KMM(X, Xtest, config[, _A, _sigma, beta])

Kernel mean matching (KMM) to compute the weight for each training instance

cmd(X, X_test[, K])

central moment discrepancy (cmd).

l2diff(x1, x2)

standard euclidean norm

moment_diff(sx1, sx2, k)

difference between moments

pairwise_distances(x[, y])

computation tool for pairwise distances

GOOD.networks.models.SRGCNs.KMM(X, Xtest, config: Union[CommonArgs, Munch], _A=None, _sigma=10.0, beta=0.2)[source]

Kernel mean matching (KMM) to compute the weight for each training instance

Parameters
  • X (Tensor) – training instances to be matched

  • Xtest (Tensor) – IID samples to match the training instances

  • config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.device)

  • _A (numpy array) – one hot matrix of the training instance labels

  • _sigma (float) – normalization term

  • beta (float) – regularization weight

Returns

  • KMM_weight (numpy array) - KMM_weight to match each training instance

  • MMD_dist (Tensor) - MMD distance

GOOD.networks.models.SRGCNs.cmd(X, X_test, K=5)[source]

central moment discrepancy (cmd). objective function for keras models (theano or tensorflow backend). Zellinger, Werner, et al. “Robust unsupervised domain adaptation for neural networks via moment alignment.”, Zellinger, Werner, et al. “Central moment discrepancy (CMD) for domain-invariant representation learning.”, ICLR, 2017.

Parameters
  • X (Tensor) – training instances

  • X_test (Tensor) – IID samples

  • K (int) – number of approximation degrees

Returns (Tensor):

central moment discrepancy

GOOD.networks.models.SRGCNs.l2diff(x1, x2)[source]

standard euclidean norm

GOOD.networks.models.SRGCNs.moment_diff(sx1, sx2, k)[source]

difference between moments

GOOD.networks.models.SRGCNs.pairwise_distances(x, y=None)[source]

computation tool for pairwise distances

Parameters
  • x (Tensor) – a Nxd matrix

  • y (Tensor) – an optional Mxd matirx

Returns (Tensor):

dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] if y is not given then use ‘y=x’. i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2

Classes

SRGCN(config)

The Graph Neural Network modified from the "Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data" paper and "Semi-supervised Classification with Graph Convolutional Networks" paper.

class GOOD.networks.models.SRGCNs.SRGCN(config: Union[CommonArgs, Munch])[source]

Bases: GNNBasic

The Graph Neural Network modified from the “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper and “Semi-supervised Classification with Graph Convolutional Networks” paper.

Parameters

config (Union[CommonArgs, Munch]) – munchified dictionary of args (config.model.dim_hidden, config.model.model_layer, config.dataset.dim_node, config.dataset.num_classes)

forward(*args, **kwargs) Tuple[Tensor, Tensor][source]

The SRGCN model implementation.

Parameters
  • *args (list) – argument list for the use of arguments_read. Refer to arguments_read

  • **kwargs (dict) – key word arguments for the use of arguments_read. Refer to arguments_read

Returns (Tensor):

[label predictions, features]