# -*- coding: utf-8 -*-
# Author: Zhi Qiao <mingshan_ai@163.com>
# License: BSD 2 clause
import numpy as np
import pickle
import os
from ..utils.check import label_check
from .binaryclass import evaluator as binary_eval
from .multilabel import evaluator as multilabel_eval
from .multilabel import evaluator as multiclass_eval
[docs]def check_evalu_type(hat_y, y):
try:
hat_y = np.array(hat_y).astype(float)
y = np.array(y).astype(float)
except:
raise Exception('not support current data type of hat_y, y')
_shape_hat_y, _shape_y = np.shape(hat_y), np.shape(y)
if _shape_hat_y != _shape_y:
raise Exception('the data shape is not inconformity between y and hey_y')
label_n_check = set([])
label_item_set = set([])
label_row_set = set([])
for each_y_path in y:
label_n_check.add(len(np.array(each_y_path)))
label_item_set.update(np.array(each_y_path).astype(int).tolist())
label_row_set.add(sum(np.array(each_y_path).astype(int)))
if len(label_n_check) != 1:
raise Exception('label_n is inconformity in data')
if len(label_item_set) <= 1:
raise Exception('value space size <=1 is unvalid')
elif len(label_item_set) == 2:
if 0 in label_item_set and 1 in label_item_set:
if list(label_n_check)[0] == 1:
evalu_type = 'binaryclass'
else:
if max(label_row_set) == 1:
evalu_type = 'multiclass'
else:
evalu_type = 'multilabel'
else:
raise Exception('odd value exist in label value space')
else:
if list(label_n_check)[0] == 1:
evalu_type = 'regression'
else:
raise Exception('odd value exist in label value space')
return evalu_type
evalu_func_mapping_dict = {
'binaryclass': binary_eval,
'multilabel': multilabel_eval,
'multiclass': multiclass_eval,
'regression': None
}
[docs]def func(hat_y, y, evalu_type = None):
evalu_type = label_check(y, hat_y, evalu_type)
print ('current data evaluate using {0} evaluation-type'.format(evalu_type))
evalu_func = evalu_func_mapping_dict[evalu_type]
return evalu_func(hat_y, y)
if __name__ == '__main__':
y = np.array([0.,1.])
hat_y = np.array([[0.3],[0.8]])
z = func(hat_y, y)
print (z)
y = np.array([[0., 1., 0.],[1., 0., 1.]])
hat_y = np.array([[0.3, 0.7, 0.1],[0.1, 0.2, 0.8]])
z = func(hat_y, y)
print (z)