Source code for GOOD.networks.model_manager

r"""A module that is consist of a GNN model loader and model configuration function.
"""

import torch

from GOOD import register
from GOOD.utils.config_reader import Union, CommonArgs, Munch
from GOOD.utils.initial import reset_random_seed


[docs]def load_model(name: str, config: Union[CommonArgs, Munch]) -> torch.nn.Module: r""" A model loader. Args: name (str): Name of the chosen GNN. config (Union[CommonArgs, Munch]): Please refer to specific GNNs for required configs and formats. Returns: A instantiated GNN model. """ try: reset_random_seed(config) model = register.models[name](config) except KeyError as e: print(f'#E#Model {name} does not exist.') raise e return model