Source code for GOOD.data.good_datasets.good_twitter

"""
The GOOD-SST2 dataset. Adapted from `DIG <https://github.com/divelab/DIG>`_.
"""
import itertools
import os
import os.path as osp
import random
from copy import deepcopy

import gdown
import numpy as np
import torch
from dig.xgraph.dataset import SentiGraphDataset
from munch import Munch
from torch_geometric.data import InMemoryDataset, extract_zip, Data
from tqdm import tqdm


[docs]class DomainGetter(): r""" A class containing methods for data domain extraction. """ def __init__(self): pass
[docs] def get_length(self, data: Data) -> int: """ Args: data (str): A PyG graph data object. Returns: The length of the sentence. """ return data.x.shape[0]
from GOOD import register
[docs]@register.dataset_register class GOODTwitter(InMemoryDataset): r""" The GOOD-Twitter dataset. Adapted from `DIG <https://github.com/divelab/DIG>`_. Args: root (str): The dataset saving root. domain (str): The domain selection. Allowed: 'length' shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'. subset (str): The split set. Allowed: 'train', 'id_val', 'id_test', 'val', and 'test'. When shift='no_shift', 'id_val' and 'id_test' are not applicable. generate (bool): The flag for regenerating dataset. True: regenerate. False: download. """ def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str = 'train', transform=None, pre_transform=None, generate: bool = False): self.name = self.__class__.__name__ self.domain = domain self.metric = 'Accuracy' self.task = 'Multi-label classification' self.url = 'https://drive.google.com/file/d/1um-ruqg58ULRkMZLOekF5FmLv4ZhGDnS/view?usp=sharing' self.generate = generate super().__init__(root, transform, pre_transform) if shift == 'covariate': subset_pt = 3 elif shift == 'concept': subset_pt = 8 elif shift == 'no_shift': subset_pt = 0 else: raise ValueError(f'Unknown shift: {shift}.') if subset == 'train': subset_pt += 0 elif subset == 'val': subset_pt += 1 elif subset == 'test': subset_pt += 2 elif subset == 'id_val': subset_pt += 3 else: subset_pt += 4 self.data, self.slices = torch.load(self.processed_paths[subset_pt]) @property def raw_dir(self): return osp.join(self.root) def _download(self): if os.path.exists(osp.join(self.raw_dir, self.name)) or self.generate: return if not os.path.exists(self.raw_dir): os.makedirs(self.raw_dir) self.download()
[docs] def download(self): path = gdown.download(self.url, output=osp.join(self.raw_dir, self.name + '.zip'), fuzzy=True) extract_zip(path, self.raw_dir) os.unlink(path)
@property def processed_dir(self): return osp.join(self.root, self.name, self.domain, 'processed') @property def processed_file_names(self): return ['no_shift_train.pt', 'no_shift_val.pt', 'no_shift_test.pt', 'covariate_train.pt', 'covariate_val.pt', 'covariate_test.pt', 'covariate_id_val.pt', 'covariate_id_test.pt', 'concept_train.pt', 'concept_val.pt', 'concept_test.pt', 'concept_id_val.pt', 'concept_id_test.pt'] def get_no_shift_list(self, data_list): random.shuffle(data_list) num_data = data_list.__len__() train_ratio = 0.6 val_ratio = 0.2 test_ratio = 0.2 train_split = int(num_data * train_ratio) val_split = int(num_data * (train_ratio + val_ratio)) train_list, val_list, test_list = data_list[: train_split], data_list[train_split: val_split], data_list[ val_split:] for data in train_list: data.env_id = random.randint(0, 9) all_env_list = [train_list, val_list, test_list] return all_env_list def get_covariate_shift_list(self, sorted_data_list): # #############debug # sorted_data_list = sorted_data_list[::-1] num_data = sorted_data_list.__len__() train_ratio = 0.5 val_ratio = 0.25 test_ratio = 0.25 train_split = int(num_data * train_ratio) val_split = int(num_data * (train_ratio + val_ratio)) train_val_test_split = [0, train_split, val_split] train_val_test_list = [[], [], []] cur_env_id = -1 cur_domain_id = None for i, data in enumerate(sorted_data_list): if cur_env_id < 2 and i >= train_val_test_split[cur_env_id + 1] and data.domain_id != cur_domain_id: # if i >= (cur_env_id + 1) * num_per_env: cur_env_id += 1 cur_domain_id = data.domain_id train_val_test_list[cur_env_id].append(data) train_list, ood_val_list, ood_test_list = train_val_test_list # Compose domains to environments num_env_train = 10 num_per_env = len(train_list) // num_env_train cur_env_id = -1 cur_domain_id = None for i, data in enumerate(train_list): if cur_env_id < 9 and i >= (cur_env_id + 1) * num_per_env and data.domain_id != cur_domain_id: # if i >= (cur_env_id + 1) * num_per_env: cur_env_id += 1 cur_domain_id = data.domain_id data.env_id = cur_env_id id_test_ratio = 0.15 num_id_test = int(len(train_list) * id_test_ratio) random.shuffle(train_list) train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[ -2 * num_id_test: - num_id_test], \ train_list[- num_id_test:] all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list] return all_env_list def get_concept_shift_list(self, sorted_domain_split_data_list): # Calculate concept probability for each domain global_pyx = [] for each_domain_datas in tqdm(sorted_domain_split_data_list): pyx = [] for data in each_domain_datas: data.pyx = torch.tensor(np.nanmean(data.y).item()) if torch.isnan(data.pyx): data.pyx = torch.tensor(0.) pyx.append(data.pyx.item()) global_pyx.append(data.pyx.item()) pyx = sum(pyx) / each_domain_datas.__len__() each_domain_datas.append(pyx) global_mean_pyx = np.mean(global_pyx) global_mid_pyx = np.sort(global_pyx)[len(global_pyx) // 2] # sorted_domain_split_data_list = sorted(sorted_domain_split_data_list, key=lambda domain_data: domain_data[-1], reverse=) bias_connect = [0.95, 0.95, 0.9, 0.85, 0.5] is_train_split = [True, False, True, True, False] is_val_split = [False if i < len(is_train_split) - 1 else True for i in range(len(is_train_split))] is_test_split = [not (tr_sp or val_sp) for tr_sp, val_sp in zip(is_train_split, is_val_split)] split_picking_ratio = [0.3, 0.5, 0.6, 1, 1] order_connect = [[] for _ in range(len(bias_connect))] cur_num = 0 for i in range(len(sorted_domain_split_data_list)): randc = 1 if cur_num < self.num_data / 2 else - 1 cur_num += sorted_domain_split_data_list[i].__len__() - 1 for j in range(len(order_connect)): order_connect[j].append(randc if is_train_split[j] else - randc) env_list = [[] for _ in range(len(bias_connect))] cur_split = 0 env_id = -1 while cur_split < len(env_list): if is_train_split[cur_split]: env_id += 1 next_split = False # domain_ids = np.random.permutation(len(sorted_domain_split_data_list)) # for domain_id in domain_ids: # each_domain_datas = sorted_domain_split_data_list[domain_id] for domain_id, each_domain_datas in enumerate(sorted_domain_split_data_list): pyx_mean = each_domain_datas[-1] pop_items = [] both_label_domain = [False, False] label_data_candidate = [None, None] both_label_include = [False, False] for i in range(len(each_domain_datas) - 1): data = each_domain_datas[i] picking_rand = random.random() data_rand = random.random() # random num for data point if cur_split == len(env_list) - 1: data.env_id = env_id env_list[cur_split].append(data) pop_items.append(data) else: # if order_connect[cur_split][domain_id] * (pyx_mean - global_mean_pyx) * ( # data.pyx - pyx_mean) > 0: # same signal test if order_connect[cur_split][domain_id] * (data.pyx - global_mean_pyx) > 0: both_label_domain[0] = True if data_rand < bias_connect[cur_split] and picking_rand < split_picking_ratio[cur_split]: both_label_include[0] = True data.env_id = env_id env_list[cur_split].append(data) pop_items.append(data) else: label_data_candidate[0] = data else: both_label_domain[1] = True if data_rand > bias_connect[cur_split] and picking_rand < split_picking_ratio[cur_split]: both_label_include[1] = True data.env_id = env_id env_list[cur_split].append(data) pop_items.append(data) else: label_data_candidate[1] = data # if env_list[cur_split].__len__() >= num_split[cur_split]: # next_split = True # --- Add extra data: avoid extreme label imbalance --- if both_label_domain[0] and both_label_domain[1] and (both_label_include[0] or both_label_include[1]): extra_data = None if not both_label_include[0]: extra_data = label_data_candidate[0] if not both_label_include[1]: extra_data = label_data_candidate[1] if extra_data: extra_data.env_id = env_id env_list[cur_split].append(extra_data) pop_items.append(extra_data) for pop_item in pop_items: each_domain_datas.remove(pop_item) cur_split += 1 num_train = sum([len(env) for i, env in enumerate(env_list) if is_train_split[i]]) num_val = sum([len(env) for i, env in enumerate(env_list) if is_val_split[i]]) num_test = sum([len(env) for i, env in enumerate(env_list) if is_test_split[i]]) print("#D#train: %d, val: %d, test: %d" % (num_train, num_val, num_test)) # all_env_list = [env_list[0], env_list[1], env_list[2]] # Use test set as validation # all_env_list = [env_list[0], env_list[2], env_list[1]] # True split train_list, ood_val_list, ood_test_list = list( itertools.chain(*[env for i, env in enumerate(env_list) if is_train_split[i]])), \ list(itertools.chain( *[env for i, env in enumerate(env_list) if is_val_split[i]])), \ list(itertools.chain( *[env for i, env in enumerate(env_list) if is_test_split[i]])) id_test_ratio = 0.15 num_id_test = int(len(train_list) * id_test_ratio) random.shuffle(train_list) train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \ train_list[-2 * num_id_test: - num_id_test], \ train_list[- num_id_test:] all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list] return all_env_list def get_domain_sorted_list(self, data_list, domain='length'): domain_getter = DomainGetter() for data in tqdm(data_list): data.__setattr__(domain, getattr(domain_getter, f'get_{domain}')(data)) sorted_data_list = sorted(data_list, key=lambda data: getattr(data, domain)) # Assign domain id cur_domain_id = -1 cur_domain = None sorted_domain_split_data_list = [] for data in sorted_data_list: if getattr(data, domain) != cur_domain: cur_domain = getattr(data, domain) cur_domain_id += 1 sorted_domain_split_data_list.append([]) data.domain_id = torch.LongTensor([cur_domain_id]) sorted_domain_split_data_list[data.domain_id].append(data) return sorted_data_list, sorted_domain_split_data_list
[docs] def process(self): dataset = SentiGraphDataset(root=self.root, name='Graph-Twitter') print('Load data done!') # dataset.data.y = dataset.data.y.unsqueeze(1).float() data_list = [] for i, data in enumerate(dataset): data.idx = i data.sentence_tokens = dataset.supplement['sentence_tokens'][str(i)] data_list.append(data) self.num_data = data_list.__len__() print('Extract data done!') no_shift_list = self.get_no_shift_list(deepcopy(data_list)) print('#IN#No shift dataset done!') sorted_data_list, sorted_domain_split_data_list = self.get_domain_sorted_list(data_list, domain=self.domain) covariate_shift_list = self.get_covariate_shift_list(deepcopy(sorted_data_list)) print() print('#IN#Covariate shift dataset done!') concept_shift_list = self.get_concept_shift_list(deepcopy(sorted_domain_split_data_list)) print() print('#IN#Concept shift dataset done!') all_data_list = no_shift_list + covariate_shift_list + concept_shift_list for i, final_data_list in enumerate(all_data_list): data, slices = self.collate(final_data_list) torch.save((data, slices), self.processed_paths[i])
[docs] @staticmethod def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool = False): r""" A staticmethod for dataset loading. This method instantiates dataset class, constructing train, id_val, id_test, ood_val (val), and ood_test (test) splits. Besides, it collects several dataset meta information for further utilization. Args: dataset_root (str): The dataset saving root. domain (str): The domain selection. Allowed: 'degree' and 'time'. shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'. generate (bool): The flag for regenerating dataset. True: regenerate. False: download. Returns: dataset or dataset splits. dataset meta info. """ meta_info = Munch() meta_info.dataset_type = 'nlp' meta_info.model_level = 'graph' train_dataset = GOODTwitter(root=dataset_root, domain=domain, shift=shift, subset='train', generate=generate) id_val_dataset = GOODTwitter(root=dataset_root, domain=domain, shift=shift, subset='id_val', generate=generate) if shift != 'no_shift' else None id_test_dataset = GOODTwitter(root=dataset_root, domain=domain, shift=shift, subset='id_test', generate=generate) if shift != 'no_shift' else None val_dataset = GOODTwitter(root=dataset_root, domain=domain, shift=shift, subset='val', generate=generate) test_dataset = GOODTwitter(root=dataset_root, domain=domain, shift=shift, subset='test', generate=generate) meta_info.dim_node = train_dataset.num_node_features meta_info.dim_edge = train_dataset.num_edge_features meta_info.num_envs = torch.unique(train_dataset.data.env_id).shape[0] # Define networks' output shape. if train_dataset.task == 'Binary classification': meta_info.num_classes = train_dataset.data.y.shape[1] elif train_dataset.task == 'Regression': meta_info.num_classes = 1 elif train_dataset.task == 'Multi-label classification': meta_info.num_classes = torch.unique(train_dataset.data.y).shape[0] # --- clear buffer dataset._data_list --- train_dataset._data_list = None if id_val_dataset: id_val_dataset._data_list = None id_test_dataset._data_list = None val_dataset._data_list = None test_dataset._data_list = None return {'train': train_dataset, 'id_val': id_val_dataset, 'id_test': id_test_dataset, 'val': val_dataset, 'test': test_dataset, 'task': train_dataset.task, 'metric': train_dataset.metric}, meta_info