Guest User

Untitled

a guest
Dec 13th, 2017
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.32 KB | None | 0 0
  1. require 'torch'
  2. require 'nn'
  3. require 'image'
  4.  
  5. require 'fast_neural_style.ShaveImage'
  6. require 'fast_neural_style.TotalVariation'
  7. require 'fast_neural_style.InstanceNormalization'
  8.  
  9. local utils = require 'fast_neural_style.utils'
  10. local preprocess = require 'fast_neural_style.preprocess'
  11.  
  12. local ev = require'ev'
  13. --JSON = (loadfile "JSON.lua")() -- one-time load of the routines
  14.  
  15.  
  16. local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' -- You will need this for encoding/decoding
  17. -- encoding
  18. function enc(data)
  19. return ((data:gsub('.', function(x)
  20. local r,b='',x:byte()
  21. for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end
  22. return r;
  23. end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x)
  24. if (#x < 6) then return '' end
  25. local c=0
  26. for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end
  27. return b:sub(c+1,c+1)
  28. end)..({ '', '==', '=' })[#data%3+1])
  29. end
  30.  
  31. -- decoding
  32. function dec(data)
  33. data = string.gsub(data, '[^'..b..'=]', '')
  34. return (data:gsub('.', function(x)
  35. if (x == '=') then return '' end
  36. local r,f='',(b:find(x)-1)
  37. for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end
  38. return r;
  39. end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x)
  40. if (#x ~= 8) then return '' end
  41. local c=0
  42. for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end
  43. return string.char(c)
  44. end))
  45. end
  46.  
  47.  
  48.  
  49. local opt = {}
  50. opt["gpu"] = 0
  51. opt["backend"] = 'cpu'
  52. opt["use_cudnn"] = 0
  53. opt["cudnn_benchmark"] = 0
  54. opt["image_size"] = 768
  55. opt["median_filter"] = 3
  56. opt["timing"] = 0
  57. opt["input_image"] = 'saveData/input.png'
  58. opt["output_image"] = 'saveData/output.png'
  59.  
  60. opt["styleNums"] = 1
  61. style_images = {}
  62.  
  63. style_images[0] = 'themes/1.jpg'
  64. style_images[1] = 'themes/2.jpg'
  65. style_images[2] = 'themes/3.jpg'
  66. style_images[3] = 'themes/4.jpg'
  67. style_images[4] = 'themes/5.jpg'
  68. style_images[5] = 'themes/6.jpg'
  69. style_images[6] = 'themes/7.jpg'
  70. style_images[7] = 'themes/8.jpg'
  71. opt["style_images"] = style_images
  72.  
  73.  
  74.  
  75. modelName = "models/instance_norm/la_muse.t7"
  76.  
  77.  
  78. opt["model"] = modelName
  79.  
  80.  
  81. print(opt.model)
  82.  
  83.  
  84. local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1)
  85. local ok, checkpoint = pcall(function() return torch.load(opt.model) end)
  86. if not ok then
  87. print('ERROR: Could not load model from ' .. opt.model)
  88. print('You may need to download the pretrained models by running')
  89. print('bash models/download_style_transfer_models.sh')
  90. return
  91. end
  92. local model = checkpoint.model
  93. model:evaluate()
  94. model:type(dtype)
  95. if use_cudnn then
  96. cudnn.convert(model, cudnn)
  97. if opt.cudnn_benchmark == 0 then
  98. cudnn.benchmark = false
  99. cudnn.fastest = true
  100. end
  101. end
  102.  
  103. local preprocess_method = checkpoint.opt.preprocessing or 'vgg'
  104. local preprocess = preprocess[preprocess_method]
  105.  
  106. local function deepCopy(tbl)
  107. -- creates a copy of a network with new modules and the same tensors
  108. local copy = {}
  109. for k, v in pairs(tbl) do
  110. if type(v) == 'table' then
  111. copy[k] = deepCopy(v)
  112. else
  113. copy[k] = v
  114. end
  115. end
  116. if torch.typename(tbl) then
  117. torch.setmetatable(copy, torch.typename(tbl))
  118. end
  119. return copy
  120. end
  121.  
  122. local function run_image(in_path, out_path)
  123. local img = image.load(in_path, 3)
  124. if opt.image_size > 0 then
  125. img = image.scale(img, opt.image_size)
  126. end
  127. local H, W = img:size(2), img:size(3)
  128.  
  129. local img_pre = preprocess.preprocess(img:view(1, 3, H, W)):type(dtype)
  130. local timer = nil
  131. if opt.timing == 1 then
  132. -- Do an extra forward pass to warm up memory and cuDNN
  133. model:forward(img_pre)
  134. timer = torch.Timer()
  135. if cutorch then cutorch.synchronize() end
  136. end
  137. local img_out = model:forward(img_pre)
  138. if opt.timing == 1 then
  139. if cutorch then cutorch.synchronize() end
  140. local time = timer:time().real
  141. print(string.format('Image %s (%d x %d) took %f',
  142. in_path, H, W, time))
  143. end
  144. local img_out = preprocess.deprocess(img_out)[1]
  145.  
  146. if opt.median_filter > 0 then
  147. img_out = utils.median_filter(img_out, opt.median_filter)
  148. end
  149.  
  150. print('Writing output image to ' .. out_path)
  151. local out_dir = paths.dirname(out_path)
  152. if not path.isdir(out_dir) then
  153. paths.mkdir(out_dir)
  154. end
  155. image.save(out_path, img_out)
  156. end
  157.  
  158. local json = require "cjson"
  159. -- th fast_neural_style.lua -model checkpoint.t7 -input_image images/content/chicago.jpg -output_image out.png -gpu 0
  160.  
  161. -- this callback is called, whenever a new client connects.
  162. -- ws is a new websocket instance
  163. local echo_handler = function(ws)
  164. ws:on_message(function(ws,message)
  165. --print(message)
  166. print("msg Recived")
  167.  
  168. local lua_table = json.decode(message)
  169. --print(lua_table)
  170. --local lua_table = JSON:decode(message)
  171. print("json decoded in lua table")
  172. --print(lua_table)
  173. --ws:send(message)
  174. --print(lua_value)
  175. for k, v in pairs(lua_table) do
  176.  
  177. if k == 'init' then
  178. --print(v)
  179. print('intial connect -- Sending style images ')
  180. local tempTable = {}
  181. tempTable["numOfStyles"] = opt.styleNums
  182. local middleRes = json.encode(tempTable)
  183. --print(responseJson)
  184. ws:send(middleRes)
  185.  
  186. for i=1,opt.styleNums do
  187. print(i)
  188. print(opt.style_images[i-1])
  189.  
  190. local inp = assert(io.open(opt.style_images[i-1], "rb"))
  191. local data = inp:read("*all")
  192.  
  193. local outTable = {}
  194. outTable["styleNum"] = i
  195. outTable["base64"] = enc(data)
  196.  
  197. local responseJson = json.encode(outTable)
  198. --print(responseJson)
  199. ws:send(responseJson)
  200.  
  201.  
  202. end
  203. --local inp = assert(io.open(opt.style_images[0], "rb"))
  204. --local data = inp:read("*all")
  205.  
  206. -- outTable = {}
  207. -- outTable["styleImages"] = 1
  208. -- outTable["base64"] = enc(data)
  209. -- local responseJson = json.encode(outTable)
  210. -- --print(responseJson)
  211. -- ws:send(responseJson)
  212.  
  213.  
  214. --with open("imageToSave.png", "wb") as fh:
  215. -- fh.write(imgData.decode(b64String))
  216.  
  217. end
  218.  
  219.  
  220. if k == 'base64' then
  221. --print(v)
  222. print('msg has image of Type : ')
  223.  
  224. print(lua_table['type'])
  225. print('Style to be applied Number: '.. lua_table['styleNum'] )
  226.  
  227. function getUTCtimestamp()
  228. --s="Sat, 29 Oct 1994 19:43:31 GMT"
  229. -- Fri May 19 14:00:12 2017
  230. -- Fri May 19 14:15:57 2017
  231. local dateString=os.date("%Y%m%d%H%M%S")
  232. print(dateString)
  233. --p="%a+, (%d+) (%a+) (%d+) (%d+):(%d+):(%d+) GMT"
  234. --local p="%a+ %a+ %d+ %d+:%d+:%d+ %d+"
  235. --local weekday month,day,hour,min,sec,year=string:match(dateString,p)
  236. --print(weekday,month,day,hour,min,sec,year)
  237. --local MON={Jan=1,Feb=2,Mar=3,Apr=4,May=5,Jun=6,Jul=7,Aug=8,Sep=9,Oct=10,Nov=11,Dec=12}
  238. --month=MON[month]
  239. --local stamp = month..day..hour..min..sec..year ;
  240. return dateString
  241.  
  242. --offset=os.time()-os.time(os.date("!*t"))
  243. --return offset
  244. --return os.time({day=day,month=month,year=year,hour=hour,min=min,sec=sec})+offset
  245. end
  246. local uniqueName = tostring( getUTCtimestamp() )
  247. print(uniqueName)
  248. opt.input_image = "saveData/input_" .. uniqueName .. ".png"
  249. opt.output_image = "saveData/output_" .. uniqueName .. ".png"
  250.  
  251. local out = assert(io.open(opt.input_image, "wb"))
  252.  
  253. --local data = inp:read("*all")
  254. --data = string.gsub(data, "\r\n", "\n")
  255. out:write(dec(v))
  256. assert(out:close())
  257. print("Image saved")
  258.  
  259.  
  260. print(opt)
  261.  
  262. run_image(opt.input_image, opt.output_image)
  263.  
  264. --local inp = assert(io.open(opt.output_image, "rb"))
  265. inp = io.open(opt.output_image, "rb")
  266. local data2 = inp:read("*all")
  267. local data = 1
  268.  
  269. local outTable = {}
  270. outTable["styleNum"] = 1
  271. outTable["base64"] = enc(data)
  272. local responseJson = json.encode(outTable)
  273. --print(responseJson)
  274. ws:send(responseJson)
  275.  
  276.  
  277. --with open("imageToSave.png", "wb") as fh:
  278. -- fh.write(imgData.decode(b64String))
  279.  
  280. end
  281. end
  282.  
  283. --print(lua_value[source])
  284. end)
  285. end
  286.  
  287. -- create a copas webserver and start listening
  288. local server = require'websocket'.server.ev.listen
  289. {
  290. -- listen on port 8080
  291. port = 9002,
  292. -- the protocols field holds
  293. -- key: protocol name
  294. -- value: callback on new connection
  295. protocols = {
  296. echo = echo_handler
  297. },
  298. default = echo_handler
  299. }
  300.  
  301. -- use the lua-ev loop
  302. ev.Loop.default:loop()
  303.  
  304.  
  305.  
  306.  
  307.  
  308. if opt.input_dir ~= '' then
  309. if opt.output_dir == '' then
  310. error('Must give -output_dir with -input_dir')
  311. end
  312. for fn in paths.files(opt.input_dir) do
  313. if utils.is_image_file(fn) then
  314. local in_path = paths.concat(opt.input_dir, fn)
  315. local out_path = paths.concat(opt.output_dir, fn)
  316. run_image(in_path, out_path)
  317. end
  318. end
  319. elseif opt.input_image ~= '' then
  320. if opt.output_image == '' then
  321. error('Must give -output_image with -input_image')
  322. end
  323. run_image(opt.input_image, opt.output_image)
  324. end
  325.  
  326.  
  327. --local inp = assert(io.open('out.png', "rb"))
  328. --local out = assert(io.open('out2.png', "wb"))
  329.  
  330. --local data = inp:read("*all")
  331. --data = string.gsub(data, "\r\n", "\n")
  332. --out:write(data)
  333. --assert(out:close())
Add Comment
Please, Sign In to add comment