Source code for GOOD.networks.models.Classifiers

r"""
Applies a linear transformation to complete classification from representations.
"""
import torch
import torch.nn as nn
from torch import Tensor

from GOOD.utils.config_reader import Union, CommonArgs, Munch


[docs]class Classifier(torch.nn.Module): r""" Applies a linear transformation to complete classification from representations. Args: config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.dataset.num_classes`) """ def __init__(self, config: Union[CommonArgs, Munch]): super(Classifier, self).__init__() # self.classifier = nn.Sequential(*( # [nn.Linear(config.model.dim_hidden, 2 * config.model.dim_ffn), nn.BatchNorm1d(2 * config.model.dim_ffn)] + # [nn.ReLU(), nn.Linear(2 * config.model.dim_ffn, config.dataset.num_classes)] # )) self.classifier = nn.Sequential(*( [nn.Linear(config.model.dim_hidden, config.dataset.num_classes)] ))
[docs] def forward(self, feat: Tensor) -> Tensor: r""" Applies a linear transformation to feature representations. Args: feat (Tensor): feature representations Returns (Tensor): label predictions """ return self.classifier(feat)