Source code for GOOD.data.good_datasets.good_arxiv

"""
The GOOD-Arxiv dataset adapted from `OGB
<https://proceedings.neurips.cc/paper/2020/hash/fb60d411a5c5b72b2e7d3527cfc84fd0-Abstract.html>`_ benchmark.
"""

import itertools
import os
import os.path as osp
import random
from copy import deepcopy

import gdown
import numpy as np
import torch
from munch import Munch
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset, extract_zip
from torch_geometric.utils import degree, to_undirected
from tqdm import tqdm


[docs]class DomainGetter(object): r""" A class containing methods for data domain extraction. """ def __init__(self): pass
[docs] def get_degree(self, graph: Data) -> int: """ Args: graph (Data): The PyG Data object. Returns: The degrees of the given graph. """ try: node_degree = degree(graph.edge_index[0], graph.num_nodes) return node_degree except ValueError as e: print('#E#Get degree error.') raise e
[docs] def get_time(self, graph: Data) -> int: """ Args: graph (Data): The PyG Data object. Returns: The year domain value of the graph. """ year = graph.node_year.squeeze() return year
[docs]class DataInfo(object): r""" The class for data point storage. This enables tackling node data point like graph data point, facilitating data splits. """ def __init__(self, idx, y): super(DataInfo, self).__init__() self.storage = [] self.idx = idx self.y = y def __repr__(self): s = [f'{key}={self.__getattribute__(key)}' for key in self.storage] s = ', '.join(s) return f"DataInfo({s})" def __setattr__(self, key, value): super().__setattr__(key, value) if key != 'storage': self.storage.append(key)
from GOOD import register
[docs]@register.dataset_register class GOODArxiv(InMemoryDataset): r""" The GOOD-Arxiv dataset adapted from `OGB <https://proceedings.neurips.cc/paper/2020/hash/fb60d411a5c5b72b2e7d3527cfc84fd0-Abstract.html>`_ benchmark. Args: root (str): The dataset saving root. domain (str): The domain selection. Allowed: 'degree' and 'time'. shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'. generate (bool): The flag for regenerating dataset. True: regenerate. False: download. """ def __init__(self, root: str, domain: str, shift: str = 'no_shift', transform=None, pre_transform=None, generate: bool = False): self.name = self.__class__.__name__ self.domain = domain self.metric = 'Accuracy' self.task = 'Multi-label classification' self.url = 'https://drive.google.com/file/d/1r1OTQJ5YxQAAYJiYfyDmCknmpVmiUksi/view?usp=sharing' self.generate = generate super().__init__(root, transform, pre_transform) shift_mode = {'no_shift': 0, 'covariate': 1, 'concept': 2} subset_pt = shift_mode[shift] self.data, self.slices = torch.load(self.processed_paths[subset_pt]) @property def raw_dir(self): return osp.join(self.root) def _download(self): if os.path.exists(osp.join(self.raw_dir, self.name)) or self.generate: return if not os.path.exists(self.raw_dir): os.makedirs(self.raw_dir) self.download()
[docs] def download(self): path = gdown.download(self.url, output=osp.join(self.raw_dir, self.name + '.zip'), fuzzy=True) extract_zip(path, self.raw_dir) os.unlink(path)
@property def processed_dir(self): ''' Returns: ''' return osp.join(self.root, self.name, self.domain, 'processed') @property def processed_file_names(self): return ['no_shift.pt', 'covariate.pt', 'concept.pt'] def assign_masks(self, train_list, val_list, test_list, id_val_list, id_test_list, graph): num_data = self.num_data train_mask, val_mask, test_mask, id_val_mask, id_test_mask = (torch.zeros((num_data,), dtype=torch.bool) for _ in range(5)) env_id = - torch.ones((num_data,), dtype=torch.long) domain = [None for _ in range(num_data)] domain_id = - torch.ones((num_data,), dtype=torch.long) for data in train_list: train_mask[data.idx] = True env_id[data.idx] = data.env_id domain[data.idx] = data.__getattribute__(self.domain) domain_id[data.idx] = data.domain_id for data in val_list: val_mask[data.idx] = True domain[data.idx] = data.__getattribute__(self.domain) domain_id[data.idx] = data.domain_id for data in test_list: test_mask[data.idx] = True domain[data.idx] = data.__getattribute__(self.domain) domain_id[data.idx] = data.domain_id for data in id_val_list: id_val_mask[data.idx] = True domain[data.idx] = data.__getattribute__(self.domain) domain_id[data.idx] = data.domain_id for data in id_test_list: id_test_mask[data.idx] = True domain[data.idx] = data.__getattribute__(self.domain) domain_id[data.idx] = data.domain_id graph.train_mask = train_mask graph.val_mask = val_mask graph.test_mask = test_mask graph.id_val_mask = id_val_mask graph.id_test_mask = id_test_mask graph.env_id = env_id graph.domain = self.domain # graph.__setattr__(self.domain, domain) graph.domain_id = domain_id return graph def get_no_shift_graph(self, graph): num_data = self.num_data node_indices = torch.randperm(num_data) if self.domain == 'degree': train_ratio = 0.6 val_ratio = 0.2 test_ratio = 0.2 else: train_ratio = 0.6 val_ratio = 0.2 test_ratio = 0.2 train_split = int(num_data * train_ratio) val_split = int(num_data * (train_ratio + val_ratio)) train_indices, val_indices, test_indices = node_indices[: train_split], node_indices[ train_split: val_split], node_indices[ val_split:] train_mask, val_mask, test_mask = (torch.zeros((num_data,), dtype=torch.bool) for _ in range(3)) env_id = - torch.ones((num_data,), dtype=torch.long) train_mask[train_indices] = True val_mask[val_indices] = True test_mask[test_indices] = True env_id[train_indices] = torch.randint(0, 9, (train_indices.shape[0],)) graph.train_mask = train_mask graph.val_mask = val_mask graph.test_mask = test_mask graph.env_id = env_id graph.domain = self.domain return graph def get_covariate_shift_graph(self, sorted_data_list, graph): num_data = self.num_data if self.domain == 'degree': sorted_data_list = sorted_data_list[::-1] train_ratio = 0.6 val_ratio = 0.2 test_ratio = 0.2 else: train_ratio = 0.5 # This is not a fault, it is for specific year splits. val_ratio = 0.2 test_ratio = 0.2 id_test_ratio = 0.1 train_split = int(num_data * train_ratio) val_split = int(num_data * (train_ratio + val_ratio)) train_val_test_split = [0, train_split, val_split] train_val_test_list = [[], [], []] cur_env_id = -1 cur_domain_id = None for i, data in enumerate(sorted_data_list): if cur_env_id < 2 and i >= train_val_test_split[cur_env_id + 1] and data.domain_id != cur_domain_id: # if i >= (cur_env_id + 1) * num_per_env: cur_env_id += 1 cur_domain_id = data.domain_id train_val_test_list[cur_env_id].append(data) train_list, ood_val_list, ood_test_list = train_val_test_list # Compose domains to environments num_env_train = 10 num_per_env = len(train_list) // num_env_train cur_env_id = -1 cur_domain_id = None for i, data in enumerate(train_list): if cur_env_id < 9 and i >= (cur_env_id + 1) * num_per_env and data.domain_id != cur_domain_id: # if i >= (cur_env_id + 1) * num_per_env: cur_env_id += 1 cur_domain_id = data.domain_id data.env_id = cur_env_id num_id_test = int(num_data * id_test_ratio) random.shuffle(train_list) train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[ -2 * num_id_test: - num_id_test], \ train_list[- num_id_test:] return self.assign_masks(train_list, ood_val_list, ood_test_list, id_val_list, id_test_list, graph) def get_concept_shift_graph(self, sorted_domain_split_data_list, graph): # Calculate concept probability for each domain global_pyx = [] for each_domain_datas in tqdm(sorted_domain_split_data_list): pyx = [] for data in each_domain_datas: data.pyx = torch.tensor(np.nanmean(data.y).item()) if torch.isnan(data.pyx): data.pyx = torch.tensor(0.) pyx.append(data.pyx.item()) global_pyx.append(data.pyx.item()) pyx = sum(pyx) / each_domain_datas.__len__() each_domain_datas.append(pyx) global_mean_pyx = np.mean(global_pyx) global_mid_pyx = np.sort(global_pyx)[len(global_pyx) // 2] # sorted_domain_split_data_list = sorted(sorted_domain_split_data_list, key=lambda domain_data: domain_data[ # -1], reverse=) bias_connect = [0.95, 0.95, 0.9, 0.85, 0.5] is_train_split = [True, False, True, True, False] is_val_split = [False if i < len(is_train_split) - 1 else True for i in range(len(is_train_split))] is_test_split = [not (tr_sp or val_sp) for tr_sp, val_sp in zip(is_train_split, is_val_split)] split_picking_ratio = [0.4, 0.6, 0.5, 1, 1] order_connect = [[] for _ in range(len(bias_connect))] cur_num = 0 for i in range(len(sorted_domain_split_data_list)): randc = 1 if cur_num < self.num_data / 2 else - 1 cur_num += sorted_domain_split_data_list[i].__len__() - 1 for j in range(len(order_connect)): order_connect[j].append(randc if is_train_split[j] else - randc) env_list = [[] for _ in range(len(bias_connect))] cur_split = 0 env_id = -1 while cur_split < len(env_list): if is_train_split[cur_split]: env_id += 1 next_split = False for domain_id, each_domain_datas in enumerate(sorted_domain_split_data_list): pyx_mean = each_domain_datas[-1] pop_items = [] both_label_domain = [False, False] label_data_candidate = [None, None] both_label_include = [False, False] for i in range(len(each_domain_datas) - 1): data = each_domain_datas[i] picking_rand = random.random() data_rand = random.random() # random num for data point if cur_split == len(env_list) - 1: data.env_id = env_id env_list[cur_split].append(data) pop_items.append(data) else: # if order_connect[cur_split][domain_id] * (pyx_mean - global_mean_pyx) * ( # data.pyx - pyx_mean) > 0: # same signal test if order_connect[cur_split][domain_id] * (data.pyx - global_mean_pyx) > 0: both_label_domain[0] = True if data_rand < bias_connect[cur_split] and picking_rand < split_picking_ratio[cur_split]: both_label_include[0] = True data.env_id = env_id env_list[cur_split].append(data) pop_items.append(data) else: label_data_candidate[0] = data else: both_label_domain[1] = True if data_rand > bias_connect[cur_split] and picking_rand < split_picking_ratio[cur_split]: both_label_include[1] = True data.env_id = env_id env_list[cur_split].append(data) pop_items.append(data) else: label_data_candidate[1] = data # if env_list[cur_split].__len__() >= num_split[cur_split]: # next_split = True # --- Add extra data: avoid extreme label imbalance --- if both_label_domain[0] and both_label_domain[1] and (both_label_include[0] or both_label_include[1]): extra_data = None if not both_label_include[0]: extra_data = label_data_candidate[0] if not both_label_include[1]: extra_data = label_data_candidate[1] if extra_data: extra_data.env_id = env_id env_list[cur_split].append(extra_data) pop_items.append(extra_data) for pop_item in pop_items: each_domain_datas.remove(pop_item) cur_split += 1 num_train = sum([len(env) for i, env in enumerate(env_list) if is_train_split[i]]) num_val = sum([len(env) for i, env in enumerate(env_list) if is_val_split[i]]) num_test = sum([len(env) for i, env in enumerate(env_list) if is_test_split[i]]) print("#D#train: %d, val: %d, test: %d" % (num_train, num_val, num_test)) # all_env_list = [env_list[0], env_list[1], env_list[2]] # Use test set as validation # all_env_list = [env_list[0], env_list[2], env_list[1]] # True split train_list, ood_val_list, ood_test_list = list( itertools.chain(*[env for i, env in enumerate(env_list) if is_train_split[i]])), \ list(itertools.chain( *[env for i, env in enumerate(env_list) if is_val_split[i]])), \ list(itertools.chain( *[env for i, env in enumerate(env_list) if is_test_split[i]])) id_test_ratio = 0.15 num_id_test = int(len(train_list) * id_test_ratio) random.shuffle(train_list) train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \ train_list[-2 * num_id_test: - num_id_test], \ train_list[- num_id_test:] all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list] return self.assign_masks(train_list, ood_val_list, ood_test_list, id_val_list, id_test_list, graph) def get_domain_sorted_indices(self, graph, domain='degree'): domain_getter = DomainGetter() graph.__setattr__(domain, getattr(domain_getter, f'get_{domain}')(graph)) data_list = [] for i in tqdm(range(self.num_data)): data_info = DataInfo(idx=i, y=graph.y[i]) data_info.__setattr__(domain, graph.__getattr__(domain)[i]) data_list.append(data_info) sorted_data_list = sorted(data_list, key=lambda data: getattr(data, domain)) # Assign domain id cur_domain_id = -1 cur_domain = None sorted_domain_split_data_list = [] for data in sorted_data_list: if getattr(data, domain) != cur_domain: cur_domain = getattr(data, domain) cur_domain_id += 1 sorted_domain_split_data_list.append([]) data.domain_id = torch.LongTensor([cur_domain_id]) sorted_domain_split_data_list[data.domain_id].append(data) return sorted_data_list, sorted_domain_split_data_list
[docs] def process(self): raise Exception('OGB changed the Arxiv version. Regeneration will not be valid anymore.') from ogb.nodeproppred import PygNodePropPredDataset dataset = PygNodePropPredDataset(root=self.root, name='ogbn-arxiv') graph = dataset[0] graph.edge_index = to_undirected(graph.edge_index, graph.num_nodes) graph.y = graph.y.squeeze() print('Load data done!') self.num_data = graph.x.shape[0] print('Extract data done!') no_shift_graph = self.get_no_shift_graph(deepcopy(graph)) print('#IN#No shift dataset done!') sorted_data_list, sorted_domain_split_data_list = self.get_domain_sorted_indices(graph, domain=self.domain) covariate_shift_graph = self.get_covariate_shift_graph(deepcopy(sorted_data_list), deepcopy(graph)) print() print('#IN#Covariate shift dataset done!') concept_shift_graph = self.get_concept_shift_graph(deepcopy(sorted_domain_split_data_list), deepcopy(graph)) print() print('#IN#Concept shift dataset done!') all_split_graph = [no_shift_graph, covariate_shift_graph, concept_shift_graph] for i, final_graph in enumerate(all_split_graph): data, slices = self.collate([final_graph]) torch.save((data, slices), self.processed_paths[i])
[docs] @staticmethod def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool = False): r""" A staticmethod for dataset loading. This method instantiates dataset class, constructing train, id_val, id_test, ood_val (val), and ood_test (test) splits. Besides, it collects several dataset meta information for further utilization. Args: dataset_root (str): The dataset saving root. domain (str): The domain selection. Allowed: 'degree' and 'time'. shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'. generate (bool): The flag for regenerating dataset. True: regenerate. False: download. Returns: dataset or dataset splits. dataset meta info. """ meta_info = Munch() meta_info.dataset_type = 'real' meta_info.model_level = 'node' dataset = GOODArxiv(root=dataset_root, domain=domain, shift=shift, generate=generate) dataset.data.x = dataset.data.x.to(torch.float32) meta_info.dim_node = dataset.num_node_features meta_info.dim_edge = dataset.num_edge_features meta_info.num_envs = (torch.unique(dataset.data.env_id) >= 0).sum() meta_info.num_train_nodes = dataset[0].train_mask.sum() # Define networks' output shape. if dataset.task == 'Binary classification': meta_info.num_classes = dataset.data.y.shape[1] elif dataset.task == 'Regression': meta_info.num_classes = 1 elif dataset.task == 'Multi-label classification': meta_info.num_classes = torch.unique(dataset.data.y).shape[0] # --- clear buffer dataset._data_list --- dataset._data_list = None return dataset, meta_info