Advertisement
Guest User

Danbooru2017 classifier performance: 85% accuracy

a guest
Jul 8th, 2018
168
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. Prepare the Danbooru2017 anime image dataset for use with the fast.ai NN library for training a Safe/Questionable/Explicit classifier (see https://danbooru.donmai.us/wiki_pages/10920 for discussion of what each category means - roughly, SFW/lewd/NSFW):
  2.  
  3. ~~~{.Bash}
  4. # [
  5. cat metadata/*.json | jq '[.id, .rating]' -c | fgrep '"s"' | sed -e 's/\,"s"\]//' | tr -d '["' > ./sfw-ids.txt
  6. cat metadata/*.json | jq '[.id, .rating]' -c | fgrep '"q"' | sed -e 's/\,"q"\]//' | tr -d '["' > ./questionable-ids.txt
  7. cat metadata/*.json | jq '[.id, .rating]' -c | fgrep '"e"' | sed -e 's/\,"e"\]//' | tr -d '["' > ./nsfw-ids.txt # ]" # ]
  8.  
  9. ## Note: this assumes you've run 'rescale-images.sh' to create a 512px version of the full corpus. If you don't want to do that, you can replace '512px-all/' with 'original/' and the NN library *should* automatically rescale down to 512x512 as necessary, but you'll spend more time loading images off the disk & processing, so GPU utilization may be worse.
  10. symSFW() { BUCKET=$(printf "%04d" $(( $@ % 1000 )) ); if [[ -a ./512px-all/$BUCKET/"$@".jpg ]]; then ln -s /media/gwern/Data/danbooru2017/512px-all/"$BUCKET"/"$@".jpg ./data/train/sfw/"$@".jpg; fi; }
  11. export -f symSFW
  12. cat sfw-ids.txt | head -600000 | nice parallel --progress symSFW &
  13.  
  14. symQ() { BUCKET=$(printf "%04d" $(( $@ % 1000 )) ); if [[ -a ./512px-all/$BUCKET/"$@".jpg ]]; then ln -s /media/gwern/Data/danbooru2017/512px-all/"$BUCKET"/"$@".jpg ./data/train/questionable/"$@".jpg; fi; }
  15. export -f symQ
  16. cat questionable-ids.txt | head -600000 | nice parallel --progress symQ &
  17.  
  18. symNSFW() { BUCKET=$(printf "%04d" $(( $@ % 1000 )) ); if [[ -a ./512px-all/$BUCKET/"$@".jpg ]]; then ln -s /media/gwern/Data/danbooru2017/512px-all/"$BUCKET"/"$@".jpg ./data/train/nsfw/"$@".jpg; fi; }
  19. export -f symNSFW
  20. cat nsfw-ids.txt | head -600000 | nice parallel --progress symNSFW &
  21.  
  22. cd data
  23. # 10k validation images total:
  24. mv `find train/sfw/ -type l|shuf|head -3334` valid/sfw/
  25. mv `find train/questionable/ -type l|shuf|head -3333` valid/questionable/
  26. mv `find train/nsfw/ -type l|shuf|head -3333` valid/nsfw/
  27. cd ../
  28.  
  29. find data/train/ -type l | wc --lines
  30. # 1271820
  31. find data/train/sfw/ -type l | wc --lines
  32. # 588879
  33. find data/train/questionable/ -type l | wc --lines
  34. # 434770
  35. find data/train/nsfw/ -type l | wc --lines
  36. # 248171
  37. # R> round(digits=2, c(588879, 434770, 248171) / 1271820)
  38. # [1] 0.46 0.34 0.20
  39. ~~~
  40.  
  41. Train a deep NN with cyclic learning rates to 85% accuracy, and examine validation set for mislabeled images (specifically, SFW images mislabeled as the default, Questionable):
  42.  
  43. ~~~{.Python}
  44. from fastai.transforms import *
  45. from fastai.conv_learner import *
  46. from fastai.model import *
  47. from fastai.dataset import *
  48. from fastai.sgdr import *
  49. from fastai.plots import *
  50.  
  51. ## Specify: DenseNet-101 on 512x512 image data, from Danbooru2017, minibatch=26 (just fits in 2x1080ti), all fast.ai data augmentations:
  52. PATH = "/media/gwern/Data/danbooru2017/data/"
  53. sz = 512
  54. bs = 13*2
  55. arch=dn121
  56. tfms = tfms_from_model(arch, sz, aug_tfms=transforms_top_down+transforms_side_on, max_zoom=1.1)
  57.  
  58. data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs)
  59. learn = ConvLearner.pretrained(arch, data, precompute=False)
  60. learn.models.model = torch.nn.DataParallel(learn.models.model)
  61.  
  62. ## Train:
  63. learn.fit(0.0229, 1)
  64. learn.unfreeze()
  65. lr=np.array([2e-4,2e-3,2.29e-2])
  66.  
  67. learn.fit(lr, 5, cycle_len=1, cycle_mult=2)
  68. learn.save("2018-07-07-densenet101-512-sfwdanbooru85percent")
  69. # https://www.dropbox.com/s/yi3ut10co7hhvy2/2018-07-05-densenet101-sfwdanbooru85percent.h5?dl=0
  70.  
  71. ## Generate predictions on validation set for active learning:
  72. ## to compute on the training/test set instead: log_preds,y = learn.TTA(is_test=True)
  73. log_preds,y = learn.TTA()
  74. probs = np.mean(np.exp(log_preds), axis=0)
  75. probs = np.mean(np.exp(log_preds), axis=0)
  76.  
  77. ## Prediction quality:
  78. accuracy_np(probs, y)
  79. # 0.852
  80.  
  81. from sklearn.metrics import confusion_matrix
  82. confusion_matrix(y, np.argmax(probs,axis=1))
  83. # NSFW Q S
  84. # NSFW 2552, 686, 95
  85. # Q 230, 2544, 559
  86. # S 20, 243, 3071
  87.  
  88. ## Merge filenames/labels/predictions:
  89. results = np.column_stack((data.val_ds.fnames, data.val_y, preds, np.mean(log_preds, axis=0)))
  90. ## Filter out just the Questionable images the NN thinks are Safe
  91. m = np.stack([row for row in results if (row[1] == '1' and row[2] == '2' ) ])
  92. ## Order mistakes by confidence from highest confidence to lowest:
  93. m = m[m[:,7].argsort()]
  94. ## Write out for interactive evaluation:
  95. np.savetxt('/media/gwern/Data/danbooru2017/data/q2s.txt', m[:,0], delimiter='', fmt="%s")
  96. ~~~
  97.  
  98. Visually examine and fix Q/S mistakes:
  99.  
  100. ~~~{.Bash}
  101. cd /media/gwern/Data/danbooru2017/data/
  102.  
  103. danbooruEditStatus() {
  104. local USERNAME="gwern-bot"
  105. local API_KEY="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
  106. local ID=$(basename --suffix=".jpg" $1)
  107. curl -u "$USERNAME:$API_KEY" -X PUT "https://danbooru.donmai.us/posts/"$ID".json" -d "post[rating]=$2"
  108. echo "Result:"
  109. curl --get "https://danbooru.donmai.us/posts/$ID.json" --data "login=$USERNAME&api_key=$API_KEY" | jq '.rating'
  110. }
  111. export -f danbooruEditStatus
  112.  
  113. ## Example use:
  114. # $ danbooruEditStatus valid/questionable/1452109.jpg s
  115. # {"tag_string":"2girls alternate_costume black_legwear blue_eyes blue_legwear brown_hair carrying cherry_blossoms cloud day hair_ornament hairclip holding loafers long_hair long_sleeves lyrical_nanoha mahou_shoujo_lyrical_nanoha mahou_shoujo_lyrical_nanoha_a's miniskirt multiple_girls open_mouth pantyhose princess_carry print_legwear red_eyes reinforce shoes short_hair shorts silver_hair single_hair_intake skirt sky smile socks standing sweater takana turtleneck x_hair_ornament yagami_hayate","is_banned":false,"id":1452109,"rating":"s","parent_id":null,"source":"http://25.media.tumblr.com/tumblr_m7ijbpKjKy1r9yjhso1_1280.jpg","image_width":800,"image_height":614,"file_size":102206,"file_ext":"jpg","pixiv_id":null,"uploader_id":23799,"keeper_data":{"uid":23799},"tag_count":42,"tag_count_general":36,"tag_count_character":2,"tag_count_copyright":3,"tag_count_artist":1,"tag_count_meta":0,"pool_string":"","created_at":"2013-06-29T08:58:15.443-04:00","score":1,"md5":"cc735d40ee521b74acaec6721e7e1549","last_comment_bumped_at":null,"is_note_locked":false,"fav_count":1,"last_noted_at":null,"is_rating_locked":false,"has_children":false,"approver_id":287254,"is_status_locked":false,"up_score":1,"down_score":0,"is_pending":false,"is_flagged":false,"is_deleted":false,"updated_at":"2018-07-08T17:06:53.192-04:00","last_commented_at":null,"has_active_children":false,"bit_flags":0,"uploader_name":"BrokenEagle98","has_large":false,"has_visible_children":false,"children_ids":null,"is_favorited":false,"tag_string_general":"2girls alternate_costume black_legwear blue_eyes blue_legwear brown_hair carrying cherry_blossoms cloud day hair_ornament hairclip holding loafers long_hair long_sleeves miniskirt multiple_girls open_mouth pantyhose princess_carry print_legwear red_eyes shoes short_hair shorts silver_hair single_hair_intake skirt sky smile socks standing sweater turtleneck x_hair_ornament","tag_string_character":"reinforce yagami_hayate","tag_string_copyright":"lyrical_nanoha mahou_shoujo_lyrical_nanoha mahou_shoujo_lyrical_nanoha_a's","tag_string_artist":"takana","tag_string_meta":"","file_url":"https://raikou2.donmai.us/cc/73/cc735d40ee521b74acaec6721e7e1549.jpg","large_file_url":"https://raikou2.donmai.us/cc/73/cc735d40ee521b74acaec6721e7e1549.jpg","preview_file_url":"https://raikou2.donmai.us/preview/cc/73/cc735d40ee521b74acaec6721e7e1549.jpg"} % Total % Received % Xferd Average Speed Time Time Time Current
  116. # Dload Upload Total Spent Left Speed
  117. # 100 2370 0 2370 0 0 8061 0 --:--:-- --:--:-- --:--:-- 8061
  118.  
  119. ## Scroll through the images in order; press '1' to move a questionable image to SFW & set to SFW on http://danbooru.donmai.us as well; '2' to move a Q to NSFW etc.
  120. feh --file=q2s.txt \
  121. --action1 "mv %f ./valid/sfw/ && echo %f >> q2s-moved.txt && danbooruEditStatus %f s &" \
  122. --action2 "mv %f ./valid/nsfw/ && echo %f >> q2e-moved.txt && danbooruEditStatus %f e &"
  123. ~~~
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement