123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- """
- Copyright (c) 2019-present NAVER Corp.
- MIT License
- """
- # -*- coding: utf-8 -*-
- import sys
- import os
- import time
- import argparse
- import torch
- import torch.nn as nn
- import torch.backends.cudnn as cudnn
- from torch.autograd import Variable
- from PIL import Image
- import cv2
- from skimage import io
- import numpy as np
- import craft_utils
- import imgproc
- import file_utils
- import json
- import zipfile
- from craft import CRAFT
- from collections import OrderedDict
- def copyStateDict(state_dict):
- if list(state_dict.keys())[0].startswith("module"):
- start_idx = 1
- else:
- start_idx = 0
- new_state_dict = OrderedDict()
- for k, v in state_dict.items():
- name = ".".join(k.split(".")[start_idx:])
- new_state_dict[name] = v
- return new_state_dict
- def str2bool(v):
- return v.lower() in ("yes", "y", "true", "t", "1")
- parser = argparse.ArgumentParser(description='CRAFT Text Detection')
- parser.add_argument('--trained_model', default='D:\试卷切割\CRAFT-pytorch\weights\craft_mlt_25k.pth', type=str, help='pretrained model')
- parser.add_argument('--text_threshold', default=0.3, type=float, help='text confidence threshold')
- parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
- parser.add_argument('--link_threshold', default=0.6, type=float, help='link confidence threshold')
- parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
- parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
- parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
- parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
- parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
- parser.add_argument('--test_folder', default='./data/', type=str, help='folder path to input images')
- parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
- parser.add_argument('--refiner_model', default='D:\试卷切割\CRAFT-pytorch\weights\craft_refiner_CTW1500.pth', type=str,
- help='pretrained refiner model')
- args = parser.parse_args()
- """ For test images in a folder """
- image_list, _, _ = file_utils.get_files(args.test_folder)
- result_folder = './result/'
- if not os.path.isdir(result_folder):
- os.mkdir(result_folder)
- def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
- t0 = time.time()
- # resize
- img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size,
- interpolation=cv2.INTER_LINEAR,
- mag_ratio=args.mag_ratio)
- ratio_h = ratio_w = 1 / target_ratio
- # preprocessing
- x = imgproc.normalizeMeanVariance(img_resized)
- x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
- x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
- if cuda:
- x = x.cuda()
- # forward pass
- with torch.no_grad():
- y, feature = net(x)
- # make score and link map
- score_text = y[0, :, :, 0].cpu().data.numpy()
- score_link = y[0, :, :, 1].cpu().data.numpy()
- # refine link
- if refine_net is not None:
- with torch.no_grad():
- y_refiner = refine_net(y, feature)
- score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
- t0 = time.time() - t0
- t1 = time.time()
- # print('score', score_text.shape)
- # print('score_link', score_link)
- # Post-processing
- boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
- # coordinate adjustment
- boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
- polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
- for k in range(len(polys)):
- if polys[k] is None: polys[k] = boxes[k]
- t1 = time.time() - t1
- # render results (optional)
- render_img = score_text.copy()
- render_img = np.hstack((render_img, score_link))
- ret_score_text = imgproc.cvt2HeatmapImg(render_img)
- if args.show_time: print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
- return boxes, polys, ret_score_text
- def getResult(boxes):
- Results = []
- for i, box in enumerate(boxes):
- poly = np.array(box).astype(np.int32).reshape((-1))
- Result = [p for p in poly]
- Results.append(Result)
- return Results
- net = CRAFT()
- print('Loading weights from checkpoint (' + args.trained_model + ')')
- if args.cuda:
- net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
- else:
- net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
- if args.cuda:
- net = net.cuda()
- net = torch.nn.DataParallel(net)
- cudnn.benchmark = False
- net.eval()
- # LinkRefiner
- refine_net = None
- if args.refine:
- from refinenet import RefineNet
- refine_net = RefineNet()
- print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
- if args.cuda:
- refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
- refine_net = refine_net.cuda()
- refine_net = torch.nn.DataParallel(refine_net)
- else:
- refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
- refine_net.eval()
- args.poly = True
- if __name__ == '__main__':
- # load net
- # net = CRAFT() # initialize
- t = time.time()
- # load data
- for k, image_path in enumerate(image_list):
- image = imgproc.loadImage(image_path)
- bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text,
- args.cuda, args.poly, refine_net)
- # save score text
- filename, file_ext = os.path.splitext(os.path.basename(image_path))
- mask_file = result_folder + "/res_" + filename + '_mask.jpg'
- cv2.imwrite(mask_file, score_text)
- file_utils.saveResult(image_path, image[:, :, ::-1], polys, dirname=result_folder)
- print('====>', getResult(polys))
- print("elapsed time : {}s".format(time.time() - t))
|