Source code for GOOD.utils.metric

r"""A metric function module that is consist of a Metric class which incorporate many score and loss functions.
"""

from math import sqrt

import torch
from sklearn.metrics import roc_auc_score as sk_roc_auc, mean_squared_error, \
    accuracy_score, average_precision_score, mean_absolute_error, f1_score
from torch.nn.functional import cross_entropy, l1_loss, binary_cross_entropy_with_logits





[docs]class Metric(object): r""" Metric function module that is consist of a Metric class which incorporate many score and loss functions """ def __init__(self): self.task2loss = { 'Binary classification': binary_cross_entropy_with_logits, 'Multi-label classification': self.cross_entropy_with_logit, 'Regression': l1_loss } self.score_name2score = { 'RMSE': self.rmse, 'MAE': mean_absolute_error, 'Average Precision': self.ap, 'F1': self.f1, 'ROC-AUC': self.roc_auc_score, 'Accuracy': self.acc, } self.loss_func = self.cross_entropy_with_logit self.score_func = self.roc_auc_score self.dataset_task = '' self.score_name = '' self.lower_better = -1 self.best_stat = {'score': None, 'loss': float('inf')} self.id_best_stat = {'score': None, 'loss': float('inf')}
[docs] def set_loss_func(self, task_name): r""" Set the loss function Args: task_name (str): name of task Returns: None """ self.dataset_task = task_name self.loss_func = self.task2loss.get(task_name) assert self.loss_func is not None
[docs] def set_score_func(self, metric_name): r""" Set the metric function Args: metric_name: name of metric Returns: None """ self.score_func = self.score_name2score.get(metric_name) assert self.score_func is not None self.score_name = metric_name.upper() if self.score_name in ['RMSE', 'MAE']: self.lower_better = 1 else: self.lower_better = -1
[docs] def f1(self, y_true, y_pred): r""" Calculate F1 score Args: y_true (torch.tensor): input labels y_pred (torch.tensor): label predictions Returns (float): F1 score """ true = torch.tensor(y_true) pred_label = torch.tensor(y_pred) pred_label = pred_label.round() if self.dataset_task == "Binary classification" else torch.argmax(pred_label, dim=1) return f1_score(true, pred_label, average='micro')
[docs] def ap(self, y_true, y_pred): r""" Calculate AP score Args: y_true (torch.tensor): input labels y_pred (torch.tensor): label predictions Returns (float): AP score """ return average_precision_score(torch.tensor(y_true).long(), torch.tensor(y_pred))
[docs] def roc_auc_score(self, y_true, y_pred): r""" Calculate roc_auc score Args: y_true (torch.tensor): input labels y_pred (torch.tensor): label predictions Returns (float): roc_auc score """ return sk_roc_auc(torch.tensor(y_true).long(), torch.tensor(y_pred), multi_class='ovo')
[docs] def reg_absolute_error(self, y_true, y_pred): r""" Calculate absolute regression error Args: y_true (torch.tensor): input labels y_pred (torch.tensor): label predictions Returns (float): absolute regression error """ return mean_absolute_error(torch.tensor(y_true), torch.tensor(y_pred))
[docs] def acc(self, y_true, y_pred): r""" Calculate accuracy score Args: y_true (torch.tensor): input labels y_pred (torch.tensor): label predictions Returns (float): accuracy score """ true = torch.tensor(y_true) pred_label = torch.tensor(y_pred) pred_label = pred_label.round() if self.dataset_task == "Binary classification" else torch.argmax(pred_label, dim=1) return accuracy_score(true, pred_label)
[docs] def rmse(self, y_true, y_pred): r""" Calculate RMSE Args: y_true (torch.tensor): input labels y_pred (torch.tensor): label predictions Returns (float): RMSE """ return sqrt(mean_squared_error(y_true, y_pred))
[docs] def cross_entropy_with_logit(self, y_pred: torch.Tensor, y_true: torch.Tensor, **kwargs): r""" Calculate cross entropy loss Args: y_pred (torch.tensor): label predictions y_true (torch.tensor): input labels **kwargs: key word arguments for the use of :func:`~torch.nn.functional.cross_entropy` Returns: cross entropy loss """ return cross_entropy(y_pred, y_true.long(), **kwargs)