Advertisement
Guest User

Untitled

a guest
May 19th, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.44 KB | None | 0 0
  1. import numpy as np
  2. import json
  3. import matplotlib.pyplot as plt
  4. from collections import Counter
  5.  
  6. # For jupyter notebook
  7. #%matplotlib inline
  8.  
  9. # TODO: Clean up this code
  10.  
  11. def cat_count(annotations=None):
  12. gtrain ,gval = annotations
  13.  
  14. train_cats = []
  15. val_cats = []
  16. for i in gtrain['annotations']:
  17. j = i['category_id']
  18. for cat in gtrain['categories']:
  19. if j == cat['id']:
  20. train_cats.append(cat['name'])
  21.  
  22.  
  23. for k in gval['annotations']:
  24. w = k['category_id']
  25. for cat in gval['categories']:
  26. if w == cat['id']:
  27. val_cats.append(cat['name'])
  28.  
  29. # Create dictionary of category and counts
  30. train_count_dict = dict(Counter(train_cats))
  31. val_count_dict = dict(Counter(val_cats))
  32.  
  33. # Create dictionary of category and ids
  34. ids = []
  35. cats = []
  36. for ind in range(len(gtrain['categories'])):
  37. cat_id,cat,_ = zip(gtrain['categories'][ind].values())
  38. ids.append(cat_id[0])
  39. cats.append(cat[0])
  40. cat_dict = dict(zip(ids, cats))
  41.  
  42. missing_train = set(list(cat_dict.values())) - set(list(train_count_dict.keys()))
  43. tadd = dict(zip(missing_train, [0]*len(missing_train)))
  44. train_count_dict.update(tadd)
  45.  
  46. missing_val = set(list(cat_dict.values())) - set(list(val_count_dict.keys()))
  47. vadd = dict(zip(missing_val, [0]*len(missing_val)))
  48. val_count_dict.update(vadd)
  49.  
  50. return [train_count_dict,val_count_dict,[tadd,vadd]]
  51.  
  52. def show_class_distribution_both(annotations=None, dist="train",bar="h"):
  53. gtrain,gval = annotations
  54. assert dist in ["train","val"], "Has to be either 'train' or 'val' data"
  55. train_cats, val_cats, _ = cat_count([gtrain,gval])
  56. train_labels, train_values = zip(*Counter(train_cats).items())
  57. val_labels, val_values = zip(*Counter(val_cats).items())
  58.  
  59.  
  60. dat = ["train","val"]
  61.  
  62. for name in dat:
  63.  
  64. if name == "train":
  65. labels = train_labels
  66. values = train_values
  67. elif name == "val":
  68. labels = val_labels
  69. values = val_values
  70.  
  71.  
  72. indexes = np.arange(len(labels))
  73. width = 0.5
  74.  
  75. if bar == "v":
  76. fig_size = (20,10)
  77. elif bar == "h":
  78. fig_size = (8,10)
  79.  
  80. plt.figure(figsize=fig_size)
  81.  
  82. if bar == "h":
  83. plt.barh(indexes, values, width,align="edge")
  84. plt.yticks(indexes+width/2,labels)
  85. plt.ylabel('Classes')
  86. plt.xlabel('Count')
  87. elif bar == "v":
  88. plt.bar(indexes, values, width,align="edge")
  89. plt.xticks(indexes, labels,rotation='vertical')
  90. plt.xlabel('Classes')
  91. plt.ylabel('Count')
  92.  
  93.  
  94. plt.tight_layout()
  95. plt.title('Class Distribution')
  96.  
  97.  
  98. if bar == "h":
  99. for i, v in enumerate(values):
  100. plt.text(v + 3, i , str(v), color='blue', fontweight='bold')
  101. elif bar == "v":
  102. for i, v in enumerate(values):
  103. plt.text(i,v + 5,str(v), color='blue', fontsize=0.6*fig_size[0],fontweight='bold')
  104.  
  105. plt.savefig(name + ".jpg")
  106. plt.show()
  107.  
  108.  
  109. def main():
  110.  
  111. with open('NEW_ANNOTATIONS_18CLASSES/instances_train.json') as gt:
  112. gtrain = json.load(gt)
  113. with open('NEW_ANNOTATIONS_18CLASSES/instances_val.json') as gv:
  114. gval = json.load(gv)
  115. show_class_distribution_both([gtrain,gval],dist="train",bar="h")
  116.  
  117. if __name__ == "__main__":
  118. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement