GOOD.utils.evaluation

Evaluation: model evaluation functions.

Functions

eval_data_preprocess(y, raw_pred, mask, config)

Preprocess data for evaluations by converting data into np.ndarray or List[np.ndarray] (Multi-task) format.

eval_score(pred_all, target_all, config)

Calculate metric scores given preprocessed prediction values and ground truth values.

GOOD.utils.evaluation.eval_data_preprocess(y: Tensor, raw_pred: Tensor, mask: Tensor, config: Union[CommonArgs, Munch]) Tuple[Union[ndarray, List], Union[ndarray, List]][source]

Preprocess data for evaluations by converting data into np.ndarray or List[np.ndarray] (Multi-task) format. When the task of the dataset is not multi-task, data is converted into np.ndarray. When it is multi-task, data is converted into List[np.ndarray] in which each np.ndarray in the list represents one task. For example, GOOD-PCBA is a 128-task binary classification dataset. Therefore, the output list will contain 128 elements.

Parameters
  • y (torch.Tensor) – Ground truth values.

  • raw_pred (torch.Tensor) – Raw prediction values without softmax or sigmoid.

  • mask (torch.Tensor) – Ground truth NAN mask for removing empty label.

  • config (Union[CommonArgs, Munch]) – The required config is config.metric.dataset_task

Returns

Processed prediction values and ground truth values.

GOOD.utils.evaluation.eval_score(pred_all: Union[List[ndarray], List[List[ndarray]]], target_all: Union[List[ndarray], List[List[ndarray]]], config: Union[CommonArgs, Munch]) Union[ndarray, float][source]

Calculate metric scores given preprocessed prediction values and ground truth values.

Parameters
  • pred_all (Union[List[np.ndarray], List[List[np.ndarray]]]) – Prediction value list. It is a list of output pred of eval_data_preprocess().

  • target_all (Union[List[np.ndarray], List[List[np.ndarray]]]) – Ground truth value list. It is a list of output target of eval_data_preprocess().

  • config (Union[CommonArgs, Munch]) – The required config is config.metric.score_func that is a function for score calculation (e.g., GOOD.utils.metric.Metric.acc()).

Returns

A float score value.