Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- require 'torch'
- require 'nn'
- require 'image'
- require 'fast_neural_style.ShaveImage'
- require 'fast_neural_style.TotalVariation'
- require 'fast_neural_style.InstanceNormalization'
- local utils = require 'fast_neural_style.utils'
- local preprocess = require 'fast_neural_style.preprocess'
- local ev = require'ev'
- --JSON = (loadfile "JSON.lua")() -- one-time load of the routines
- local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' -- You will need this for encoding/decoding
- -- encoding
- function enc(data)
- return ((data:gsub('.', function(x)
- local r,b='',x:byte()
- for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end
- return r;
- end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x)
- if (#x < 6) then return '' end
- local c=0
- for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end
- return b:sub(c+1,c+1)
- end)..({ '', '==', '=' })[#data%3+1])
- end
- -- decoding
- function dec(data)
- data = string.gsub(data, '[^'..b..'=]', '')
- return (data:gsub('.', function(x)
- if (x == '=') then return '' end
- local r,f='',(b:find(x)-1)
- for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end
- return r;
- end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x)
- if (#x ~= 8) then return '' end
- local c=0
- for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end
- return string.char(c)
- end))
- end
- local opt = {}
- opt["gpu"] = 0
- opt["backend"] = 'cpu'
- opt["use_cudnn"] = 0
- opt["cudnn_benchmark"] = 0
- opt["image_size"] = 768
- opt["median_filter"] = 3
- opt["timing"] = 0
- opt["input_image"] = 'saveData/input.png'
- opt["output_image"] = 'saveData/output.png'
- opt["styleNums"] = 1
- style_images = {}
- style_images[0] = 'themes/1.jpg'
- style_images[1] = 'themes/2.jpg'
- style_images[2] = 'themes/3.jpg'
- style_images[3] = 'themes/4.jpg'
- style_images[4] = 'themes/5.jpg'
- style_images[5] = 'themes/6.jpg'
- style_images[6] = 'themes/7.jpg'
- style_images[7] = 'themes/8.jpg'
- opt["style_images"] = style_images
- modelName = "models/instance_norm/la_muse.t7"
- opt["model"] = modelName
- print(opt.model)
- local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1)
- local ok, checkpoint = pcall(function() return torch.load(opt.model) end)
- if not ok then
- print('ERROR: Could not load model from ' .. opt.model)
- print('You may need to download the pretrained models by running')
- print('bash models/download_style_transfer_models.sh')
- return
- end
- local model = checkpoint.model
- model:evaluate()
- model:type(dtype)
- if use_cudnn then
- cudnn.convert(model, cudnn)
- if opt.cudnn_benchmark == 0 then
- cudnn.benchmark = false
- cudnn.fastest = true
- end
- end
- local preprocess_method = checkpoint.opt.preprocessing or 'vgg'
- local preprocess = preprocess[preprocess_method]
- local function deepCopy(tbl)
- -- creates a copy of a network with new modules and the same tensors
- local copy = {}
- for k, v in pairs(tbl) do
- if type(v) == 'table' then
- copy[k] = deepCopy(v)
- else
- copy[k] = v
- end
- end
- if torch.typename(tbl) then
- torch.setmetatable(copy, torch.typename(tbl))
- end
- return copy
- end
- local function run_image(in_path, out_path)
- local img = image.load(in_path, 3)
- if opt.image_size > 0 then
- img = image.scale(img, opt.image_size)
- end
- local H, W = img:size(2), img:size(3)
- local img_pre = preprocess.preprocess(img:view(1, 3, H, W)):type(dtype)
- local timer = nil
- if opt.timing == 1 then
- -- Do an extra forward pass to warm up memory and cuDNN
- model:forward(img_pre)
- timer = torch.Timer()
- if cutorch then cutorch.synchronize() end
- end
- local img_out = model:forward(img_pre)
- if opt.timing == 1 then
- if cutorch then cutorch.synchronize() end
- local time = timer:time().real
- print(string.format('Image %s (%d x %d) took %f',
- in_path, H, W, time))
- end
- local img_out = preprocess.deprocess(img_out)[1]
- if opt.median_filter > 0 then
- img_out = utils.median_filter(img_out, opt.median_filter)
- end
- print('Writing output image to ' .. out_path)
- local out_dir = paths.dirname(out_path)
- if not path.isdir(out_dir) then
- paths.mkdir(out_dir)
- end
- image.save(out_path, img_out)
- end
- local json = require "cjson"
- -- th fast_neural_style.lua -model checkpoint.t7 -input_image images/content/chicago.jpg -output_image out.png -gpu 0
- -- this callback is called, whenever a new client connects.
- -- ws is a new websocket instance
- local echo_handler = function(ws)
- ws:on_message(function(ws,message)
- --print(message)
- print("msg Recived")
- local lua_table = json.decode(message)
- --print(lua_table)
- --local lua_table = JSON:decode(message)
- print("json decoded in lua table")
- --print(lua_table)
- --ws:send(message)
- --print(lua_value)
- for k, v in pairs(lua_table) do
- if k == 'init' then
- --print(v)
- print('intial connect -- Sending style images ')
- local tempTable = {}
- tempTable["numOfStyles"] = opt.styleNums
- local middleRes = json.encode(tempTable)
- --print(responseJson)
- ws:send(middleRes)
- for i=1,opt.styleNums do
- print(i)
- print(opt.style_images[i-1])
- local inp = assert(io.open(opt.style_images[i-1], "rb"))
- local data = inp:read("*all")
- local outTable = {}
- outTable["styleNum"] = i
- outTable["base64"] = enc(data)
- local responseJson = json.encode(outTable)
- --print(responseJson)
- ws:send(responseJson)
- end
- --local inp = assert(io.open(opt.style_images[0], "rb"))
- --local data = inp:read("*all")
- -- outTable = {}
- -- outTable["styleImages"] = 1
- -- outTable["base64"] = enc(data)
- -- local responseJson = json.encode(outTable)
- -- --print(responseJson)
- -- ws:send(responseJson)
- --with open("imageToSave.png", "wb") as fh:
- -- fh.write(imgData.decode(b64String))
- end
- if k == 'base64' then
- --print(v)
- print('msg has image of Type : ')
- print(lua_table['type'])
- print('Style to be applied Number: '.. lua_table['styleNum'] )
- function getUTCtimestamp()
- --s="Sat, 29 Oct 1994 19:43:31 GMT"
- -- Fri May 19 14:00:12 2017
- -- Fri May 19 14:15:57 2017
- local dateString=os.date("%Y%m%d%H%M%S")
- print(dateString)
- --p="%a+, (%d+) (%a+) (%d+) (%d+):(%d+):(%d+) GMT"
- --local p="%a+ %a+ %d+ %d+:%d+:%d+ %d+"
- --local weekday month,day,hour,min,sec,year=string:match(dateString,p)
- --print(weekday,month,day,hour,min,sec,year)
- --local MON={Jan=1,Feb=2,Mar=3,Apr=4,May=5,Jun=6,Jul=7,Aug=8,Sep=9,Oct=10,Nov=11,Dec=12}
- --month=MON[month]
- --local stamp = month..day..hour..min..sec..year ;
- return dateString
- --offset=os.time()-os.time(os.date("!*t"))
- --return offset
- --return os.time({day=day,month=month,year=year,hour=hour,min=min,sec=sec})+offset
- end
- local uniqueName = tostring( getUTCtimestamp() )
- print(uniqueName)
- opt.input_image = "saveData/input_" .. uniqueName .. ".png"
- opt.output_image = "saveData/output_" .. uniqueName .. ".png"
- local out = assert(io.open(opt.input_image, "wb"))
- --local data = inp:read("*all")
- --data = string.gsub(data, "\r\n", "\n")
- out:write(dec(v))
- assert(out:close())
- print("Image saved")
- print(opt)
- run_image(opt.input_image, opt.output_image)
- --local inp = assert(io.open(opt.output_image, "rb"))
- inp = io.open(opt.output_image, "rb")
- local data2 = inp:read("*all")
- local data = 1
- local outTable = {}
- outTable["styleNum"] = 1
- outTable["base64"] = enc(data)
- local responseJson = json.encode(outTable)
- --print(responseJson)
- ws:send(responseJson)
- --with open("imageToSave.png", "wb") as fh:
- -- fh.write(imgData.decode(b64String))
- end
- end
- --print(lua_value[source])
- end)
- end
- -- create a copas webserver and start listening
- local server = require'websocket'.server.ev.listen
- {
- -- listen on port 8080
- port = 9002,
- -- the protocols field holds
- -- key: protocol name
- -- value: callback on new connection
- protocols = {
- echo = echo_handler
- },
- default = echo_handler
- }
- -- use the lua-ev loop
- ev.Loop.default:loop()
- if opt.input_dir ~= '' then
- if opt.output_dir == '' then
- error('Must give -output_dir with -input_dir')
- end
- for fn in paths.files(opt.input_dir) do
- if utils.is_image_file(fn) then
- local in_path = paths.concat(opt.input_dir, fn)
- local out_path = paths.concat(opt.output_dir, fn)
- run_image(in_path, out_path)
- end
- end
- elseif opt.input_image ~= '' then
- if opt.output_image == '' then
- error('Must give -output_image with -input_image')
- end
- run_image(opt.input_image, opt.output_image)
- end
- --local inp = assert(io.open('out.png', "rb"))
- --local out = assert(io.open('out2.png', "wb"))
- --local data = inp:read("*all")
- --data = string.gsub(data, "\r\n", "\n")
- --out:write(data)
- --assert(out:close())
Add Comment
Please, Sign In to add comment