info_retrieval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. import time
  3. import math
  4. from collections import Counter
  5. import shutil
  6. import sqlite3
  7. from config import sqlite_path, sqlite_copy_path, log_msg
  8. from data_preprocessing import DataPreProcessing
  9. from word_segment import Word_Segment
  10. class Info_Retrieval():
  11. def __init__(self, data_process, logger=None, n_grams_flag=False):
  12. # 数据预处理实例化
  13. self.dpp = data_process
  14. # 日志配置
  15. self.logger = logger
  16. # 分词算法
  17. self.word_seg = Word_Segment(n_grams_flag=n_grams_flag)
  18. # bm25算法参数
  19. self.k1, self.b = 1.5, 0.75
  20. # sqlite数据库连接
  21. self.sqlite_path = sqlite_path
  22. self.sqlite_copy_path = sqlite_copy_path
  23. def __call__(self, sentence, topk=50):
  24. # 将搜索语句进行标准化清洗
  25. sentence = self.dpp.content_clear_func(sentence)
  26. # 将搜索语句分词
  27. seg_list, seg_init_list = self.word_seg(sentence)
  28. # 日志采集
  29. self.logger.info(log_msg.format(id="文本查重",
  30. type="info_retrieve分词",
  31. message=seg_list)) if self.logger else None
  32. # bm25算法分数计算词典
  33. scores_dict = dict()
  34. # "doc_data_statistics"存储文档总数和文档总长度
  35. doc_data = self.sqlite_fetch("doc_data_statistics")
  36. # 获取文档总数与平均文档长度
  37. doc_param = [int(doc_data[1]), int(doc_data[2]) / int(doc_data[1])]
  38. # 初始化召回文档列表和检索数据列表
  39. recall_doc_list, term_data_list = [], []
  40. # 遍历分词元素用于搜索相关数据
  41. for word in set(seg_list):
  42. # 通过sqlite数据库检索term相关数据
  43. term_data = self.sqlite_fetch(word)
  44. if term_data is not None:
  45. # 获取每个分词元素对应的文档列表和数据列表
  46. doc_id_list = str(term_data[1]).split('\n')
  47. term_data_list.append((doc_id_list, term_data[2]))
  48. recall_doc_list.extend(doc_id_list)
  49. # 计算召回文档列表中出现频率最高的文档集合
  50. recall_doc_set = set([ele[0] for ele in Counter(recall_doc_list).most_common(500)])
  51. # return [ele[0] for ele in Counter(recall_doc_list).most_common(topk)], seg_list
  52. # 遍历分词元素数据列表
  53. for term_data in term_data_list:
  54. # bm25算法计算倒排索引关键词对应文档得分
  55. scores_dict = self.BM25(scores_dict, term_data, doc_param, recall_doc_set)
  56. # 将分数词典进行排序并返回排序结果
  57. # scores_list = heap_sort(list(scores_dict.items()), topk)
  58. scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topk]
  59. # 对检索结果进行判断并取出排序后的id
  60. scores_sort_list = [ele[0] for ele in scores_list]
  61. return scores_sort_list, seg_init_list
  62. # bm25算法计算倒排索引关键词对应文档得分
  63. def BM25(self, scores_dict, term_data, doc_param, recall_doc_set):
  64. # 获取文档总数与平均文档长度
  65. doc_count, avg_doc_length = doc_param[0], doc_param[1]
  66. # 获取当前term对应的文档数量
  67. doc_id_list = term_data[0]
  68. doc_freq = len(doc_id_list)
  69. # 计算idf值
  70. idf_weight = math.log((doc_count - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
  71. term_docs = term_data[1].split('\n')
  72. # 遍历计算当前term对应的所有文档的得分
  73. for i,doc in enumerate(term_docs):
  74. # 获取文档id并判断是否在召回文档集合中
  75. doc_id = doc_id_list[i]
  76. if doc_id not in recall_doc_set:
  77. continue
  78. # 获取词频以及当前文档长度
  79. tf, dl = map(int, doc.split('\t'))
  80. # 计算bm25得分(idf值必须大于0)
  81. score = (idf_weight * self.k1 * tf) / (tf + self.k1 * (1 - self.b + self.b * dl / avg_doc_length))
  82. # 通过词典锁定文档id累加计分
  83. scores_dict[doc_id] = scores_dict.get(doc_id, 0) + score
  84. return scores_dict
  85. # 通过sqlite数据库检索term相关数据
  86. def sqlite_fetch(self, term):
  87. try:
  88. # 建立sqlite数据库链接
  89. sqlite_conn = sqlite3.connect(self.sqlite_path)
  90. # 创建游标对象cursor
  91. cursor = sqlite_conn.cursor()
  92. # 从sqlite读取数据
  93. cursor.execute("SELECT * FROM physics WHERE term=?", (term,))
  94. term_data = cursor.fetchone()
  95. # 关闭数据库连接
  96. cursor.close()
  97. sqlite_conn.close()
  98. return term_data
  99. except Exception as e:
  100. # 关闭已损坏失效的sqlite数据库连接
  101. cursor.close()
  102. sqlite_conn.close()
  103. # 删除损害失效的sqlite数据库
  104. if os.path.exists(self.sqlite_path):
  105. os.remove(self.sqlite_path)
  106. # 复制并重命名已备份sqlite数据库
  107. shutil.copy(self.sqlite_copy_path, self.sqlite_path)
  108. return self.sqlite_fetch(term)
  109. if __name__ == "__main__":
  110. sent_list = ["中华民族传统文化源远流长!下列诗句能体现“分子在不停地做无规则运动”的是( )<br/>A.大风起兮云飞扬<br/>B.柳絮飞时花满城<br/>C.满架蔷薇一院香<br/>D.秋雨梧桐叶落时"]
  111. data_process = DataPreProcessing()
  112. ir = Info_Retrieval(data_process, n_grams_flag=True)
  113. start = time.time()
  114. for sentence in sent_list:
  115. scores_sort_list = ir(sentence)
  116. print(time.time()-start)