12345678910111213141516171819202122232425262728293031323334353637383940 |
- import pickle
- import numpy as np
- from config import mongo_coll, whitening_path
- # 计算bert-whitening参数
- def compute_kernel_bias(vec_list):
- """计算kernel和bias
- vecs.shape = [num_samples, embedding_size]
- 最后的变换: y = (x + bias).dot(kernel)
- """
- vecs = np.array(vec_list)
- mu = vecs.mean(axis=0, keepdims=True)
- cov = np.cov(vecs.T)
- u, s, vh = np.linalg.svd(cov)
- W = np.dot(u, np.diag(1 / np.sqrt(s)))
- return W, -mu
- # 计算句向量列表
- def compute_vec_list(origin_dataset):
- vec_list = []
- for data in origin_dataset:
- if "sentence_vec" not in data:
- continue
- sentence_vec = pickle.loads(data["sentence_vec"])
- if sentence_vec.size != 384:
- continue
- vec_list.append(sentence_vec)
-
- return vec_list
- if __name__ == "__main__":
- origin_dataset = mongo_coll.find(no_cursor_timeout=True, batch_size=5)
- # 计算句向量列表
- vec_list = compute_vec_list(origin_dataset)
- # 计算bert-whitening参数
- kernel, bias = compute_kernel_bias(vec_list)
- # 保存bert-whitening参数
- with open(whitening_path, 'wb') as f:
- pickle.dump([kernel, bias], f)
|