metrics.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from collections import defaultdict
  5. import numpy as np
  6. def calculate_metric(predict, gt):
  7. """
  8. 计算 tp fp fn tp+fn=100%
  9. """
  10. tp, fp, fn = 0, 0, 0
  11. for entity_predict in predict:
  12. flag = 0
  13. for entity_gt in gt:
  14. # if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]: # 判断列表和索引是否相同
  15. if entity_predict[1] == entity_gt[1]: # 只判断试题起始位置相同
  16. flag = 1
  17. tp += 1
  18. break
  19. if flag == 0:
  20. fp += 1 # 不是一道题,但预测为一道题
  21. fn = len(gt) - tp
  22. return np.array([tp, fp, fn])
  23. def get_p_r_f(tp, fp, fn):
  24. p = tp / (tp + fp) if tp + fp != 0 else 0
  25. r = tp / (tp + fn) if tp + fn != 0 else 0
  26. f1 = 2 * p * r / (p + r) if p + r != 0 else 0
  27. return np.array([p, r, f1])
  28. def classification_report(metrics_matrix, label_list, id2label, total_count, digits=2, suffix=False):
  29. name_width = max([len(label) for label in label_list])
  30. last_line_heading = 'micro-f1'
  31. width = max(name_width, len(last_line_heading), digits)
  32. headers = ["precision", "recall", "f1-score", "support"]
  33. head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
  34. report = head_fmt.format(u'', *headers, width=width)
  35. report += u'\n\n'
  36. row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'
  37. ps, rs, f1s, s = [], [], [], []
  38. for label_id, label_matrix in enumerate(metrics_matrix):
  39. type_name = id2label[label_id]
  40. p, r, f1 = get_p_r_f(label_matrix[0], label_matrix[1], label_matrix[2])
  41. nb_true = total_count[label_id]
  42. report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)
  43. ps.append(p)
  44. rs.append(r)
  45. f1s.append(f1)
  46. s.append(nb_true)
  47. report += u'\n'
  48. mirco_metrics = np.sum(metrics_matrix, axis=0)
  49. mirco_metrics = get_p_r_f(mirco_metrics[0], mirco_metrics[1], mirco_metrics[2])
  50. # compute averages
  51. report += row_fmt.format(last_line_heading,
  52. mirco_metrics[0],
  53. mirco_metrics[1],
  54. mirco_metrics[2],
  55. np.sum(s),
  56. width=width, digits=digits)
  57. return report