r"""
Applies a linear transformation to complete classification from representations.
"""
import torch
import torch.nn as nn
from torch import Tensor
from GOOD.utils.config_reader import Union, CommonArgs, Munch
[docs]class Classifier(torch.nn.Module):
r"""
Applies a linear transformation to complete classification from representations.
Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.dataset.num_classes`)
"""
def __init__(self, config: Union[CommonArgs, Munch]):
super(Classifier, self).__init__()
# self.classifier = nn.Sequential(*(
# [nn.Linear(config.model.dim_hidden, 2 * config.model.dim_ffn), nn.BatchNorm1d(2 * config.model.dim_ffn)] +
# [nn.ReLU(), nn.Linear(2 * config.model.dim_ffn, config.dataset.num_classes)]
# ))
self.classifier = nn.Sequential(*(
[nn.Linear(config.model.dim_hidden, config.dataset.num_classes)]
))
[docs] def forward(self, feat: Tensor) -> Tensor:
r"""
Applies a linear transformation to feature representations.
Args:
feat (Tensor): feature representations
Returns (Tensor):
label predictions
"""
return self.classifier(feat)