import os import time import math from collections import Counter import shutil import sqlite3 from config import sqlite_path, sqlite_copy_path, log_msg from data_preprocessing import DataPreProcessing from word_segment import Word_Segment from heap_sort import heap_sort class Info_Retrieval(): def __init__(self, data_process, logger=None, n_grams_flag=False): # 数据预处理实例化 self.dpp = data_process # 日志配置 self.logger = logger # 分词算法 self.word_seg = Word_Segment(n_grams_flag=n_grams_flag) # bm25算法参数 self.k1, self.b = 1.5, 0.75 # sqlite数据库连接 self.sqlite_path = sqlite_path self.sqlite_copy_path = sqlite_copy_path def __call__(self, sentence, topk=50): # 将搜索语句进行标准化清洗 sentence = self.dpp.content_clear_func(sentence) # 将搜索语句分词 seg_list, seg_init_list = self.word_seg(sentence) # 日志采集 self.logger.info(log_msg.format(id="文本查重", type="info_retrieve分词", message=seg_list)) if self.logger else None # bm25算法分数计算词典 scores_dict = dict() # "doc_data_statistics"存储文档总数和文档总长度 doc_data = self.sqlite_fetch("doc_data_statistics") # 获取文档总数与平均文档长度 doc_param = [int(doc_data[1]), int(doc_data[2]) / int(doc_data[1])] # 初始化召回文档列表和检索数据列表 recall_doc_list, term_data_list = [], [] # 遍历分词元素用于搜索相关数据 for word in set(seg_list): # 通过sqlite数据库检索term相关数据 term_data = self.sqlite_fetch(word) if term_data is not None: # 获取每个分词元素对应的文档列表和数据列表 doc_id_list = str(term_data[1]).split('\n') term_data_list.append((doc_id_list, term_data[2])) recall_doc_list.extend(doc_id_list) # 计算召回文档列表中出现频率最高的文档集合 recall_doc_set = set([ele[0] for ele in Counter(recall_doc_list).most_common(500)]) # return [ele[0] for ele in Counter(recall_doc_list).most_common(topk)], seg_list # 遍历分词元素数据列表 for term_data in term_data_list: # bm25算法计算倒排索引关键词对应文档得分 scores_dict = self.BM25(scores_dict, term_data, doc_param, recall_doc_set) # 将分数词典进行排序并返回排序结果 # scores_list = heap_sort(list(scores_dict.items()), topk) scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topk] # 对检索结果进行判断并取出排序后的id scores_sort_list = [ele[0] for ele in scores_list] return scores_sort_list, seg_init_list # bm25算法计算倒排索引关键词对应文档得分 def BM25(self, scores_dict, term_data, doc_param, recall_doc_set): # 获取文档总数与平均文档长度 doc_count, avg_doc_length = doc_param[0], doc_param[1] # 获取当前term对应的文档数量 doc_id_list = term_data[0] doc_freq = len(doc_id_list) # 计算idf值 idf_weight = math.log((doc_count - doc_freq + 0.5) / (doc_freq + 0.5) + 1) term_docs = term_data[1].split('\n') # 遍历计算当前term对应的所有文档的得分 for i,doc in enumerate(term_docs): # 获取文档id并判断是否在召回文档集合中 doc_id = doc_id_list[i] if doc_id not in recall_doc_set: continue # 获取词频以及当前文档长度 tf, dl = map(int, doc.split('\t')) # 计算bm25得分(idf值必须大于0) score = (idf_weight * self.k1 * tf) / (tf + self.k1 * (1 - self.b + self.b * dl / avg_doc_length)) # 通过词典锁定文档id累加计分 scores_dict[doc_id] = scores_dict.get(doc_id, 0) + score return scores_dict # 通过sqlite数据库检索term相关数据 def sqlite_fetch(self, term): try: # 建立sqlite数据库链接 sqlite_conn = sqlite3.connect(self.sqlite_path) # 创建游标对象cursor cursor = sqlite_conn.cursor() # 从sqlite读取数据 cursor.execute("SELECT * FROM physics WHERE term=?", (term,)) term_data = cursor.fetchone() # 关闭数据库连接 cursor.close() sqlite_conn.close() return term_data except Exception as e: # 关闭已损坏失效的sqlite数据库连接 cursor.close() sqlite_conn.close() # 删除损害失效的sqlite数据库 if os.path.exists(self.sqlite_path): os.remove(self.sqlite_path) # 复制并重命名已备份sqlite数据库 shutil.copy(self.sqlite_copy_path, self.sqlite_path) return self.sqlite_fetch(term) if __name__ == "__main__": sent_list = ["中华民族传统文化源远流长!下列诗句能体现“分子在不停地做无规则运动”的是( )
A.大风起兮云飞扬
B.柳絮飞时花满城
C.满架蔷薇一院香
D.秋雨梧桐叶落时"] data_process = DataPreProcessing() ir = Info_Retrieval(data_process, n_grams_flag=True) start = time.time() for sentence in sent_list: scores_sort_list = ir(sentence) print(time.time()-start)