import pickle import numpy as np from config import mongo_coll, whitening_path # 计算bert-whitening参数 def compute_kernel_bias(vec_list): """计算kernel和bias vecs.shape = [num_samples, embedding_size] 最后的变换: y = (x + bias).dot(kernel) """ vecs = np.array(vec_list) mu = vecs.mean(axis=0, keepdims=True) cov = np.cov(vecs.T) u, s, vh = np.linalg.svd(cov) W = np.dot(u, np.diag(1 / np.sqrt(s))) return W, -mu # 计算句向量列表 def compute_vec_list(origin_dataset): vec_list = [] for data in origin_dataset: if "sentence_vec" not in data: continue sentence_vec = pickle.loads(data["sentence_vec"]) if sentence_vec.size != 384: continue vec_list.append(sentence_vec) return vec_list if __name__ == "__main__": origin_dataset = mongo_coll.find(no_cursor_timeout=True, batch_size=5) # 计算句向量列表 vec_list = compute_vec_list(origin_dataset) # 计算bert-whitening参数 kernel, bias = compute_kernel_bias(vec_list) # 保存bert-whitening参数 with open(whitening_path, 'wb') as f: pickle.dump([kernel, bias], f)