Advertisement
Guest User

modeltrain.py

a guest
Mar 16th, 2022
41
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 15.40 KB | None | 0 0
  1. import sys
  2.  
  3. print("Loading PyTorch...\n", file=sys.stderr)
  4.  
  5. import torch
  6. from torch.utils.data import Dataset
  7. import torchstudio.tcpcodec as tc
  8. from torchstudio.modules import safe_exec
  9. import os
  10. import sys
  11. import io
  12. import tempfile
  13. from tqdm.auto import tqdm
  14. import time
  15.  
  16.  
  17. class CachedDataset(Dataset):
  18.     def __init__(self, disk_cache=False):
  19.         self.reset(disk_cache)
  20.  
  21.     def add_sample(self, data):
  22.         if self.disk_cache:
  23.             file=tempfile.TemporaryFile(prefix='torchstudio.'+str(len(self.index))+'.') #guaranteed to be deleted on win/mac/linux: https://bugs.python.org/issue4928
  24.             file.write(data)
  25.             self.index.append(file)
  26.         else:
  27.             self.index.append(tc.decode_torch_tensors(data))
  28.  
  29.     def reset(self, disk_cache=False):
  30.         self.index = []
  31.         self.disk_cache=disk_cache
  32.  
  33.     def __len__(self):
  34.         return len(self.index)
  35.  
  36.     def __getitem__(self, id):
  37.         if id<0 or id>=len(self):
  38.             raise IndexError
  39.  
  40.         if self.disk_cache:
  41.             file=self.index[id]
  42.             file.seek(0)
  43.             sample=tc.decode_torch_tensors(file.read())
  44.         else:
  45.             sample=self.index[id]
  46.         return sample
  47.  
  48.  
  49. modules_valid=True
  50.  
  51. train_dataset = CachedDataset()
  52. valid_dataset = CachedDataset()
  53. train_bar = None
  54.  
  55. model = None
  56.  
  57. app_socket = tc.connect()
  58. print("Training script connected\n", file=sys.stderr)
  59. while True:
  60.     msg_type, msg_data = tc.recv_msg(app_socket)
  61.     print("received message: "+msg_type+"\n", file=sys.stderr)
  62.     time.sleep(0.1)
  63.  
  64.     if msg_type == 'SetDevice':
  65.         print("Setting device...\n", file=sys.stderr)
  66.         time.sleep(0.1)
  67.         device_id=tc.decode_strings(msg_data)[0]
  68.         print("device_id: "+device_id+"\n", file=sys.stderr)
  69.         time.sleep(0.1)
  70.         device = torch.device(device_id)
  71.         print("device: "+str(device)+"\n", file=sys.stderr)
  72.         time.sleep(0.1)
  73.         pin_memory = True if 'cuda' in device_id else False
  74.         print("pin_memory: "+str(pin_memory)+"\n", file=sys.stderr)
  75.         time.sleep(0.1)
  76.  
  77.     if msg_type == 'SetTorchScriptModel' and modules_valid:
  78.         print("Setting torchscript model...\n", file=sys.stderr)
  79.         time.sleep(0.1)
  80.         buffer=io.BytesIO(msg_data)
  81.         print("torchscript buffer decoded\n", file=sys.stderr)
  82.         time.sleep(0.1)
  83.         model = torch.jit.load(buffer, map_location=device)
  84.         print("torchscript model loaded\n", file=sys.stderr)
  85.         time.sleep(0.1)
  86.  
  87.     if msg_type == 'SetPackageModel' and modules_valid:
  88.         print("Setting package model...\n", file=sys.stderr)
  89.         time.sleep(0.1)
  90.         buffer=io.BytesIO(msg_data)
  91.         print("package buffer decoded\n", file=sys.stderr)
  92.         time.sleep(0.1)
  93.         model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl', map_location=device)
  94.         print("package model loaded\n", file=sys.stderr)
  95.         time.sleep(0.1)
  96.  
  97.     if msg_type == 'SetModelState' and modules_valid:
  98.         print("Setting model state...\n", file=sys.stderr)
  99.         time.sleep(0.1)
  100.         if model is not None:
  101.             buffer=io.BytesIO(msg_data)
  102.             model.load_state_dict(torch.load(buffer,map_location=device))
  103.  
  104.     if msg_type == 'SetLossCodes' and modules_valid:
  105.         print("Setting loss code...\n", file=sys.stderr)
  106.         time.sleep(0.1)
  107.         loss_definitions=tc.decode_strings(msg_data)
  108.         print("loss definition decoded\n", file=sys.stderr)
  109.         criteria = []
  110.         for definition in loss_definitions:
  111.             error_msg, loss_env = safe_exec(definition, description='loss definition')
  112.             if error_msg is not None or 'loss' not in loss_env:
  113.                 print("Unknown loss definition error" if error_msg is None else error_msg, file=sys.stderr)
  114.                 modules_valid=False
  115.                 tc.send_msg(app_socket, 'TrainingError')
  116.                 break
  117.             else:
  118.                 criteria.append(loss_env['loss'])
  119.  
  120.     if msg_type == 'SetMetricCodes' and modules_valid:
  121.         print("Setting metrics code...\n", file=sys.stderr)
  122.         metric_definitions=tc.decode_strings(msg_data)
  123.         metrics = []
  124.         for definition in metric_definitions:
  125.             error_msg, metric_env = safe_exec(definition, description='metric definition')
  126.             if error_msg is not None or 'metric' not in metric_env:
  127.                 print("Unknown metric definition error" if error_msg is None else error_msg, file=sys.stderr)
  128.                 modules_valid=False
  129.                 tc.send_msg(app_socket, 'TrainingError')
  130.                 break
  131.             else:
  132.                 metrics.append(metric_env['metric'])
  133.  
  134.     if msg_type == 'SetOptimizerCode' and modules_valid:
  135.         print("Setting optimizer code...\n", file=sys.stderr)
  136.         error_msg, optimizer_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='optimizer definition')
  137.         if error_msg is not None or 'optimizer' not in optimizer_env:
  138.             print("Unknown optimizer definition error" if error_msg is None else error_msg, file=sys.stderr)
  139.             modules_valid=False
  140.             tc.send_msg(app_socket, 'TrainingError')
  141.         else:
  142.             optimizer = optimizer_env['optimizer']
  143.     if msg_type == 'SetOptimizerState' and modules_valid:
  144.         buffer=io.BytesIO(msg_data)
  145.         optimizer.load_state_dict(torch.load(buffer,map_location=device))
  146.  
  147.     if msg_type == 'SetSchedulerCode' and modules_valid:
  148.         print("Setting scheduler code...\n", file=sys.stderr)
  149.         error_msg, scheduler_env = safe_exec(tc.decode_strings(msg_data)[0], context=globals(), description='scheduler definition')
  150.         if error_msg is not None or 'scheduler' not in scheduler_env:
  151.             print("Unknown scheduler definition error" if error_msg is None else error_msg, file=sys.stderr)
  152.             modules_valid=False
  153.             tc.send_msg(app_socket, 'TrainingError')
  154.         else:
  155.             scheduler = scheduler_env['scheduler']
  156.  
  157.     if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values
  158.         batch_size, shuffle, epochs, early_stop = tc.decode_ints(msg_data)
  159.         early_stop=True if early_stop==1 else False
  160.         shuffle=True if shuffle==1 else False
  161.  
  162.     if msg_type == 'StartTrainingServer' and modules_valid:
  163.         print("Caching...\n", file=sys.stderr)
  164.  
  165.         sshaddress, sshport, username, password, keydata = tc.decode_strings(msg_data)
  166.  
  167.         training_server, address = tc.generate_server()
  168.  
  169.         if sshaddress and sshport and username:
  170.             import socket
  171.             import paramiko
  172.             import torchstudio.sshtunnel as sshtunnel
  173.  
  174.             if not password:
  175.                 password=None
  176.             if not keydata:
  177.                 pkey=None
  178.             else:
  179.                 import io
  180.                 keybuffer=io.StringIO(keydata)
  181.                 pkey=paramiko.RSAKey.from_private_key(keybuffer)
  182.  
  183.             sshclient = paramiko.SSHClient()
  184.             sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  185.             sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
  186.  
  187.             reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1]))
  188.             address[1]=str(reverse_tunnel.lport)
  189.  
  190.         tc.send_msg(app_socket, 'TrainingServerRequestingAllSamples', tc.encode_strings(address))
  191.         dataset_socket=tc.start_server(training_server)
  192.         train_dataset.reset()
  193.         valid_dataset.reset()
  194.  
  195.         while True:
  196.             dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
  197.  
  198.             if dataset_msg_type == 'NumSamples':
  199.                 num_samples=tc.decode_ints(dataset_msg_data)[0]
  200.                 pbar=tqdm(total=num_samples, desc='Caching...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
  201.  
  202.             if dataset_msg_type == 'InputTensorsID' and modules_valid:
  203.                 input_tensors_id = tc.decode_ints(dataset_msg_data)
  204.  
  205.             if dataset_msg_type == 'OutputTensorsID' and modules_valid:
  206.                 output_tensors_id = tc.decode_ints(dataset_msg_data)
  207.  
  208.             if dataset_msg_type == 'TrainingSample':
  209.                 train_dataset.add_sample(dataset_msg_data)
  210.                 pbar.update(1)
  211.  
  212.             if dataset_msg_type == 'ValidationSample':
  213.                 valid_dataset.add_sample(dataset_msg_data)
  214.                 pbar.update(1)
  215.  
  216.             if dataset_msg_type == 'DoneSending':
  217.                 pbar.close()
  218.                 tc.send_msg(dataset_socket, 'DoneReceiving')
  219.                 dataset_socket.close()
  220.                 training_server.close()
  221.                 if sshaddress and sshport and username:
  222.                     sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
  223.                 break
  224.  
  225.         train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
  226.         valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory)
  227.  
  228.     if msg_type == 'StartTraining' and modules_valid:
  229.         print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr)
  230.  
  231.     if msg_type == 'TrainOneEpoch' and modules_valid:
  232.         #training
  233.         model.train()
  234.         train_loss = 0
  235.         train_metrics = []
  236.         for metric in metrics:
  237.             metric.reset()
  238.         for batch_id, tensors in enumerate(train_loader):
  239.             inputs = [tensors[i].to(device) for i in input_tensors_id]
  240.             targets = [tensors[i].to(device) for i in output_tensors_id]
  241.             optimizer.zero_grad()
  242.             outputs = model(*inputs)
  243.             outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
  244.             loss = 0
  245.             for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
  246.                 loss = loss + criterion(output, target)
  247.             loss.backward()
  248.             optimizer.step()
  249.             train_loss += loss.item() * inputs[0].size(0)
  250.  
  251.             with torch.set_grad_enabled(False):
  252.                 for output, target, metric in zip(outputs, targets, metrics):
  253.                     metric.update(output, target)
  254.  
  255.         train_loss = train_loss/len(train_dataset)
  256.         train_metrics = 0
  257.         for metric in metrics:
  258.             train_metrics = train_metrics+metric.compute().item()
  259.         train_metrics/=len(metrics)
  260.         scheduler.step()
  261.  
  262.         #validation
  263.         model.eval()
  264.         valid_loss = 0
  265.         valid_metrics = []
  266.         for metric in metrics:
  267.             metric.reset()
  268.         with torch.set_grad_enabled(False):
  269.             for batch_id, tensors in enumerate(valid_loader):
  270.                 inputs = [tensors[i].to(device) for i in input_tensors_id]
  271.                 targets = [tensors[i].to(device) for i in output_tensors_id]
  272.                 outputs = model(*inputs)
  273.                 outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
  274.                 loss = 0
  275.                 for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
  276.                     loss = loss + criterion(output, target)
  277.                 valid_loss += loss.item() * inputs[0].size(0)
  278.  
  279.                 for output, target, metric in zip(outputs, targets, metrics):
  280.                     metric.update(output, target)
  281.  
  282.         valid_loss = valid_loss/len(valid_dataset)
  283.         valid_metrics = 0
  284.         for metric in metrics:
  285.             valid_metrics = valid_metrics+metric.compute().item()
  286.         valid_metrics/=len(metrics)
  287.  
  288.         tc.send_msg(app_socket, 'TrainingLoss', tc.encode_floats(train_loss))
  289.         tc.send_msg(app_socket, 'ValidationLoss', tc.encode_floats(valid_loss))
  290.         tc.send_msg(app_socket, 'TrainingMetric', tc.encode_floats(train_metrics))
  291.         tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics))
  292.  
  293.         buffer=io.BytesIO()
  294.         torch.save(model.state_dict(), buffer)
  295.         tc.send_msg(app_socket, 'ModelState', buffer.getvalue())
  296.  
  297.         buffer=io.BytesIO()
  298.         torch.save(optimizer.state_dict(), buffer)
  299.         tc.send_msg(app_socket, 'OptimizerState', buffer.getvalue())
  300.  
  301.         tc.send_msg(app_socket, 'Trained')
  302.  
  303.         #create train_bar only after first successful training to avoid ghost progress message after an error
  304.         if train_bar is not None:
  305.             train_bar.bar_format='{desc} epoch {n_fmt} | {remaining} left |{rate_fmt}\n\n'
  306.         else:
  307.             train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n')
  308.         train_bar.update(1)
  309.  
  310.     if msg_type == 'StopTraining' and modules_valid:
  311.         if train_bar is not None:
  312.             train_bar.close()
  313.             train_bar=None
  314.         print("Training stopped at epoch "+str(scheduler.last_epoch-1), file=sys.stderr)
  315.  
  316.     if msg_type == 'SetInputTensors' or msg_type == 'InferTensors':
  317.         input_tensors = tc.decode_torch_tensors(msg_data)
  318.         for i, tensor in enumerate(input_tensors):
  319.             input_tensors[i]=torch.unsqueeze(tensor, 0).to(device) #add batch dimension
  320.  
  321.     if msg_type == 'InferTensors':
  322.         if model is not None:
  323.             with torch.set_grad_enabled(False):
  324.                 model.eval()
  325.                 output_tensors=model(*input_tensors)
  326.                 output_tensors=[output.cpu() for output in output_tensors]
  327.                 tc.send_msg(app_socket, 'InferedTensors', tc.encode_torch_tensors(output_tensors))
  328.  
  329.     if msg_type == 'SaveTorchScript':
  330.         path, mode = tc.decode_strings(msg_data)
  331.         if "torch.jit" in str(type(model)):
  332.             torch.jit.save(model, path) #already a torchscript, save as is
  333.             print("Export complete")
  334.         else:
  335.             if mode=="trace":
  336.                 error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
  337.             else:
  338.                 error_msg, torchscript_model = safe_exec(torch.jit.script,{'obj':model}, description='model scripting')
  339.             if error_msg:
  340.                 print("Error exporting:", error_msg, file=sys.stderr)
  341.             else:
  342.                 torch.jit.save(torchscript_model, path)
  343.                 print("Export complete")
  344.  
  345.     if msg_type == 'SaveONNX':
  346.         error_msg=None
  347.         torchscript_model=model
  348.         if not "torch.jit" in str(type(model)):
  349.             error_msg, torchscript_model = safe_exec(torch.jit.trace,{'func':model, 'example_inputs':input_tensors, 'check_trace':False}, description='model tracing')
  350.         if error_msg:
  351.             print("Error exporting:", error_msg, file=sys.stderr)
  352.         else:
  353.             error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'opset_version':12})
  354.             if error_msg:
  355.                 print("Error exporting:", error_msg, file=sys.stderr)
  356.             else:
  357.                 print("Export complete")
  358.  
  359.     if msg_type == 'Exit':
  360.         break
  361.  
  362.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement