Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
135
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.24 KB | None | 0 0
  1. # -*- coding:utf-8 -*-
  2. import os
  3. import sys
  4. import shutil
  5. import xml.etree.ElementTree as ET
  6.  
  7. target_classes_list = ["didi", "group", "hello", "mobike", "ofo", "other", "rider"]
  8. origin_dir = "G:\\combine_all"
  9. result_dir = "G:\\combine_all_bike"
  10.  
  11.  
  12. def get_xml_file_path_and_name(file_dir):
  13. """获取文件夹下所有xml的路径"""
  14. file_path_list = []
  15. file_name_list = []
  16. for root, dirs, files in os.walk(file_dir):
  17. for file in files:
  18. if os.path.splitext(file)[1] == '.xml':
  19. file_path_list.append(os.path.join(root, file))
  20. file_name_list.append(file)
  21. return file_path_list, file_name_list
  22.  
  23.  
  24. def write_xml(tree, out_path):
  25. """
  26. 将xml文件写出
  27. tree: xml树
  28. out_path: 写出路径
  29. """
  30. tree.write(out_path, encoding="utf-8", xml_declaration=True)
  31.  
  32.  
  33. def init_dir(origin, result):
  34. origin_ann_dir = os.path.join(origin, "Annotations")
  35. origin_img_dir = os.path.join(origin, "JPEGImages")
  36. if not os.path.isdir(origin_ann_dir) or not os.path.isdir(origin_img_dir):
  37. raise Exception("origin data is missing!")
  38.  
  39. result_ann_dir = os.path.join(result, "Annotations")
  40. result_img_dir = os.path.join(result, "JPEGImages")
  41. # 判断文件夹是否存在,不存在则创建
  42. if not os.path.isdir(result_ann_dir):
  43. os.makedirs(result_ann_dir)
  44. if not os.path.isdir(result_img_dir):
  45. os.makedirs(result_img_dir)
  46. # 判断文件夹是否为空,不为空则清空
  47. if not os.listdir(result_ann_dir):
  48. shutil.rmtree(result_ann_dir)
  49. os.mkdir(result_ann_dir)
  50. if not os.listdir(result_img_dir):
  51. shutil.rmtree(result_img_dir)
  52. os.mkdir(result_img_dir)
  53.  
  54. return origin_ann_dir, origin_img_dir, result_ann_dir, result_img_dir
  55.  
  56.  
  57. def copy_img(xml_name, origin, result):
  58. img_name = xml_name.replace(".xml", ".jpg")
  59. old_path = os.path.join(origin, img_name)
  60. new_path = os.path.join(result, img_name)
  61. shutil.copyfile(old_path, new_path)
  62.  
  63.  
  64. def select(file_path_list, file_name_list):
  65. """选择所需要的类别标签,剔除多余标签"""
  66. for i in range(len(file_path_list)):
  67. file = file_path_list[i]
  68. try:
  69. tree = ET.parse(file)
  70. root = tree.getroot() # 获得根节点
  71. except Exception:
  72. print("[Error] Cannot parse file %s" % file)
  73. sys.exit(1)
  74. object_list = root.findall('object')
  75. # 剔除多余标签
  76. for target in object_list:
  77. if not target.find('name').text in target_classes_list:
  78. root.remove(target)
  79. # 剔除完毕后,判断当前xml中是否还有object,如果没有,则跳过,如果有,则将xml和对应的图片保存到新的位置,
  80. if len(root.findall('object')) == 0:
  81. continue
  82. else:
  83. write_xml(tree, os.path.join(result_annotation_dir, file_name_list[i]))
  84. copy_img(file_name_list[i], origin_image_dir, result_image_dir)
  85.  
  86.  
  87. if __name__ == '__main__':
  88. origin_annotation_dir, origin_image_dir, result_annotation_dir, result_image_dir = init_dir(origin_dir, result_dir)
  89.  
  90. xml_path_list, xml_name_list = get_xml_file_path_and_name(origin_annotation_dir)
  91.  
  92. select(xml_path_list, xml_name_list)
  93. print('Select finished!')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement