Advertisement
tham7777

Untitled

Jul 4th, 2020
491
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.33 KB | None | 0 0
  1. import argparse
  2. import numpy as np
  3. import os
  4.  
  5. import matplotlib.pyplot as plt
  6. import matplotlib.image as mpimg
  7.  
  8. import mindspore.dataset.transforms.vision.c_transforms as transforms
  9. import mindspore.dataset as ds
  10. import mindspore.dataset.transforms.c_transforms as C
  11. import mindspore.dataset.transforms.vision.c_transforms as CV
  12. import mindspore.nn as nn
  13.  
  14. from mindspore import context
  15. from mindspore.common import dtype as mstype
  16. from mindspore.common.initializer import TruncatedNormal
  17. from mindspore.model_zoo.resnet import resnet50
  18. from mindspore.nn.metrics import Accuracy
  19. from mindspore.train import Model
  20. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
  21. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  22.  
  23.  
  24. def create_train_dataset(args, num_parallel_workers = 1):
  25.     dataset = ds.ImageFolderDatasetV2(args.paths_train, decode=True, shuffle=True)
  26.        
  27.     dataset = dataset.map(input_columns="label", operations=C.TypeCast(mstype.int32), num_parallel_workers=num_parallel_workers)   
  28.     dataset = dataset.map(input_columns="image", operations=CV.RandomResizedCrop((args.resize_height, args.resize_width)), num_parallel_workers=num_parallel_workers)
  29.     dataset = dataset.map(input_columns="image", operations=CV.RandomRotation(30), num_parallel_workers=num_parallel_workers)
  30.     dataset = dataset.map(input_columns="image", operations=CV.RandomHorizontalFlip(), num_parallel_workers=num_parallel_workers)    
  31.     dataset = dataset.map(input_columns="image", operations=CV.Rescale(1.0/255.0, 0.0), num_parallel_workers=num_parallel_workers)
  32.     dataset = dataset.map(input_columns="image", operations=CV.RandomColorAdjust(), num_parallel_workers=num_parallel_workers)
  33.     dataset = dataset.map(input_columns="image", operations=CV.HWC2CHW(), num_parallel_workers=num_parallel_workers)
  34.     dataset = dataset.batch(args.batch, drop_remainder=True)
  35.     dataset = dataset.repeat(1)
  36.    
  37.     return dataset  
  38.    
  39. def weight_variable():
  40.     """Weight initial."""
  41.     return TruncatedNormal(0.02)
  42.  
  43. def conv(in_channels, out_channels, kernel_size, stride=1):
  44.     """Conv layer weight initial."""
  45.     weight = weight_variable()
  46.     return nn.Conv2d(in_channels, out_channels,
  47.                      kernel_size=kernel_size, stride=stride,
  48.                      weight_init=weight, has_bias=False, pad_mode="same")
  49.                      
  50. def fc_with_initialize(input_channels, out_channels):
  51.     """Fc layer weight initial."""
  52.     weight = weight_variable()
  53.     bias = weight_variable()
  54.     return nn.Dense(input_channels, out_channels, weight, bias)
  55.  
  56. class LeNet5(nn.Cell):
  57.     """Lenet network structure."""
  58.     # define the operator required
  59.     def __init__(self):
  60.         super(LeNet5, self).__init__()
  61.         self.conv1 = conv(3, 6, 5)
  62.         self.conv1BatchNorm = nn.BatchNorm2d(6)    
  63.         self.conv2 = conv(6, 16, 5)
  64.         self.fc1 = fc_with_initialize(16 * 32 * 32, 120)
  65.         self.fc2 = fc_with_initialize(120, 84)
  66.         self.fc3 = fc_with_initialize(84, 2)
  67.         self.relu = nn.ReLU()
  68.         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  69.         self.flatten = nn.Flatten()
  70.  
  71.     # use the preceding operators to construct networks
  72.     def construct(self, x):
  73.         x = self.conv1(x)
  74.         x = self.relu(x)
  75.         x = self.conv1BatchNorm(x) #works fine without batchnorm
  76.         x = self.max_pool2d(x)
  77.         x = self.conv2(x)
  78.         x = self.relu(x)
  79.         x = self.max_pool2d(x)
  80.         x = self.flatten(x)
  81.         x = self.fc1(x)
  82.         x = self.relu(x)
  83.         x = self.fc2(x)
  84.         x = self.relu(x)
  85.         x = self.fc3(x)
  86.         return x   
  87.  
  88. if __name__ == "__main__":
  89.     parser = argparse.ArgumentParser(description='Fire detection example')
  90.     parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
  91.                         help='device where the code will be implemented (default: CPU)')
  92.     parser.add_argument('--paths_train', default="./data/train", type=str, help="paths of the train folder")
  93.     parser.add_argument('--paths_test', default="./data/test", type=str, help="paths of the test folder")
  94.     parser.add_argument('--resize_height', type=int, default=128, help="height of the input image of model")
  95.     parser.add_argument('--resize_width', type=int, default=128, help="width of the input image of model")
  96.     parser.add_argument('--epoch', type=int, default=1, help="Epoch of training")
  97.     parser.add_argument('--batch', type=int, default=32, help="Batch size")
  98.    
  99.     args = parser.parse_args()
  100.     context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  101.     dataset_train = create_train_dataset(args)    
  102.    
  103.     net = LeNet5()#fire_classify_net.fire_classify_net()
  104.     loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
  105.     optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
  106.     #optim = nn.Adam(params=net.trainable_params())
  107.     model = Model(net, loss_fn=loss, optimizer=optim, metrics={"Accuracy": Accuracy()})
  108.     config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
  109.     # save the network model and parameters for subsequence fine-tuning
  110.     ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
  111.     model.train(args.epoch, dataset_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement