Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import json
- import matplotlib.pyplot as plt
- from collections import Counter
- # For jupyter notebook
- #%matplotlib inline
- # TODO: Clean up this code
- def cat_count(annotations=None):
- gtrain ,gval = annotations
- train_cats = []
- val_cats = []
- for i in gtrain['annotations']:
- j = i['category_id']
- for cat in gtrain['categories']:
- if j == cat['id']:
- train_cats.append(cat['name'])
- for k in gval['annotations']:
- w = k['category_id']
- for cat in gval['categories']:
- if w == cat['id']:
- val_cats.append(cat['name'])
- # Create dictionary of category and counts
- train_count_dict = dict(Counter(train_cats))
- val_count_dict = dict(Counter(val_cats))
- # Create dictionary of category and ids
- ids = []
- cats = []
- for ind in range(len(gtrain['categories'])):
- cat_id,cat,_ = zip(gtrain['categories'][ind].values())
- ids.append(cat_id[0])
- cats.append(cat[0])
- cat_dict = dict(zip(ids, cats))
- missing_train = set(list(cat_dict.values())) - set(list(train_count_dict.keys()))
- tadd = dict(zip(missing_train, [0]*len(missing_train)))
- train_count_dict.update(tadd)
- missing_val = set(list(cat_dict.values())) - set(list(val_count_dict.keys()))
- vadd = dict(zip(missing_val, [0]*len(missing_val)))
- val_count_dict.update(vadd)
- return [train_count_dict,val_count_dict,[tadd,vadd]]
- def show_class_distribution_both(annotations=None, dist="train",bar="h"):
- gtrain,gval = annotations
- assert dist in ["train","val"], "Has to be either 'train' or 'val' data"
- train_cats, val_cats, _ = cat_count([gtrain,gval])
- train_labels, train_values = zip(*Counter(train_cats).items())
- val_labels, val_values = zip(*Counter(val_cats).items())
- dat = ["train","val"]
- for name in dat:
- if name == "train":
- labels = train_labels
- values = train_values
- elif name == "val":
- labels = val_labels
- values = val_values
- indexes = np.arange(len(labels))
- width = 0.5
- if bar == "v":
- fig_size = (20,10)
- elif bar == "h":
- fig_size = (8,10)
- plt.figure(figsize=fig_size)
- if bar == "h":
- plt.barh(indexes, values, width,align="edge")
- plt.yticks(indexes+width/2,labels)
- plt.ylabel('Classes')
- plt.xlabel('Count')
- elif bar == "v":
- plt.bar(indexes, values, width,align="edge")
- plt.xticks(indexes, labels,rotation='vertical')
- plt.xlabel('Classes')
- plt.ylabel('Count')
- plt.tight_layout()
- plt.title('Class Distribution')
- if bar == "h":
- for i, v in enumerate(values):
- plt.text(v + 3, i , str(v), color='blue', fontweight='bold')
- elif bar == "v":
- for i, v in enumerate(values):
- plt.text(i,v + 5,str(v), color='blue', fontsize=0.6*fig_size[0],fontweight='bold')
- plt.savefig(name + ".jpg")
- plt.show()
- def main():
- with open('NEW_ANNOTATIONS_18CLASSES/instances_train.json') as gt:
- gtrain = json.load(gt)
- with open('NEW_ANNOTATIONS_18CLASSES/instances_val.json') as gv:
- gval = json.load(gv)
- show_class_distribution_both([gtrain,gval],dist="train",bar="h")
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement