Source code for GOOD.data.good_loaders.BaseLoader

import random

from torch_geometric.loader import DataLoader, GraphSAINTRandomWalkSampler

from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from GOOD.utils.initial import reset_random_seed
from typing import List, Iterator
from torch.utils.data.sampler import Sampler
from torch_geometric.data.dataset import Dataset
import numpy as np
import torch

[docs]@register.dataloader_register class BaseDataLoader(Munch): def __init__(self, *args, **kwargs): super(BaseDataLoader, self).__init__(*args, **kwargs)
[docs] @classmethod def setup(cls, dataset, config: Union[CommonArgs, Munch]): r""" Create a PyG data loader. Args: dataset: A GOOD dataset. config: Required configs: ``config.train.train_bs`` ``config.train.val_bs`` ``config.train.test_bs`` ``config.model.model_layer`` ``config.train.num_steps(for node prediction)`` Returns: A PyG dataset loader. """ reset_random_seed(config) def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2 ** 32 np.random.seed(worker_seed) random.seed(worker_seed) g = torch.Generator() g.manual_seed(config.random_seed) if config.model.model_level == 'node': graph = dataset[0] loader = GraphSAINTRandomWalkSampler(graph, batch_size=config.train.train_bs, walk_length=config.model.model_layer, num_steps=config.train.num_steps, sample_coverage=100, save_dir=dataset.processed_dir) if config.ood.ood_alg == 'EERM': loader = {'train': [graph], 'eval_train': [graph], 'id_val': [graph], 'id_test': [graph], 'val': [graph], 'test': [graph]} else: loader = {'train': loader, 'eval_train': [graph], 'id_val': [graph], 'id_test': [graph], 'val': [graph], 'test': [graph]} else: loader = {'train': DataLoader(dataset['train'], batch_size=config.train.train_bs, shuffle=True, num_workers=config.num_workers, worker_init_fn=seed_worker, generator=g), 'eval_train': DataLoader(dataset['train'], batch_size=config.train.val_bs, shuffle=False, num_workers=config.num_workers, worker_init_fn=seed_worker, generator=g), 'id_val': DataLoader(dataset['id_val'], batch_size=config.train.val_bs, shuffle=False, num_workers=config.num_workers, worker_init_fn=seed_worker, generator=g) if dataset.get( 'id_val') else None, 'id_test': DataLoader(dataset['id_test'], batch_size=config.train.test_bs, shuffle=False, num_workers=config.num_workers, worker_init_fn=seed_worker, generator=g) if dataset.get( 'id_test') else None, 'val': DataLoader(dataset['val'], batch_size=config.train.val_bs, shuffle=False, num_workers=config.num_workers, worker_init_fn=seed_worker, generator=g), 'test': DataLoader(dataset['test'], batch_size=config.train.test_bs, shuffle=False, num_workers=config.num_workers, worker_init_fn=seed_worker, generator=g)} return cls(loader)