Source code for GOOD.networks.models.EERM

r"""
The implementation of `Handling Distribution Shifts on Graphs: An Invariance Perspective <https://arxiv.org/abs/2202.02466>`_.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj, dense_to_sparse, subgraph

from GOOD import register
from .BaseGNN import GNNBasic
from .Classifiers import Classifier
from .GCNs import GCNFeatExtractor


[docs]@register.model_register class EERMGCN(GNNBasic): r""" EERM implementation adapted from https://github.com/qitianwu/GraphOOD-EERM. """ def __init__(self, config): super(EERMGCN, self).__init__(config) self.gnn = GCNFeatExtractor(config) self.p = 0.2 self.K = config.ood.extra_param[0] self.T = config.ood.extra_param[1] self.num_sample = config.ood.extra_param[2] self.classifier = Classifier(config) self.gl = Graph_Editer(self.K, config.dataset.num_train_nodes, config.device) self.gl.reset_parameters() self.gl_optimizer = torch.optim.Adam(self.gl.parameters(), lr=config.ood.extra_param[3]) def reset_parameters(self): self.gnn.reset_parameters() if hasattr(self, 'graph_est'): self.gl.reset_parameters()
[docs] def forward(self, *args, **kwargs): data = kwargs.get('data') loss_func = self.config.metric.loss_func # --- K fold --- if self.training: edge_index, _ = subgraph(data.train_mask, data.edge_index, relabel_nodes=True) x = data.x[data.train_mask] y = data.y[data.train_mask] # --- check will orig_edge_index change? --- orig_edge_index = edge_index for t in range(self.T): Loss, Log_p = [], 0 for k in range(self.K): edge_index, log_p = self.gl(orig_edge_index, self.num_sample, k) raw_pred = self.classifier(self.gnn(data=Data(x=x, edge_index=edge_index, y=y))) loss = loss_func(raw_pred, y) Loss.append(loss.view(-1)) Log_p += log_p Var, Mean = torch.var_mean(torch.cat(Loss, dim=0)) reward = Var.detach() inner_loss = - reward * Log_p self.gl_optimizer.zero_grad() inner_loss.backward() self.gl_optimizer.step() return Var, Mean else: out = self.classifier(self.gnn(data=data)) return out
[docs]class Graph_Editer(nn.Module): r""" EERM's graph editer adapted from https://github.com/qitianwu/GraphOOD-EERM. """ def __init__(self, K, n, device): super(Graph_Editer, self).__init__() self.B = nn.Parameter(torch.FloatTensor(K, n, n)) self.n = n self.device = device def reset_parameters(self): nn.init.uniform_(self.B)
[docs] def forward(self, edge_index, num_sample, k): n = self.n Bk = self.B[k] A = to_dense_adj(edge_index, max_num_nodes=n)[0].to(torch.int) A_c = torch.ones(n, n, dtype=torch.int).to(self.device) - A P = torch.softmax(Bk, dim=0) S = torch.multinomial(P, num_samples=num_sample) # [n, s] M = torch.zeros(n, n, dtype=torch.float).to(self.device) col_idx = torch.arange(0, n).unsqueeze(1).repeat(1, num_sample) M[S, col_idx] = 1. C = A + M * (A_c - A) edge_index = dense_to_sparse(C)[0] log_p = torch.sum( torch.sum(Bk[S, col_idx], dim=1) - torch.logsumexp(Bk, dim=0) ) return edge_index, log_p