GOOD.networks.models.SRGCNs
GCN implementation of the SRGNN algorithm from “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper
Functions
|
Kernel mean matching (KMM) to compute the weight for each training instance |
|
central moment discrepancy (cmd). |
|
standard euclidean norm |
|
difference between moments |
|
computation tool for pairwise distances |
- GOOD.networks.models.SRGCNs.KMM(X, Xtest, config: Union[CommonArgs, Munch], _A=None, _sigma=10.0, beta=0.2)[source]
Kernel mean matching (KMM) to compute the weight for each training instance
- Parameters
X (Tensor) – training instances to be matched
Xtest (Tensor) – IID samples to match the training instances
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.device
)_A (numpy array) – one hot matrix of the training instance labels
_sigma (float) – normalization term
beta (float) – regularization weight
- Returns
KMM_weight (numpy array) - KMM_weight to match each training instance
MMD_dist (Tensor) - MMD distance
- GOOD.networks.models.SRGCNs.cmd(X, X_test, K=5)[source]
central moment discrepancy (cmd). objective function for keras models (theano or tensorflow backend). Zellinger, Werner, et al. “Robust unsupervised domain adaptation for neural networks via moment alignment.”, Zellinger, Werner, et al. “Central moment discrepancy (CMD) for domain-invariant representation learning.”, ICLR, 2017.
- Parameters
X (Tensor) – training instances
X_test (Tensor) – IID samples
K (int) – number of approximation degrees
- Returns (Tensor):
central moment discrepancy
- GOOD.networks.models.SRGCNs.pairwise_distances(x, y=None)[source]
computation tool for pairwise distances
- Parameters
x (Tensor) – a Nxd matrix
y (Tensor) – an optional Mxd matirx
- Returns (Tensor):
dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] if y is not given then use ‘y=x’. i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
Classes
|
The Graph Neural Network modified from the "Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data" paper and "Semi-supervised Classification with Graph Convolutional Networks" paper. |
- class GOOD.networks.models.SRGCNs.SRGCN(config: Union[CommonArgs, Munch])[source]
Bases:
GNNBasic
The Graph Neural Network modified from the “Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data” paper and “Semi-supervised Classification with Graph Convolutional Networks” paper.
- Parameters
config (Union[CommonArgs, Munch]) – munchified dictionary of args (
config.model.dim_hidden
,config.model.model_layer
,config.dataset.dim_node
,config.dataset.num_classes
)
- forward(*args, **kwargs) Tuple[Tensor, Tensor] [source]
The SRGCN model implementation.
- Parameters
*args (list) – argument list for the use of arguments_read. Refer to
arguments_read
**kwargs (dict) – key word arguments for the use of arguments_read. Refer to
arguments_read
- Returns (Tensor):
[label predictions, features]