textbbx.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Copyright (c) 2019-present NAVER Corp.
  3. MIT License
  4. """
  5. # -*- coding: utf-8 -*-
  6. import sys
  7. import os
  8. import time
  9. import argparse
  10. import torch
  11. import torch.nn as nn
  12. import torch.backends.cudnn as cudnn
  13. from torch.autograd import Variable
  14. from PIL import Image
  15. import cv2
  16. from skimage import io
  17. import numpy as np
  18. import craft_utils
  19. import imgproc
  20. import file_utils
  21. import json
  22. import zipfile
  23. from craft import CRAFT
  24. from collections import OrderedDict
  25. def copyStateDict(state_dict):
  26. if list(state_dict.keys())[0].startswith("module"):
  27. start_idx = 1
  28. else:
  29. start_idx = 0
  30. new_state_dict = OrderedDict()
  31. for k, v in state_dict.items():
  32. name = ".".join(k.split(".")[start_idx:])
  33. new_state_dict[name] = v
  34. return new_state_dict
  35. def str2bool(v):
  36. return v.lower() in ("yes", "y", "true", "t", "1")
  37. parser = argparse.ArgumentParser(description='CRAFT Text Detection')
  38. parser.add_argument('--trained_model', default='D:\试卷切割\CRAFT-pytorch\weights\craft_mlt_25k.pth', type=str, help='pretrained model')
  39. parser.add_argument('--text_threshold', default=0.3, type=float, help='text confidence threshold')
  40. parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
  41. parser.add_argument('--link_threshold', default=0.6, type=float, help='link confidence threshold')
  42. parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
  43. parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
  44. parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
  45. parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
  46. parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
  47. parser.add_argument('--test_folder', default='./data/', type=str, help='folder path to input images')
  48. parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
  49. parser.add_argument('--refiner_model', default='D:\试卷切割\CRAFT-pytorch\weights\craft_refiner_CTW1500.pth', type=str,
  50. help='pretrained refiner model')
  51. args = parser.parse_args()
  52. """ For test images in a folder """
  53. image_list, _, _ = file_utils.get_files(args.test_folder)
  54. result_folder = './result/'
  55. if not os.path.isdir(result_folder):
  56. os.mkdir(result_folder)
  57. def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
  58. t0 = time.time()
  59. # resize
  60. img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size,
  61. interpolation=cv2.INTER_LINEAR,
  62. mag_ratio=args.mag_ratio)
  63. ratio_h = ratio_w = 1 / target_ratio
  64. # preprocessing
  65. x = imgproc.normalizeMeanVariance(img_resized)
  66. x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
  67. x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
  68. if cuda:
  69. x = x.cuda()
  70. # forward pass
  71. with torch.no_grad():
  72. y, feature = net(x)
  73. # make score and link map
  74. score_text = y[0, :, :, 0].cpu().data.numpy()
  75. score_link = y[0, :, :, 1].cpu().data.numpy()
  76. # refine link
  77. if refine_net is not None:
  78. with torch.no_grad():
  79. y_refiner = refine_net(y, feature)
  80. score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
  81. t0 = time.time() - t0
  82. t1 = time.time()
  83. # print('score', score_text.shape)
  84. # print('score_link', score_link)
  85. # Post-processing
  86. boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
  87. # coordinate adjustment
  88. boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
  89. polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
  90. for k in range(len(polys)):
  91. if polys[k] is None: polys[k] = boxes[k]
  92. t1 = time.time() - t1
  93. # render results (optional)
  94. render_img = score_text.copy()
  95. render_img = np.hstack((render_img, score_link))
  96. ret_score_text = imgproc.cvt2HeatmapImg(render_img)
  97. if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
  98. return boxes, polys, ret_score_text
  99. def getResult(boxes):
  100. Results = []
  101. for i, box in enumerate(boxes):
  102. poly = np.array(box).astype(np.int32).reshape((-1))
  103. Result = [p for p in poly]
  104. Results.append(Result)
  105. return Results
  106. net = CRAFT()
  107. print('Loading weights from checkpoint (' + args.trained_model + ')')
  108. if args.cuda:
  109. net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
  110. else:
  111. net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
  112. if args.cuda:
  113. net = net.cuda()
  114. net = torch.nn.DataParallel(net)
  115. cudnn.benchmark = False
  116. net.eval()
  117. # LinkRefiner
  118. refine_net = None
  119. if args.refine:
  120. from refinenet import RefineNet
  121. refine_net = RefineNet()
  122. print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
  123. if args.cuda:
  124. refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
  125. refine_net = refine_net.cuda()
  126. refine_net = torch.nn.DataParallel(refine_net)
  127. else:
  128. refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
  129. refine_net.eval()
  130. args.poly = True
  131. if __name__ == '__main__':
  132. # load net
  133. # net = CRAFT() # initialize
  134. t = time.time()
  135. # load data
  136. for k, image_path in enumerate(image_list):
  137. image = imgproc.loadImage(image_path)
  138. bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text,
  139. args.cuda, args.poly, refine_net)
  140. # save score text
  141. filename, file_ext = os.path.splitext(os.path.basename(image_path))
  142. mask_file = result_folder + "/res_" + filename + '_mask.jpg'
  143. cv2.imwrite(mask_file, score_text)
  144. file_utils.saveResult(image_path, image[:, :, ::-1], polys, dirname=result_folder)
  145. print('====>', getResult(polys))
  146. print("elapsed time : {}s".format(time.time() - t))