GOOD.utils.metric

A metric function module that is consist of a Metric class which incorporate many score and loss functions.

Classes

Metric()

Metric function module that is consist of a Metric class which incorporate many score and loss functions

class GOOD.utils.metric.Metric[source]

Bases: object

Metric function module that is consist of a Metric class which incorporate many score and loss functions

acc(y_true, y_pred)[source]

Calculate accuracy score

Parameters
  • y_true (torch.tensor) – input labels

  • y_pred (torch.tensor) – label predictions

Returns (float):

accuracy score

ap(y_true, y_pred)[source]

Calculate AP score

Parameters
  • y_true (torch.tensor) – input labels

  • y_pred (torch.tensor) – label predictions

Returns (float):

AP score

cross_entropy_with_logit(y_pred: Tensor, y_true: Tensor, **kwargs)[source]

Calculate cross entropy loss

Parameters
  • y_pred (torch.tensor) – label predictions

  • y_true (torch.tensor) – input labels

  • **kwargs – key word arguments for the use of cross_entropy()

Returns

cross entropy loss

f1(y_true, y_pred)[source]

Calculate F1 score

Parameters
  • y_true (torch.tensor) – input labels

  • y_pred (torch.tensor) – label predictions

Returns (float):

F1 score

reg_absolute_error(y_true, y_pred)[source]

Calculate absolute regression error

Parameters
  • y_true (torch.tensor) – input labels

  • y_pred (torch.tensor) – label predictions

Returns (float):

absolute regression error

rmse(y_true, y_pred)[source]

Calculate RMSE

Parameters
  • y_true (torch.tensor) – input labels

  • y_pred (torch.tensor) – label predictions

Returns (float):

RMSE

roc_auc_score(y_true, y_pred)[source]

Calculate roc_auc score

Parameters
  • y_true (torch.tensor) – input labels

  • y_pred (torch.tensor) – label predictions

Returns (float):

roc_auc score

set_loss_func(task_name)[source]

Set the loss function

Parameters

task_name (str) – name of task

Returns

None

set_score_func(metric_name)[source]

Set the metric function

Parameters

metric_name – name of metric

Returns

None