Guest User

Untitled

a guest
Apr 22nd, 2018
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.24 KB | None | 0 0
  1. import math
  2. import sys
  3. import errno
  4. import os
  5. import ctypes
  6. import signal
  7. import torch
  8. import time
  9. import traceback
  10. import unittest
  11. import subprocess
  12. from torch import multiprocessing
  13. from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
  14. from torch.utils.data.dataset import random_split
  15. from torch.utils.data.dataloader import default_collate, ExceptionWrapper
  16.  
  17. JOIN_TIMEOUT = 10
  18.  
  19. def _manager_process(dataset, worker_pids):
  20. print("_manager_process is started")
  21. loader = iter(DataLoader(dataset, batch_size=2, num_workers=4, pin_memory=True))
  22. print("loader init done")
  23. workers = loader.workers
  24. print("# workers: ", len(workers))
  25. for i in range(len(workers)):
  26. print(workers[i].pid)
  27. worker_pids[i] = int(workers[i].pid)
  28. print(worker_pids[i])
  29. for i, sample in enumerate(loader):
  30. if i == 3:
  31. break
  32. os.kill(os.getpid(), signal.SIGKILL)
  33.  
  34. def _is_process_alive(pid, pname):
  35. command = 'ps -p {} -o comm='.format(pid)
  36. p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
  37. (output, err) = p.communicate()
  38. p_status = p.wait()
  39. output = output.decode('utf-8')
  40. return pname in output
  41.  
  42. def test_worker_exit(dataset):
  43. worker_pids = multiprocessing.Array('i', [0] * 4)
  44.  
  45. # _manager_process(dataset, worker_pids)
  46.  
  47. mp = multiprocessing.Process(target=_manager_process, args=(dataset, worker_pids, ))
  48. mp.start()
  49.  
  50. time.sleep(30)
  51.  
  52. exit_status = [False] * len(worker_pids)
  53. start_time = time.time()
  54. pname = 'python'
  55. while True:
  56. for i in range(len(worker_pids)):
  57. pid = worker_pids[i]
  58. if not exit_status[i]:
  59. if not _is_process_alive(pid, pname):
  60. exit_status[i] = True
  61. if all(exit_status):
  62. break
  63. else:
  64. time.sleep(1)
  65. if time.time() - start_time > JOIN_TIMEOUT:
  66. raise Exception('subprocess not terminated')
  67.  
  68. if __name__ == '__main__':
  69. data = torch.randn(100, 2, 3, 5)
  70. labels = torch.randperm(50).repeat(2)
  71. dataset = TensorDataset(data, labels)
  72. multiprocessing.set_start_method('spawn')
  73. test_worker_exit(dataset)
Add Comment
Please, Sign In to add comment