test.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """
  2. Copyright (c) 2019-present NAVER Corp.
  3. MIT License
  4. """
  5. # -*- coding: utf-8 -*-
  6. import sys
  7. import os
  8. import time
  9. import argparse
  10. import torch
  11. import torch.nn as nn
  12. import torch.backends.cudnn as cudnn
  13. from torch.autograd import Variable
  14. from PIL import Image
  15. import cv2
  16. from skimage import io
  17. import numpy as np
  18. import craft_utils
  19. import imgproc
  20. import file_utils
  21. import json
  22. import zipfile
  23. from craft import CRAFT
  24. from collections import OrderedDict
  25. def copyStateDict(state_dict):
  26. if list(state_dict.keys())[0].startswith("module"):
  27. start_idx = 1
  28. else:
  29. start_idx = 0
  30. new_state_dict = OrderedDict()
  31. for k, v in state_dict.items():
  32. name = ".".join(k.split(".")[start_idx:])
  33. new_state_dict[name] = v
  34. return new_state_dict
  35. def str2bool(v):
  36. return v.lower() in ("yes", "y", "true", "t", "1")
  37. parser = argparse.ArgumentParser(description='CRAFT Text Detection')
  38. parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
  39. parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
  40. parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
  41. parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
  42. parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
  43. parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
  44. parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
  45. parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
  46. parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
  47. parser.add_argument('--test_folder', default='./data/', type=str, help='folder path to input images')
  48. parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
  49. parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
  50. args = parser.parse_args()
  51. """ For test images in a folder """
  52. image_list, _, _ = file_utils.get_files(args.test_folder)
  53. result_folder = './result/'
  54. if not os.path.isdir(result_folder):
  55. os.mkdir(result_folder)
  56. def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
  57. t0 = time.time()
  58. # resize
  59. img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
  60. ratio_h = ratio_w = 1 / target_ratio
  61. # preprocessing
  62. x = imgproc.normalizeMeanVariance(img_resized)
  63. x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
  64. x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
  65. if cuda:
  66. x = x.cuda()
  67. # forward pass
  68. with torch.no_grad():
  69. y, feature = net(x)
  70. # make score and link map
  71. score_text = y[0,:,:,0].cpu().data.numpy()
  72. score_link = y[0,:,:,1].cpu().data.numpy()
  73. # refine link
  74. if refine_net is not None:
  75. with torch.no_grad():
  76. y_refiner = refine_net(y, feature)
  77. score_link = y_refiner[0,:,:,0].cpu().data.numpy()
  78. t0 = time.time() - t0
  79. t1 = time.time()
  80. # Post-processing
  81. boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
  82. # coordinate adjustment
  83. boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
  84. polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
  85. for k in range(len(polys)):
  86. if polys[k] is None: polys[k] = boxes[k]
  87. t1 = time.time() - t1
  88. # render results (optional)
  89. render_img = score_text.copy()
  90. render_img = np.hstack((render_img, score_link))
  91. ret_score_text = imgproc.cvt2HeatmapImg(render_img)
  92. if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
  93. return boxes, polys, ret_score_text
  94. if __name__ == '__main__':
  95. # load net
  96. net = CRAFT() # initialize
  97. print('Loading weights from checkpoint (' + args.trained_model + ')')
  98. if args.cuda:
  99. net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
  100. else:
  101. net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
  102. if args.cuda:
  103. net = net.cuda()
  104. net = torch.nn.DataParallel(net)
  105. cudnn.benchmark = False
  106. net.eval()
  107. # LinkRefiner
  108. refine_net = None
  109. if args.refine:
  110. from refinenet import RefineNet
  111. refine_net = RefineNet()
  112. print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
  113. if args.cuda:
  114. refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
  115. refine_net = refine_net.cuda()
  116. refine_net = torch.nn.DataParallel(refine_net)
  117. else:
  118. refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
  119. refine_net.eval()
  120. args.poly = True
  121. t = time.time()
  122. # load data
  123. for k, image_path in enumerate(image_list):
  124. print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
  125. image = imgproc.loadImage(image_path)
  126. bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
  127. # save score text
  128. filename, file_ext = os.path.splitext(os.path.basename(image_path))
  129. mask_file = result_folder + "/res_" + filename + '_mask.jpg'
  130. cv2.imwrite(mask_file, score_text)
  131. file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
  132. print("elapsed time : {}s".format(time.time() - t))