Advertisement
Guest User

Untitled

a guest
Apr 1st, 2020
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.50 KB | None | 0 0
  1. import cv2
  2. import torch
  3. import torch.nn as nn
  4. from torchvision import transforms
  5. from PIL import Image
  6.  
  7.  
  8. transform = transforms.Compose([
  9. transforms.Resize((200, 200)),
  10. transforms.ToTensor(),
  11. transforms.Normalize([0.5, 0.5, 0.5],
  12. [0.5, 0.5, 0.5])
  13. ])
  14.  
  15.  
  16. class ConvNet(nn.Module):
  17. def __init__(self):
  18. super().__init__()
  19.  
  20. self.conv1 = nn.Sequential(
  21. nn.Conv2d(3, 10, (3, 3), padding=1),
  22. nn.MaxPool2d((2, 2), stride=(2, 2)),
  23. nn.ReLU()
  24. )
  25. self.conv2 = nn.Sequential(
  26. nn.Conv2d(10, 20, (3, 3), padding=1),
  27. nn.MaxPool2d((2, 2), stride=(2, 2)),
  28. nn.ReLU()
  29. )
  30. self.conv3 = nn.Sequential(
  31. nn.Conv2d(20, 30, (3, 3), padding=1),
  32. nn.MaxPool2d((2, 2), stride=(2, 2)),
  33. nn.ReLU()
  34. )
  35. self.conv4 = nn.Sequential(
  36. nn.Conv2d(30, 64, (3, 3), padding=1),
  37. nn.MaxPool2d((2, 2), stride=(2, 2)),
  38. nn.ReLU()
  39. )
  40. self.linear1 = nn.Sequential(
  41. nn.Linear(9216, 2048),
  42. nn.Dropout(0.5),
  43. nn.ReLU()
  44. )
  45. self.linear2 = nn.Sequential(
  46. nn.Linear(2048, 512),
  47. nn.Dropout(0.5),
  48. nn.ReLU()
  49. )
  50. self.linear3 = nn.Sequential(
  51. nn.Linear(512, 100),
  52. nn.Dropout(0.5),
  53. nn.ReLU()
  54. )
  55.  
  56. self.linear4 = nn.Sequential(
  57. nn.Linear(100, 4),
  58.  
  59. )
  60.  
  61. def forward(self, x):
  62. output = self.conv1(x)
  63. output = self.conv2(output)
  64. output = self.conv3(output)
  65. output = self.conv4(output)
  66. output = output.view(-1)
  67. output = self.linear1(output)
  68. output = self.linear2(output)
  69. output = self.linear3(output)
  70. output = self.linear4(output)
  71. return output
  72.  
  73.  
  74. model = ConvNet()
  75. #model path
  76. model.load_state_dict(torch.load('D:\soft\PyCharmCommunityEdition2019.2.3\pycharmprojects\project\conv_net_model.ckpt'))
  77. model.eval()
  78. cap = cv2.VideoCapture(-1)
  79. while(True):
  80. ret, frame = cap.read()
  81. if not ret:
  82. raise ValueError("unable to load Image")
  83. img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  84. im_pil = Image.fromarray(img)
  85. im_pil = transforms(im_pil)
  86. im_pil.resize_(1, 3, 200, 200)
  87. coord = model(im_pil)
  88. frame = cv2.rectangle(frame, (coord[0], coord[1]), (coord[2], coord[3]), (0, 0, 255))
  89. cv2.imshow(frame)
  90. cap.release()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement