Source code for GOOD.kernel.pipeline_manager

r"""A module that is consist of a pipeline loader and model configuration function.
"""

from typing import Dict
from typing import Union

import torch.nn
from munch import Munch
from torch.utils.data import DataLoader

from GOOD.kernel.pipelines.basic_pipeline import Pipeline
from GOOD.ood_algorithms.algorithms.BaseOOD import BaseOODAlg
from GOOD.utils.args import CommonArgs
from GOOD.utils.initial import reset_random_seed
from GOOD.utils.register import register


[docs]def load_pipeline(name: str, task: str, model: torch.nn.Module, loader: Union[DataLoader, Dict[str, DataLoader]], ood_algorithm: BaseOODAlg, config: Union[CommonArgs, Munch] ) -> Pipeline: r""" A pipeline loader. Args: name (str): Name of the chosen pipeline config (Union[CommonArgs, Munch]): Please refer to specific GNNs for required configs and formats. Returns: A instantiated pipeline. """ try: reset_random_seed(config) pipeline = register.pipelines[name](task, model, loader, ood_algorithm, config) except KeyError as e: print(f'#E#Pipeline {name} does not exist.') raise e return pipeline