Advertisement
Guest User

Untitled

a guest
Sep 18th, 2019
136
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.80 KB | None | 0 0
  1. from glob import glob
  2. from shutil import copyfile
  3. from json import load
  4. from random import shuffle
  5. import matplotlib.pyplot as plt
  6. import matplotlib.patches as patches
  7.  
  8. # stat
  9. stat = {"file": []}
  10. all_shapes = []
  11. str_for_map = "0123456789.:-"
  12. label_map = {j: i for (i, j) in enumerate(str_for_map)}
  13. dst = "train"
  14. src = "ori"
  15. train_path = "/home/tmp/"
  16.  
  17.  
  18. # plot
  19. def plot_imgandlabel(img, txt):
  20. plt.imshow(img)
  21. cord = []
  22.  
  23. for labels in txt:
  24. h, w = img.shape[:2]
  25. x1 = w * (labels[1] - labels[3]/2)
  26. y1 = h * (labels[2] - labels[4]/2)
  27. x2 = w * (labels[1] + labels[3]/2)
  28. y2 = h * (labels[2] + labels[4]/2)
  29. print(labels[0], x1, y1, x2, y2)
  30.  
  31. bbox = patches.Rectangle((x1, y1), x2-x1, y2-y1,
  32. linewidth=2, facecolor='none', edgecolor='blue')
  33. plt.gca().add_patch(bbox)
  34.  
  35. plt.imshow(img)
  36. plt.show()
  37.  
  38.  
  39. # plot by name
  40. def plot_imgandlabel_raw(imgpath, txtpath):
  41. img = plt.imread(imgpath)
  42. txt = open(txtpath).readlines()
  43. labels = [[float(t) for t in ti.split()] for ti in txt]
  44. plot_imgandlabel(img, labels)
  45.  
  46.  
  47. # copy imgae to train
  48. def copyImage(base):
  49. imgfile = base + ".jpg"
  50. copyfile(imgfile, imgfile.replace(src, dst))
  51.  
  52.  
  53. # transfer label to train
  54. def copyLabel(base):
  55. stat["file"].append(base)
  56. return
  57. js = load(open(base + ".json"))
  58.  
  59. # image meta
  60. h, w = js["imageHeight"], js["imageWidth"]
  61. # debug
  62. # print(h, w)
  63. # imgfile = base + ".jpg"
  64. # img = plt.imread(imgfile)
  65.  
  66. # shape transfer
  67. txt = []
  68. for sh in js["shapes"]:
  69. label, (x1, y1), (x2, y2)= sh['label'], *sh["points"]
  70. x1, x2 = min(x1, x2), max(x1, x2)
  71. y1, y2 = min(y1, y2), max(y1, y2)
  72. all_shapes.append((x2 - x1, y2 - y1))
  73. labels = [label_map[label], (x1 + x2) / 2 / w, (y1 + y2) / 2 / h, (x2 - x1) / w, (y2 - y1) / h]
  74. txt.append(labels)
  75. # stat
  76. if not stat.get(label):
  77. stat[label] = 1
  78. else:
  79. stat[label] += 1
  80.  
  81. # write
  82. fout = open(base.replace(src, dst) + ".txt", "w")
  83. for t in txt:
  84. print("{} {:6f} {:6f} {:6f} {:6f}".format(*t), file=fout)
  85. # debug
  86. # print(txt)
  87. # plot_imgandlabel(img, txt)
  88. # break
  89.  
  90.  
  91. def writeMeta():
  92. data = """\
  93. classes={}
  94. train=../train/train.txt
  95. valid=../train/valid.txt
  96. names=../train/name.txt
  97. """.format(len(str_for_map))
  98. open(dst + "/name.txt", "w").write("\n".join(str_for_map) + "\n")
  99. open(dst + "/digit.data", "w").write(data)
  100. all_file = [train_path + d.replace(src, dst) + ".jpg" for d in stat['file']]
  101. # random
  102. shuffle(all_file)
  103. vlen = len(all_file) // 5
  104. train_file = all_file[:-vlen]
  105. valid_file = all_file[-vlen:]
  106. open(dst + "/train.txt", "w").write(
  107. "\n".join(train_file))
  108. open(dst + "/valid.txt", "w").write(
  109. "\n".join(valid_file))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement