info_retrieval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import time
  3. import math
  4. from collections import Counter
  5. import shutil
  6. import sqlite3
  7. from config import sqlite_path, sqlite_copy_path, log_msg
  8. from data_preprocessing import DataPreProcessing
  9. from word_segment import Word_Segment
  10. from heap_sort import heap_sort
  11. class Info_Retrieval():
  12. def __init__(self, data_process, logger=None, n_grams_flag=False):
  13. # 数据预处理实例化
  14. self.dpp = data_process
  15. # 日志配置
  16. self.logger = logger
  17. # 分词算法
  18. self.word_seg = Word_Segment(n_grams_flag=n_grams_flag)
  19. # bm25算法参数
  20. self.k1, self.b = 1.5, 0.75
  21. # sqlite数据库连接
  22. self.sqlite_path = sqlite_path
  23. self.sqlite_copy_path = sqlite_copy_path
  24. def __call__(self, sentence, topK=50):
  25. # 将搜索语句进行标准化清洗
  26. sentence = self.dpp.content_clear_func(sentence)
  27. # 将搜索语句分词
  28. seg_list = self.word_seg(sentence)
  29. # 日志采集
  30. self.logger.info(log_msg.format(id="文本查重",
  31. type="info_retrieve分词",
  32. message=seg_list)) if self.logger else None
  33. # bm25算法分数计算词典
  34. scores_dict = dict()
  35. # "doc_data_statistics"存储文档总数和文档总长度
  36. doc_data = self.sqlite_fetch("doc_data_statistics")
  37. # 获取文档总数与平均文档长度
  38. doc_param = [int(doc_data[1]), int(doc_data[2]) / int(doc_data[1])]
  39. # 初始化召回文档列表和检索数据列表
  40. recall_doc_list, term_data_list = [], []
  41. # 遍历分词元素用于搜索相关数据
  42. for word in set(seg_list):
  43. # 通过sqlite数据库检索term相关数据
  44. term_data = self.sqlite_fetch(word)
  45. if term_data is not None:
  46. # 获取每个分词元素对应的文档列表和数据列表
  47. doc_id_list = str(term_data[1]).split('\n')
  48. term_data_list.append((doc_id_list, term_data[2]))
  49. recall_doc_list.extend(doc_id_list)
  50. # 计算召回文档列表中出现频率最高的文档集合
  51. recall_doc_set = set([ele[0] for ele in Counter(recall_doc_list).most_common(500)])
  52. # return [ele[0] for ele in Counter(recall_doc_list).most_common(topK)], seg_list
  53. # 遍历分词元素数据列表
  54. for term_data in term_data_list:
  55. # bm25算法计算倒排索引关键词对应文档得分
  56. scores_dict = self.BM25(scores_dict, term_data, doc_param, recall_doc_set)
  57. # 将分数词典进行排序并返回排序结果
  58. # scores_list = heap_sort(list(scores_dict.items()), topK)
  59. scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topK]
  60. # 对检索结果进行判断并取出排序后的id
  61. scores_sort_list = [ele[0] for ele in scores_list]
  62. return scores_sort_list, seg_list
  63. # bm25算法计算倒排索引关键词对应文档得分
  64. def BM25(self, scores_dict, term_data, doc_param, recall_doc_set):
  65. # 获取文档总数与平均文档长度
  66. doc_count, avg_doc_length = doc_param[0], doc_param[1]
  67. # 获取当前term对应的文档数量
  68. doc_id_list = term_data[0]
  69. doc_freq = len(doc_id_list)
  70. # 计算idf值
  71. idf_weight = math.log((doc_count - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
  72. term_docs = term_data[1].split('\n')
  73. # 遍历计算当前term对应的所有文档的得分
  74. for i,doc in enumerate(term_docs):
  75. # 获取文档id并判断是否在召回文档集合中
  76. doc_id = doc_id_list[i]
  77. if doc_id not in recall_doc_set:
  78. continue
  79. # 获取词频以及当前文档长度
  80. tf, dl = map(int, doc.split('\t'))
  81. # 计算bm25得分(idf值必须大于0)
  82. score = (idf_weight * self.k1 * tf) / (tf + self.k1 * (1 - self.b + self.b * dl / avg_doc_length))
  83. # 通过词典锁定文档id累加计分
  84. scores_dict[doc_id] = scores_dict.get(doc_id, 0) + score
  85. return scores_dict
  86. # 通过sqlite数据库检索term相关数据
  87. def sqlite_fetch(self, term):
  88. try:
  89. # 建立sqlite数据库链接
  90. sqlite_conn = sqlite3.connect(self.sqlite_path)
  91. # 创建游标对象cursor
  92. cursor = sqlite_conn.cursor()
  93. # 从sqlite读取数据
  94. cursor.execute("SELECT * FROM physics WHERE term=?", (term,))
  95. term_data = cursor.fetchone()
  96. # 关闭数据库连接
  97. cursor.close()
  98. sqlite_conn.close()
  99. return term_data
  100. except Exception as e:
  101. # 关闭已损坏失效的sqlite数据库连接
  102. cursor.close()
  103. sqlite_conn.close()
  104. # 删除损害失效的sqlite数据库
  105. if os.path.exists(self.sqlite_path):
  106. os.remove(self.sqlite_path)
  107. # 复制并重命名已备份sqlite数据库
  108. shutil.copy(self.sqlite_copy_path, self.sqlite_path)
  109. return self.sqlite_fetch(term)
  110. if __name__ == "__main__":
  111. sent_list = ["中华民族传统文化源远流长!下列诗句能体现“分子在不停地做无规则运动”的是( )<br/>A.大风起兮云飞扬<br/>B.柳絮飞时花满城<br/>C.满架蔷薇一院香<br/>D.秋雨梧桐叶落时"]
  112. data_process = DataPreProcessing()
  113. ir = Info_Retrieval(data_process, n_grams_flag=True)
  114. start = time.time()
  115. for sentence in sent_list:
  116. scores_sort_list = ir(sentence)
  117. print(time.time()-start)