Advertisement
Guest User

Untitled

a guest
Sep 24th, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.34 KB | None | 0 0
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3.  
  4. import argparse
  5. import sys
  6. import random
  7. import numpy as np
  8. from collections import deque
  9.  
  10. import json
  11. import socket
  12.  
  13. from keras.initializers import normal, identity
  14. from keras.models import model_from_json
  15. from keras.models import Sequential
  16. from keras.layers.core import Dense, Dropout, Activation, Flatten
  17. from keras.optimizers import SGD , Adam
  18. #from keras.utils import plot_model
  19. import tensorflow as tf
  20.  
  21. GAME = 'glidethroughthesky' # the name of the game being played for log files
  22. ACTIONS = 3 # number of valid actions
  23. GAMMA = 0.99 # decay rate of past observations
  24. OBSERVATION = 50
  25. EXPLORE = 3000000. # frames over which to anneal epsilon
  26. FINAL_EPSILON = 0.0001 # final value of epsilon
  27. INITIAL_EPSILON = 0.01 # starting value of epsilon
  28. REPLAY_MEMORY = 50000 # number of previous transitions to remember
  29. BATCH = 32 # size of minibatch
  30. FRAME_PER_ACTION = 1
  31. LEARNING_RATE = 1e-4
  32.  
  33. ### init values
  34. features = {
  35. "targetYDistance": 25,
  36. "midYDistance": 25,
  37. # "targetDistance": 999,
  38. # "midDistance": 999,
  39. "distance": 25,
  40. "vY": 0,
  41. "mana": 3
  42. }
  43.  
  44. passedWall = False
  45. dashTarget = False
  46. died = False
  47. win = False
  48.  
  49. passedWalls = 0
  50.  
  51. client = None
  52. address = None
  53.  
  54. def buildmodel():
  55. global features
  56. model = Sequential()
  57. model.add(Dense(units = 32, input_dim = len(features), activation = 'relu'))
  58. model.add(Dense(units = 32, input_dim = len(features), activation = 'relu'))
  59. model.add(Dense(units = ACTIONS, activation = 'softmax'))
  60.  
  61. adam = Adam(lr=LEARNING_RATE)
  62. model.compile(loss='mse',optimizer=adam)
  63. model.summary()
  64. return model
  65.  
  66. def standardize(features):
  67. f = features.copy()
  68. f["mana"] = f["mana"]/10
  69. f["distance"] = (f["distance"])/25
  70. f["targetYDistance"] = (f["targetYDistance"]-8)/16
  71. f["midYDistance"] = (f["midYDistance"]-8)/16
  72. f["vY"] = (f["vY"]+20)/40
  73.  
  74. #print(json.dumps(f))
  75. return f
  76.  
  77. def updateFeatures(dataMerge, died, win, dashTarget, passedWall):
  78. global client
  79. size = 8192
  80.  
  81. loadedData = client.recv(size).decode("utf-8")
  82. endIndex = loadedData[:loadedData.find('}')]
  83. dataMerge += loadedData
  84.  
  85. if endIndex == -1:
  86. return
  87.  
  88. try:
  89. data = json.loads(dataMerge[:dataMerge.find('}')+1])
  90. dataMerge = ""
  91.  
  92. except ValueError:
  93. return;
  94.  
  95.  
  96. if "targetYDistance" in data:
  97. features["targetYDistance"] = float(data["targetYDistance"])
  98.  
  99. if "midYDistance" in data:
  100. features["midYDistance"] = float(data["midYDistance"])
  101.  
  102. if "distance" in data:
  103. features["distance"] = float(data["distance"])
  104.  
  105. # if "targetDistance" in data:
  106. # features["targetDistance"] = float(data["targetDistance"])
  107. #
  108. # if "midDistance" in data:
  109. # features["midDistance"] = float(data["midDistance"])
  110.  
  111. if "mana" in data:
  112. features["mana"] = float(data["mana"])
  113.  
  114. if "vY" in data:
  115. features["vY"] = float(data["vY"])
  116.  
  117.  
  118.  
  119. if "died" in data:
  120. died = True
  121.  
  122. if "win" in data:
  123. win = True
  124.  
  125. if "passedWall" in data:
  126. passedWall = True
  127.  
  128. if "dashTarget" in data:
  129. dashTarget = True
  130.  
  131. return died, win, dashTarget, passedWall
  132.  
  133. def waitForFeatures():
  134. oldFeatures = features
  135.  
  136. died = False
  137. win = False
  138. dashTarget = False
  139. passedWall = False
  140. global passedWalls
  141.  
  142. oldFeatures = features
  143. while True:
  144. dataMerge = ""
  145. died, win, dashTarget, passedWall = updateFeatures(dataMerge, died, win, dashTarget, passedWall)
  146. for key, value in oldFeatures.items():
  147. if features[key] == value:
  148. continue
  149. break
  150.  
  151. terminal = False
  152. reward = 0
  153.  
  154. if passedWall:
  155. print("passedWall")
  156. reward = 1
  157. passedWalls += 1
  158. else:
  159. if oldFeatures["mana"] > features["mana"] and dashTarget == False:
  160. reward = -2
  161. if oldFeatures["mana"] == 1 and features["mana"] == 0:
  162. reward = -5
  163.  
  164. if dashTarget:
  165. print("dashTarget")
  166. reward = 2
  167.  
  168. if oldFeatures["mana"] == 1 and features["mana"] == 0:
  169. reward = -2
  170.  
  171. if features["mana"] == -1:
  172. reward = -10
  173.  
  174. if died:
  175. print("------------ DIED ------------")
  176. reward = -1 - 10 * abs(oldFeatures["midYDistance"]) - 10*features["mana"]
  177. terminal = True
  178.  
  179. if win:
  180. print(" ** ------------ --- ------------")
  181. print(" ** ------------ WIN ------------")
  182. print(" ** ------------ --- ------------")
  183.  
  184. if died or win:
  185. with open("log.txt", "a") as outfile:
  186. outfile.write(str(passedWalls) + "\r\n")
  187. passedWalls = 0
  188.  
  189. #print ("\t" + json.dumps(features))
  190.  
  191. return standardize(features), reward, terminal
  192.  
  193. def sendAction(a_t):
  194. global client
  195.  
  196. if a_t[1] == 1:
  197. client.send(("dash\n").encode())
  198. if a_t[2] == 1:
  199. client.send(("jump\n").encode())
  200.  
  201. def trainNetwork(model,args):
  202. ### # store the previous observations in replay memory
  203. D = deque()
  204.  
  205. do_nothing = np.zeros(ACTIONS)
  206. do_nothing[0] = 1
  207. sendAction(do_nothing)
  208. s_t, r_0, terminal = waitForFeatures()
  209.  
  210. OBSERVE = OBSERVATION #mislim da mi ovde treba 0, posto nema konvolutivnu mrezu
  211. if args['mode'] == 'Load':
  212. epsilon = FINAL_EPSILON
  213. epsilon = 0.008
  214. # epsilon = INITIAL_EPSILON
  215. print ("Now we load weight")
  216. model.load_weights("model.h5")
  217. adam = Adam(lr=LEARNING_RATE)
  218. model.compile(loss='mse',optimizer=adam)
  219. print ("Weight load successfully")
  220. else: #We go to training mode
  221. epsilon = INITIAL_EPSILON
  222.  
  223. t = 0
  224. while (True):
  225. loss = 0
  226. Q_sa = 0
  227. action_index = 0
  228. r_t = 0
  229. a_t = np.zeros([ACTIONS])
  230. #choose an action epsilon greedy
  231. if t % FRAME_PER_ACTION == 0:
  232. if random.random() <= epsilon:
  233. action_index = random.randrange(ACTIONS)
  234. a_t[action_index] = 1
  235. print("----------Random Action: ", a_t , " ---------- " , epsilon)
  236. else:
  237. p = np.fromiter(s_t.values(), float).reshape((1, len(features)))
  238. q = model.predict(p)
  239. max_Q = np.argmax(q)
  240. action_index = max_Q
  241. a_t[max_Q] = 1
  242.  
  243. #We reduced the epsilon gradually
  244. if epsilon > FINAL_EPSILON and t > OBSERVE:
  245. epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
  246.  
  247. # global features
  248. old_features = features
  249. sendAction(a_t)
  250. s_t1, r_t, terminal = waitForFeatures()
  251.  
  252. D.append((np.fromiter(old_features.values(), float).reshape((1, len(features))), action_index, r_t, (np.fromiter(s_t1.values(), float).reshape((1, len(features)))), terminal))
  253. if len(D) > REPLAY_MEMORY:
  254. D.popleft()
  255.  
  256. #only train if done observing
  257. if t > OBSERVE:
  258. #sample a minibatch to train on
  259. minibatch = random.sample(D, BATCH)
  260.  
  261. #Now we do the experience replay
  262. state_t, action_t, reward_t, state_t1, terminal = zip(*minibatch)
  263. state_t = np.concatenate(state_t)
  264. state_t1 = np.concatenate(state_t1)
  265. targets = model.predict(state_t)
  266. Q_sa = model.predict(state_t1)
  267. targets[range(BATCH), action_t] = reward_t + GAMMA*np.max(Q_sa, axis=1)*np.invert(terminal)
  268.  
  269. loss += model.train_on_batch(state_t, targets)
  270.  
  271. s_t = s_t1
  272. t = t + 1
  273.  
  274. # save progress every 10000 iterations
  275. if t % 10000 == 0:
  276. print("Now we save model")
  277. model.save_weights("model.h5", overwrite=True)
  278. with open("model.json", "w") as outfile:
  279. json.dump(model.to_json(), outfile)
  280.  
  281. # print info
  282. state = ""
  283. if t <= OBSERVE:
  284. state = "observe"
  285. elif t > OBSERVE and t <= OBSERVE + EXPLORE:
  286. state = "explore"
  287. else:
  288. state = "train"
  289.  
  290. #print("TIMESTEP", t, "/ STATE", state, \
  291. # "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \
  292. # "/ Q_MAX " , np.max(Q_sa), "/ Loss ", loss)
  293.  
  294. print("Episode finished!")
  295. print("************************")
  296.  
  297. def main():
  298. parser = argparse.ArgumentParser(description='Description of your program')
  299. parser.add_argument('-m','--mode', help='Train / Load', required=False)
  300. args = vars(parser.parse_args())
  301.  
  302. model = buildmodel()
  303.  
  304. host = ''
  305. port = 50000
  306. backlog = 5
  307. size = 4096
  308. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  309. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  310. s.bind((host,port))
  311. s.listen(backlog)
  312.  
  313. global client
  314. global address
  315.  
  316. client, address = s.accept()
  317. # while client != None:
  318. # client, address = s.accept() # TODO WAIT
  319. print("Client connected.")
  320. client.send(("Hello!\n").encode())
  321.  
  322. trainNetwork(model,args)
  323.  
  324. if __name__ == "__main__":
  325. config = tf.ConfigProto()
  326. config.gpu_options.allow_growth = True
  327. sess = tf.Session(config=config)
  328. from keras import backend as K
  329. K.set_session(sess)
  330. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement