file_utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import numpy as np
  4. import cv2
  5. import imgproc
  6. # borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py
  7. def get_files(img_dir):
  8. imgs, masks, xmls = list_files(img_dir)
  9. return imgs, masks, xmls
  10. def list_files(in_path):
  11. img_files = []
  12. mask_files = []
  13. gt_files = []
  14. for (dirpath, dirnames, filenames) in os.walk(in_path):
  15. for file in filenames:
  16. filename, ext = os.path.splitext(file)
  17. ext = str.lower(ext)
  18. if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
  19. img_files.append(os.path.join(dirpath, file))
  20. elif ext == '.bmp':
  21. mask_files.append(os.path.join(dirpath, file))
  22. elif ext == '.xml' or ext == '.gt' or ext == '.txt':
  23. gt_files.append(os.path.join(dirpath, file))
  24. elif ext == '.zip':
  25. continue
  26. # img_files.sort()
  27. # mask_files.sort()
  28. # gt_files.sort()
  29. return img_files, mask_files, gt_files
  30. def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None):
  31. """ save text detection result one by one
  32. Args:
  33. img_file (str): image file name
  34. img (array): raw image context
  35. boxes (array): array of result file
  36. Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output
  37. Return:
  38. None
  39. """
  40. img = np.array(img)
  41. # make result file list
  42. filename, file_ext = os.path.splitext(os.path.basename(img_file))
  43. # result directory
  44. res_file = dirname + "res_" + filename + '.txt'
  45. res_img_file = dirname + "res_" + filename + '.jpg'
  46. if not os.path.isdir(dirname):
  47. os.mkdir(dirname)
  48. with open(res_file, 'w') as f:
  49. for i, box in enumerate(boxes):
  50. poly = np.array(box).astype(np.int32).reshape((-1))
  51. strResult = ','.join([str(p) for p in poly]) + '\r\n'
  52. f.write(strResult)
  53. poly = poly.reshape(-1, 2)
  54. cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
  55. ptColor = (0, 255, 255)
  56. if verticals is not None:
  57. if verticals[i]:
  58. ptColor = (255, 0, 0)
  59. if texts is not None:
  60. font = cv2.FONT_HERSHEY_SIMPLEX
  61. font_scale = 0.5
  62. cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1)
  63. cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1)
  64. # Save result image
  65. cv2.imwrite(res_img_file, img)