r"""
Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism <https://arxiv.org/abs/2201.12987>`_.
"""
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn import InstanceNorm
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import is_undirected
from torch_sparse import transpose
from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .BaseGNN import GNNBasic
from .Classifiers import Classifier
from .GINs import GINFeatExtractor
from .GINvirtualnode import vGINFeatExtractor
[docs]@register.model_register
class GSATGIN(GNNBasic):
def __init__(self, config: Union[CommonArgs, Munch]):
super(GSATGIN, self).__init__(config)
self.gnn = GINFeatExtractor(config)
self.extractor = ExtractorMLP(config)
self.classifier = Classifier(config)
self.learn_edge_att = config.ood.extra_param[0]
self.config = config
[docs] def forward(self, *args, **kwargs):
r"""
The GSAT model implementation.
Args:
*args (list): argument list for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>`
**kwargs (dict): key word arguments for the use of arguments_read. Refer to :func:`arguments_read <GOOD.networks.models.BaseGNN.GNNBasic.arguments_read>`
Returns (Tensor):
Label predictions and other results for loss calculations.
"""
data = kwargs.get('data')
emb = self.gnn(*args, without_readout=True, **kwargs)
att_log_logits = self.extractor(emb, data.edge_index, data.batch)
att = self.sampling(att_log_logits, self.training)
if self.learn_edge_att:
if is_undirected(data.edge_index):
nodesize = data.x.shape[0]
edge_att = (att + transpose(data.edge_index, att, nodesize, nodesize, coalesced=False)[1]) / 2
else:
edge_att = att
else:
edge_att = self.lift_node_att_to_edge_att(att, data.edge_index)
set_masks(edge_att, self)
logits = self.classifier(self.gnn(*args, **kwargs))
clear_masks(self)
return logits, att, edge_att
def sampling(self, att_log_logits, training):
att = self.concrete_sample(att_log_logits, temp=1, training=training)
return att
@staticmethod
def lift_node_att_to_edge_att(node_att, edge_index):
src_lifted_att = node_att[edge_index[0]]
dst_lifted_att = node_att[edge_index[1]]
edge_att = src_lifted_att * dst_lifted_att
return edge_att
@staticmethod
def concrete_sample(att_log_logit, temp, training):
if training:
random_noise = torch.empty_like(att_log_logit).uniform_(1e-10, 1 - 1e-10)
random_noise = torch.log(random_noise) - torch.log(1.0 - random_noise)
att_bern = ((att_log_logit + random_noise) / temp).sigmoid()
else:
att_bern = (att_log_logit).sigmoid()
return att_bern
[docs]@register.model_register
class GSATvGIN(GSATGIN):
r"""
The GIN virtual node version of GSAT.
"""
def __init__(self, config: Union[CommonArgs, Munch]):
super(GSATvGIN, self).__init__(config)
self.gnn = vGINFeatExtractor(config)
[docs]class BatchSequential(nn.Sequential):
[docs] def forward(self, inputs, batch):
for module in self._modules.values():
if isinstance(module, (InstanceNorm)):
if batch.shape[0] == 0:
inputs = inputs
else:
inputs = module(inputs, batch)
else:
inputs = module(inputs)
return inputs
[docs]class MLP(BatchSequential):
def __init__(self, channels, dropout, bias=True):
m = []
for i in range(1, len(channels)):
m.append(nn.Linear(channels[i - 1], channels[i], bias))
if i < len(channels) - 1:
m.append(InstanceNorm(channels[i]))
m.append(nn.ReLU())
m.append(nn.Dropout(dropout))
super(MLP, self).__init__(*m)
[docs]def set_masks(mask: Tensor, model: nn.Module):
r"""
Modified from https://github.com/wuyxin/dir-gnn.
"""
for module in model.modules():
if isinstance(module, MessagePassing):
module._apply_sigmoid = False
module.__explain__ = True
module._explain = True
module.__edge_mask__ = mask
module._edge_mask = mask
[docs]def clear_masks(model: nn.Module):
r"""
Modified from https://github.com/wuyxin/dir-gnn.
"""
for module in model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = False
module._explain = False
module.__edge_mask__ = None
module._edge_mask = None