Source code for GOOD.networks.models.GSATGNNs

r"""
Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism <https://arxiv.org/abs/2201.12987>`_.
"""

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn import InstanceNorm
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import is_undirected
from torch_sparse import transpose

from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseGNN import GNNBasic
from .Classifiers import Classifier
from .GINs import GINFeatExtractor
from .GINvirtualnode import vGINFeatExtractor


[docs]@register.model_register class GSATGIN(GNNBasic): def __init__(self, config: Union[CommonArgs, Munch]): super(GSATGIN, self).__init__(config) self.gnn = GINFeatExtractor(config) self.extractor = ExtractorMLP(config) self.classifier = Classifier(config) self.learn_edge_att = config.ood.extra_param[0] self.config = config
[docs] def forward(self, *args, **kwargs): r""" The GSAT model implementation. Args: *args (list): argument list for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>` **kwargs (dict): key word arguments for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>` Returns (Tensor): Label predictions and other results for loss calculations. """ data = kwargs.get('data') emb = self.gnn(*args, without_readout=True, **kwargs) att_log_logits = self.extractor(emb, data.edge_index, data.batch) att = self.sampling(att_log_logits, self.training) if self.learn_edge_att: if is_undirected(data.edge_index): nodesize = data.x.shape[0] edge_att = (att + transpose(data.edge_index, att, nodesize, nodesize, coalesced=False)[1]) / 2 else: edge_att = att else: edge_att = self.lift_node_att_to_edge_att(att, data.edge_index) set_masks(edge_att, self) logits = self.classifier(self.gnn(*args, **kwargs)) clear_masks(self) return logits, att, edge_att
def sampling(self, att_log_logits, training): att = self.concrete_sample(att_log_logits, temp=1, training=training) return att @staticmethod def lift_node_att_to_edge_att(node_att, edge_index): src_lifted_att = node_att[edge_index[0]] dst_lifted_att = node_att[edge_index[1]] edge_att = src_lifted_att * dst_lifted_att return edge_att @staticmethod def concrete_sample(att_log_logit, temp, training): if training: random_noise = torch.empty_like(att_log_logit).uniform_(1e-10, 1 - 1e-10) random_noise = torch.log(random_noise) - torch.log(1.0 - random_noise) att_bern = ((att_log_logit + random_noise) / temp).sigmoid() else: att_bern = (att_log_logit).sigmoid() return att_bern
[docs]@register.model_register class GSATvGIN(GSATGIN): r""" The GIN virtual node version of GSAT. """ def __init__(self, config: Union[CommonArgs, Munch]): super(GSATvGIN, self).__init__(config) self.gnn = vGINFeatExtractor(config)
[docs]class ExtractorMLP(nn.Module): def __init__(self, config: Union[CommonArgs, Munch]): super().__init__() hidden_size = config.model.dim_hidden self.learn_edge_att = config.ood.extra_param[0] # learn_edge_att dropout_p = config.model.dropout_rate if self.learn_edge_att: self.feature_extractor = MLP([hidden_size * 2, hidden_size * 4, hidden_size, 1], dropout=dropout_p) else: self.feature_extractor = MLP([hidden_size * 1, hidden_size * 2, hidden_size, 1], dropout=dropout_p)
[docs] def forward(self, emb, edge_index, batch): if self.learn_edge_att: col, row = edge_index f1, f2 = emb[col], emb[row] f12 = torch.cat([f1, f2], dim=-1) att_log_logits = self.feature_extractor(f12, batch[col]) else: att_log_logits = self.feature_extractor(emb, batch) return att_log_logits
[docs]class BatchSequential(nn.Sequential):
[docs] def forward(self, inputs, batch): for module in self._modules.values(): if isinstance(module, (InstanceNorm)): if batch.shape[0] == 0: inputs = inputs else: inputs = module(inputs, batch) else: inputs = module(inputs) return inputs
[docs]class MLP(BatchSequential): def __init__(self, channels, dropout, bias=True): m = [] for i in range(1, len(channels)): m.append(nn.Linear(channels[i - 1], channels[i], bias)) if i < len(channels) - 1: m.append(InstanceNorm(channels[i])) m.append(nn.ReLU()) m.append(nn.Dropout(dropout)) super(MLP, self).__init__(*m)
[docs]def set_masks(mask: Tensor, model: nn.Module): r""" Modified from https://github.com/wuyxin/dir-gnn. """ for module in model.modules(): if isinstance(module, MessagePassing): module._apply_sigmoid = False module.__explain__ = True module._explain = True module.__edge_mask__ = mask module._edge_mask = mask
[docs]def clear_masks(model: nn.Module): r""" Modified from https://github.com/wuyxin/dir-gnn. """ for module in model.modules(): if isinstance(module, MessagePassing): module.__explain__ = False module._explain = False module.__edge_mask__ = None module._edge_mask = None