Advertisement
Guest User

Untitled

a guest
Jun 9th, 2021
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.89 KB | None | 0 0
  1. import torch
  2.  
  3. import torch.nn as nn
  4.  
  5. import torch.nn.functional as F
  6.  
  7. from torch.autograd import Variable
  8.  
  9. import torch.optim as optim
  10.  
  11. import PIL.Image as Image
  12.  
  13. import random
  14.  
  15. import os
  16.  
  17. import torchvision.transforms as transforms
  18.  
  19. import time
  20.  
  21.  
  22. transforms = transforms.Compose([
  23.  
  24.    transforms.Resize(200),
  25.  
  26.    transforms.ToTensor()#,
  27.  
  28.    #transforms.Normalize(mean = [0.3, 0.4,0.2], std = [0.2, 0.3, 0.2]) # Werte müssen zusammengerechnet unter 1 liegen
  29.  
  30.    ])
  31.  
  32.  
  33. def concatenate(list1, list2):
  34.  
  35.    list = []
  36.  
  37.    for i in list1:
  38.  
  39.        list.append(i)
  40.  
  41.    for i in list2:
  42.  
  43.        list.append(i)
  44.  
  45.    return list
  46.  
  47.  
  48. root = "/home/ich/Desktop/NN-UEbung/dataset"
  49.  
  50. image_paths = os.listdir(root + "/training_fake")
  51.  
  52. image_paths = concatenate(image_paths, os.listdir(root + "/training_real"))
  53.  
  54.  
  55. train_data_list = []
  56.  
  57. test_data_list = []
  58.  
  59. test_data = []
  60.  
  61. train_data = []
  62.  
  63. target_train_list = []
  64.  
  65. target_test_list = []
  66.  
  67.  
  68. batch_size = 4
  69.  
  70. batch = 0
  71.  
  72.  
  73. for i in range(len(image_paths)):
  74.  
  75.    img_path = random.choice(image_paths)
  76.  
  77.    image_paths.remove(img_path)
  78.  
  79.  
  80.    if "fake" in img_path:
  81.  
  82.        if random.choice([0,1]) == 1:
  83.  
  84.            target_train_list.append(1)
  85.  
  86.            train_data_list.append(transforms(Image.open(root + "/training_fake/" + img_path)))
  87.  
  88.        else:
  89.  
  90.            target_test_list.append(1)
  91.  
  92.            test_data_list.append(transforms(Image.open(root + "/training_fake/" + img_path)))
  93.  
  94.    else:
  95.  
  96.        if random.choice([0,1]) == 1:
  97.  
  98.            target_train_list.append(0)
  99.  
  100.            train_data_list.append(transforms(Image.open(root + "/training_real/" + img_path)))
  101.  
  102.        else:
  103.  
  104.            target_test_list.append(0)
  105.  
  106.            test_data_list.append(transforms(Image.open(root + "/training_real/" + img_path)))
  107.  
  108.  
  109.    if len(train_data_list) >= batch_size:
  110.  
  111.        train_data.append((torch.stack(train_data_list), (target_train_list)))
  112.  
  113.        train_data_list = []
  114.  
  115.        target_train_list = []
  116.  
  117.        batch += 1
  118.  
  119.        print("Batch Nr. " + str(batch))
  120.  
  121.    if len(test_data_list) >= batch_size:
  122.  
  123.        test_data.append((torch.stack(test_data_list), (target_test_list)))
  124.  
  125.        test_data_list = []
  126.  
  127.        target_test_list = []
  128.  
  129.        batch += 1
  130.  
  131.        print("Batch Nr. " + str(batch))
  132.  
  133.  
  134. class Netz(nn.Module):
  135.  
  136.    def __init__(self):
  137.  
  138.        super(Netz, self).__init__()
  139.  
  140.        self.conv1 = nn.Conv2d(3, 5, kernel_size = 5)
  141.  
  142.        self.conv2 = nn.Conv2d(5, 8, kernel_size = 5)
  143.  
  144.        self.conv3 = nn.Conv2d(8, 14, kernel_size = 5)
  145.  
  146.        #self.conv4 = nn.Conv2d(18, 24, kernel_size = 3)
  147.  
  148.        self.fc1 = nn.Linear(24696, 1000)
  149.  
  150.        self.fc2 = nn.Linear(1000, batch_size)
  151.  
  152.  
  153.    def forward(self, x):
  154.  
  155.        x = self.conv1(x)
  156.  
  157.        x = F.max_pool2d(x, 2)
  158.  
  159.        x = F.relu(x)
  160.  
  161.        x = self.conv2(x)
  162.  
  163.        x = F.max_pool2d(x, 2)
  164.  
  165.        x = F.relu(x)
  166.  
  167.        x = self.conv3(x)
  168.  
  169.        x = F.max_pool2d(x, 2)
  170.  
  171.        x = F.relu(x)
  172.  
  173.        #x = self.conv4(x)
  174.  
  175.        #x = F.max_pool2d(x, 2)
  176.  
  177.        #x = F.relu(x)
  178.  
  179.        x = x.view(-1, 24696)
  180.  
  181.        x = self.fc1(x)
  182.  
  183.        x = F.relu(x)
  184.  
  185.        x = self.fc2(x)
  186.  
  187.        return torch.sigmoid(x)
  188.  
  189.  
  190. model = Netz()
  191.  
  192. #if torch.cuda.is_available():
  193.  
  194. #   model = model.cuda()
  195.  
  196. #   print("Netz auf CUDA verschoben!")
  197.  
  198.  
  199. if os.path.isfile("nn.pt"):
  200.  
  201.    model = torch.load("nn.pt")
  202.  
  203.    print("File nn.pt loaded!")
  204.  
  205.  
  206. optimizer = optim.Adam(model.parameters(), lr=0.01)
  207.  
  208. torch.autograd.set_detect_anomaly(True)
  209.  
  210. def train(epoch):
  211.  
  212.    model.train()
  213.  
  214.    batch_id = 1
  215.  
  216.    for data, target in train_data:
  217.  
  218.        target = Variable(torch.Tensor(target))
  219.  
  220.        data = Variable(data)
  221.  
  222.        #if torch.cuda.is_available():
  223.  
  224.        #   data = data.cuda()
  225.  
  226.        #   target = target.cuda()
  227.  
  228.        #   print("Daten auf CUDA verschoben!")
  229.  
  230.  
  231.        optimizer.zero_grad()
  232.  
  233.        out = model(data)
  234.  
  235.        criterion = F.binary_cross_entropy
  236.  
  237.        loss = criterion(out.squeeze(), target)
  238.  
  239.        loss.backward()
  240.  
  241.        optimizer.step()
  242.  
  243.        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} (Batch-ID: {})'.format(epoch, (batch_id - 1) * batch_size, len(train_data) * batch_size, 100. * batch_id / len(train_data), loss.item(), batch_id))
  244.  
  245.        batch_id += 1
  246.  
  247.    torch.save(model, "nn.pt")
  248.  
  249.  
  250. def test():
  251.  
  252.    model.eval()
  253.  
  254.    loss = 0
  255.  
  256.    for data, target in train_data:
  257.  
  258.        target = Variable(torch.Tensor(target))
  259.  
  260.        data = Variable(data)
  261.  
  262.        #if torch.cuda.is_available():
  263.  
  264.        #   data = data.cuda()
  265.  
  266.        #   target = target.cuda()
  267.  
  268.        #   print("Daten auf CUDA verschoben!")
  269.  
  270.        out = model(data)
  271.  
  272.        loss += F.binary_cross_entropy(out.squeeze(), target)
  273.  
  274.        #torch.cuda.empty_cache()
  275.  
  276.    print("Average loss: " + str(loss.item() / len(test_data)))
  277.  
  278.  
  279. for epoch in range(1, 31):
  280.  
  281.    train(epoch)
  282.  
  283.    test()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement