123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from collections import defaultdict
- import numpy as np
- def calculate_metric(predict, gt):
- """
- 计算 tp fp fn tp+fn=100%
- """
- tp, fp, fn = 0, 0, 0
- for entity_predict in predict:
- flag = 0
- for entity_gt in gt:
- # if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]: # 判断列表和索引是否相同
- if entity_predict[1] == entity_gt[1]: # 只判断试题起始位置相同
- flag = 1
- tp += 1
- break
- if flag == 0:
- fp += 1 # 不是一道题,但预测为一道题
- fn = len(gt) - tp
- return np.array([tp, fp, fn])
- def get_p_r_f(tp, fp, fn):
- p = tp / (tp + fp) if tp + fp != 0 else 0
- r = tp / (tp + fn) if tp + fn != 0 else 0
- f1 = 2 * p * r / (p + r) if p + r != 0 else 0
- return np.array([p, r, f1])
- def classification_report(metrics_matrix, label_list, id2label, total_count, digits=2, suffix=False):
- name_width = max([len(label) for label in label_list])
- last_line_heading = 'micro-f1'
- width = max(name_width, len(last_line_heading), digits)
- headers = ["precision", "recall", "f1-score", "support"]
- head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
- report = head_fmt.format(u'', *headers, width=width)
- report += u'\n\n'
- row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'
- ps, rs, f1s, s = [], [], [], []
- for label_id, label_matrix in enumerate(metrics_matrix):
- type_name = id2label[label_id]
- p, r, f1 = get_p_r_f(label_matrix[0], label_matrix[1], label_matrix[2])
- nb_true = total_count[label_id]
- report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)
- ps.append(p)
- rs.append(r)
- f1s.append(f1)
- s.append(nb_true)
- report += u'\n'
- mirco_metrics = np.sum(metrics_matrix, axis=0)
- mirco_metrics = get_p_r_f(mirco_metrics[0], mirco_metrics[1], mirco_metrics[2])
- # compute averages
- report += row_fmt.format(last_line_heading,
- mirco_metrics[0],
- mirco_metrics[1],
- mirco_metrics[2],
- np.sum(s),
- width=width, digits=digits)
- return report
|