bert_whitening.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import pickle
  2. import numpy as np
  3. from config import mongo_coll, whitening_path
  4. # 计算bert-whitening参数
  5. def compute_kernel_bias(vec_list):
  6. """计算kernel和bias
  7. vecs.shape = [num_samples, embedding_size]
  8. 最后的变换: y = (x + bias).dot(kernel)
  9. """
  10. vecs = np.array(vec_list)
  11. mu = vecs.mean(axis=0, keepdims=True)
  12. cov = np.cov(vecs.T)
  13. u, s, vh = np.linalg.svd(cov)
  14. W = np.dot(u, np.diag(1 / np.sqrt(s)))
  15. return W, -mu
  16. # 计算句向量列表
  17. def compute_vec_list(origin_dataset):
  18. vec_list = []
  19. for data in origin_dataset:
  20. if "sentence_vec" not in data:
  21. continue
  22. sentence_vec = pickle.loads(data["sentence_vec"])
  23. if sentence_vec.size != 384:
  24. continue
  25. vec_list.append(sentence_vec)
  26. return vec_list
  27. if __name__ == "__main__":
  28. origin_dataset = mongo_coll.find(no_cursor_timeout=True, batch_size=5)
  29. # 计算句向量列表
  30. vec_list = compute_vec_list(origin_dataset)
  31. # 计算bert-whitening参数
  32. kernel, bias = compute_kernel_bias(vec_list)
  33. # 保存bert-whitening参数
  34. with open(whitening_path, 'wb') as f:
  35. pickle.dump([kernel, bias], f)