Source code for GOOD.networks.models.BaseGNN

"""
Base classes for Graph Neural Networks
"""
import torch
import torch.nn as nn
from torch_geometric.data.batch import Batch
from torch import Tensor
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from .Pooling import GlobalMeanPool, GlobalMaxPool, IdenticalPool
from torch.nn import Identity


[docs]class GNNBasic(torch.nn.Module): r""" Base class for graph neural networks Args: *args (list): argument list for the use of :func:`~arguments_read` **kwargs (dict): key word arguments for the use of :func:`~arguments_read` """ def __init__(self, config: Union[CommonArgs, Munch], *args, **kwargs): super(GNNBasic, self).__init__() self.config = config
[docs] def arguments_read(self, *args, **kwargs): r""" It is an argument reading function for diverse model input formats. Support formats are: ``model(x, edge_index)`` ``model(x, edge_index, batch)`` ``model(data=data)``. Notes: edge_weight is optional for node prediction tasks. Args: *args: [x, edge_index, [batch]] **kwargs: data, [edge_weight] Returns: Unpacked node features, sparse adjacency matrices, batch indicators, and optional edge weights. """ data: Batch = kwargs.get('data') or None if not data: if not args: assert 'x' in kwargs assert 'edge_index' in kwargs x, edge_index = kwargs['x'], kwargs['edge_index'], batch = kwargs.get('batch') if batch is None: batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=torch.device('cuda')) elif len(args) == 2: x, edge_index, batch = args[0], args[1], \ torch.zeros(args[0].shape[0], dtype=torch.int64, device=torch.device('cuda')) elif len(args) == 3: x, edge_index, batch = args[0], args[1], args[2] else: raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") else: x, edge_index, batch = data.x, data.edge_index, data.batch if self.config.model.model_level != 'node': # --- Maybe batch size --- Reason: some method may filter graphs leading inconsistent of batch size batch_size: int = kwargs.get('batch_size') or (batch[-1].item() + 1) if self.config.model.model_level == 'node': edge_weight = kwargs.get('edge_weight') return x, edge_index, edge_weight, batch elif self.config.dataset.dim_edge: edge_attr = data.edge_attr return x, edge_index, edge_attr, batch, batch_size return x, edge_index, batch, batch_size
def probs(self, *args, **kwargs): # nodes x classes return self(*args, **kwargs).softmax(dim=1)
[docs]class BasicEncoder(torch.nn.Module): r""" Base GNN feature encoder. Args: config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.model.model_level`, :obj:`config.model.global_pool`, :obj:`config.model.dropout_rate`) .. code-block:: python config = munchify({model: {dim_hidden: int(300), model_layer: int(5), model_level: str('node'), global_pool: str('mean'), dropout_rate: float(0.5),} }) """ def __init__(self, config: Union[CommonArgs, Munch], **kwargs): if type(self).mro()[type(self).mro().index(__class__) + 1] is torch.nn.Module: super(BasicEncoder, self).__init__() else: super(BasicEncoder, self).__init__(config) num_layer = config.model.model_layer self.relu1 = nn.ReLU() self.relus = nn.ModuleList( [ nn.ReLU() for _ in range(num_layer - 1) ] ) if kwargs.get('no_bn'): self.batch_norm1 = Identity() self.batch_norms = [ Identity() for _ in range(num_layer - 1) ] else: self.batch_norm1 = nn.BatchNorm1d(config.model.dim_hidden) self.batch_norms = nn.ModuleList([ nn.BatchNorm1d(config.model.dim_hidden) for _ in range(num_layer - 1) ]) self.dropout1 = nn.Dropout(config.model.dropout_rate) self.dropouts = nn.ModuleList([ nn.Dropout(config.model.dropout_rate) for _ in range(num_layer - 1) ]) if config.model.model_level == 'node': self.readout = IdenticalPool() elif config.model.global_pool == 'mean': self.readout = GlobalMeanPool() else: self.readout = GlobalMaxPool()