"""
GCN implementation of the Deep Coral algorithm from `"Deep CORAL: Correlation Alignment for Deep Domain Adaptation"
<https://link.springer.com/chapter/10.1007/978-3-319-49409-8_35>`_ paper
"""
from typing import Tuple
import torch
from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseGNN import GNNBasic
from .Classifiers import Classifier
from .GCNs import GCNFeatExtractor
[docs]@register.model_register
class Coral_GCN(GNNBasic):
r"""
The Graph Neural Network modified from the `"Deep CORAL: Correlation Alignment for Deep Domain Adaptation"
<https://link.springer.com/chapter/10.1007/978-3-319-49409-8_35>`_ paper and `"Semi-supervised Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper.
Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.dataset.dim_node`, :obj:`config.dataset.num_classes`)
"""
def __init__(self, config: Union[CommonArgs, Munch]):
super().__init__(config)
self.feat_encoder = GCNFeatExtractor(config)
self.classifier = Classifier(config)
self.graph_repr = None
[docs] def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
The Deep Coral-GCN model implementation.
Args:
*args (list): argument list for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>`
**kwargs (dict): key word arguments for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>`
Returns (Tensor):
[label predictions, features]
"""
out_readout = self.feat_encoder(*args, **kwargs)
out = self.classifier(out_readout)
return out, out_readout