Advertisement
Guest User

Untitled

a guest
Apr 26th, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.23 KB | None | 0 0
  1. #======================================================================
  2. # DDQN in Grid World
  3. # Author : Junmo Nam
  4. # Reference : Van Hasselt, Hado, Arthur Guez, and David Silver.
  5. # "Deep reinforcement learning with double q-learning."
  6. # Thirtieth AAAI Conference on Artificial Intelligence. 2016.
  7. # BEGAS
  8. #======================================================================
  9.  
  10.  
  11. #======================================================================
  12. # Environment setting
  13. #======================================================================
  14.  
  15. #step function for give information after do action
  16. step = function(state,action){
  17.  
  18. #x and y in state
  19. x = state$x
  20. y = state$y
  21.  
  22. #do action
  23. if(x==1 & action == 'up'| x== nrow(grid) & action=='down'|
  24. y==1 & action == 'left'| y==ncol(grid) & action == 'right'){
  25. #action not allowed : penalty
  26. reward = -1.5
  27. next_state = state
  28. }else if(action == 'up'){
  29. reward = grid[x-1,y]
  30. next_state = list(x = x-1,y = y)
  31. }else if(action == 'down'){
  32. reward = grid[x+1,y]
  33. next_state = list(x = x+1,y = y)
  34. }else if(action == 'left'){
  35. reward = grid[x,y-1]
  36. next_state = list(x = x,y = y-1)
  37. }else{
  38. reward = grid[x,y+1]
  39. next_state = list(x = x,y = y+1)
  40. }
  41.  
  42. if(reward==-1){ #face wall : go back
  43. reward = -1.75
  44. next_state = state
  45. }
  46.  
  47. return(list(next_state = next_state,reward = reward))
  48. }
  49.  
  50. #get action from given state
  51. get_action = function(state,epsilon,qtable,actions){
  52.  
  53. ifelse(sample(0:1,1,prob = c(epsilon,1-epsilon))==1, #epsilon check
  54. as.character((qtable %>% filter(x==state$x,y==state$y) %>% filter(q == max(q)) %>% sample_n(1))$action), #do max q action
  55. sample(actions,1))#random action
  56. }
  57.  
  58.  
  59. #Define built model function for training q (DQN)
  60. build_model = function(state_size, action_size,learning_rate){
  61. model <- keras_model_sequential() %>%
  62. # 1st layer
  63. layer_dense(input_shape = state_size,30,kernel_initializer = 'he_uniform',activation = 'relu') %>%
  64. # 2nd layer
  65. layer_dense(30, kernel_initializer = 'he_uniform',activation = 'relu') %>%
  66. # 3rd layer : output
  67. layer_dense(action_size,activation = 'linear') %>%
  68. #compiling model
  69. compile(optimizer = optimizer_adam(lr = learning_rate),
  70. loss = 'mse')
  71. }
  72.  
  73. #training model function : DDQN
  74. double_trainer = function(model,memory,target_model,batch_size,discount,qtable){
  75.  
  76. #define sample(mini_batch) and empty states
  77. mini_batch = memory %>% sample_n(batch_size)
  78.  
  79. #saving values : SARS
  80. states = mini_batch[,1:2] #states vars
  81. names(states) = c('x','y')
  82. actions = mini_batch %>% select(action)
  83. rewards = mini_batch %>% select(reward)
  84. next_states = mini_batch %>% select(grep('next_state',names(.)))
  85. names(next_states) = c('x','y')
  86.  
  87. #get dones
  88. dones = mini_batch %>% select(done)
  89.  
  90. #make q(s) and q(s')
  91. target = predict(model,data.matrix(states))
  92. target_val = predict(target_model,data.matrix(next_states))
  93.  
  94. #Bellman equation for updating target
  95. for(i in 1:batch_size){
  96. ifelse(dones[i,],
  97. target[i,which(c('up','down','left','right')==actions[i,])] <- rewards[i,], # when process is done, column index for given action
  98. #update target by bellman equation, column index for given action
  99. # using DDQN method : R_t+1 + discount*Q(s', argmax(Q(s'))))
  100. target[i,which(c('up','down','left','right')==actions[i,])] <- rewards[i,] +
  101. discount*(filter(qtable,x==next_states[i,1],y==next_states[i,2],
  102. action ==c('up','down','left','right')[which.max(target_val[i,])]) %>% sample_n(1))$q
  103. )
  104. }
  105. return(list(states_batch = data.matrix(states),target = target)) #batch components for fit model
  106. }
  107.  
  108. #training model function : DQN
  109. single_trainer = function(model,memory,target_model,batch_size,discount,qtable){
  110.  
  111. #define sample(mini_batch) and empty states
  112. mini_batch = memory %>% sample_n(batch_size)
  113.  
  114. #saving values : SARS
  115. states = mini_batch[,1:2] #states vars
  116. names(states) = c('x','y')
  117. actions = mini_batch %>% select(action)
  118. rewards = mini_batch %>% select(reward)
  119. next_states = mini_batch %>% select(grep('next_state',names(.)))
  120. names(next_states) = c('x','y')
  121.  
  122. #get dones
  123. dones = mini_batch %>% select(done)
  124.  
  125. #make q(s) and q(s')
  126. target = predict(model,data.matrix(states))
  127. target_val = predict(target_model,data.matrix(next_states))
  128.  
  129. #Bellman equation for updating target
  130. for(i in 1:batch_size){
  131. ifelse(dones[i,],
  132. target[i,which(c('up','down','left','right')==actions[i,])] <- rewards[i,], # when process is done, column index for given action
  133. #update target by bellman equation, column index for given action
  134. # using DDQN method : R_t+1 + discount*Q(s', argmax(Q(s'))))
  135. target[i,which(c('up','down','left','right')==actions[i,])] <- rewards[i,] + discount*max(target_val[i,])
  136. )
  137. }
  138. return(list(states_batch = data.matrix(states),target = target)) #batch components for fit model
  139. }
  140.  
  141.  
  142. #======================================================================
  143. # DDQN Agent
  144. #======================================================================
  145.  
  146.  
  147. #training agent function
  148. double_dqn_agent = function(episode,params,grid,model_weight = NULL,double = T){
  149.  
  150. #check trainer will be DDQN or DQN
  151. if(double){
  152. trainer = double_trainer
  153. }else{
  154. trainer = single_trainer
  155. }
  156.  
  157.  
  158. #claim parameters from params
  159. discount = params$discount
  160. learning_rate = params$learning_rate
  161. batch_size = params$batch_size
  162. epsilon = params$epsilon
  163. epsilon_decay = params$epsilon_decay
  164. epsilon_min = params$epsilon_min
  165. train_start = params$train_start
  166. max_memory = params$max_memory
  167.  
  168. #calculate grid size
  169. grid_size = dim(grid) %>% prod
  170.  
  171. # set memory dataframe
  172. memory = data.frame()
  173.  
  174. #allow 4 actions
  175. actions = c('up','down','left','right')
  176.  
  177. #claim size of action and state
  178. state_size = 2
  179. action_size = length(actions)
  180.  
  181. #initiate model
  182. model = build_model(state_size,action_size,learning_rate)
  183. target_model = build_model(state_size,action_size,learning_rate)
  184.  
  185. #set models' weight as same
  186. set_weights(target_model,get_weights(model))
  187.  
  188. #optional : if preset for model weight exists
  189. if(!is.null(model_weight)){
  190. set_weights(model,model_weight)
  191. set_weights(target_model,model_weight)
  192. }
  193.  
  194. #claim q table
  195. qtable = expand.grid(x = 1:nrow(grid),y = 1:ncol(grid),action = actions,q = 0)
  196.  
  197.  
  198. #recording train loss
  199. loss = c()
  200.  
  201. #recording score
  202. scores = c()
  203.  
  204. #recording success frequency
  205. sfreq = c()
  206.  
  207. #episode loop
  208. for(i in 1:episode){
  209.  
  210. #set first value
  211. done = F
  212. score = 0
  213.  
  214. #reset state and get first action
  215. if(i %% 10 == 0){ #start at first position
  216. state = list(x = 1, y = 1)
  217. }else{ #random initiate
  218. state = list(x = sample(1:nrow(grid),1),y = sample(1:ncol(grid),1))
  219. }
  220.  
  221. #small memory for give agent penalty
  222. small_memory = c()
  223.  
  224. n_step = 0
  225. repeat{ #loop a single episode
  226.  
  227. #update small memory
  228. small_memory = append(small_memory,state %>% paste(collapse = ''))
  229.  
  230. #start action by given state
  231. action = get_action(state,epsilon,qtable,actions)
  232.  
  233. #step by given action and state
  234. time_step =step(state,action)
  235.  
  236.  
  237. n_step = n_step+1
  238.  
  239. #updating values
  240. reward = ifelse(time_step$reward==1,2*grid_size,time_step$reward) - 0.04 #0.04 = cost of making movement
  241. next_state = time_step$next_state
  242. score = score + reward
  243. done = (time_step$reward==1 | -grid_size*0.5 >= score) #finish or prevent infinite loop
  244.  
  245. if(done==F & next_state %>% paste(collapse = '') %in% small_memory){
  246. reward = reward-0.9 #penalty when agent go place it went before
  247. }
  248.  
  249. #append memory data
  250. memory = data.frame(state = state,
  251. action=action,
  252. reward=reward,
  253. next_state = next_state,
  254. done=done) %>%
  255. rbind(memory,.)
  256.  
  257. #keep memory in max_memory size
  258. if(nrow(memory)>max_memory){
  259. memory = memory[-1,]
  260. }
  261.  
  262. #training model
  263. if(nrow(memory)>=train_start){ # when memory has enough data
  264. #update epsilon
  265. epsilon = ifelse(epsilon > epsilon_min,epsilon*epsilon_decay,epsilon)
  266.  
  267. #run training model and get loss
  268. fit_vars = trainer(model,memory,target_model,batch_size,discount,qtable)
  269. model %>% fit(fit_vars$states_batch, fit_vars$target,
  270. batch_size, epochs=1, verbose=0) #training
  271. train.loss = model$history$history$loss[[1]]
  272.  
  273. }
  274.  
  275. #update qtable
  276. qtable = predict(model,expand.grid(x=1:nrow(grid),y=1:ncol(grid)) %>% data.matrix) %>%
  277. data.frame %>% cbind(expand.grid(x=1:nrow(grid),y=1:ncol(grid)),.)
  278. colnames(qtable)[-(1:2)] = actions
  279. qtable = melt(qtable,c('x','y'))
  280. colnames(qtable)[3:4] = c('action','q')
  281.  
  282. #update and state
  283. state = next_state
  284.  
  285. if(done){ #when episode is done
  286. set_weights(target_model,get_weights(model))#update target model by weight of model
  287. loss[i] = ifelse(exists('train.loss'),train.loss,NA)
  288. scores[i] = score
  289. small_memory = append(small_memory,state %>% paste(collapse = '')) #update small memory
  290. break
  291. }
  292.  
  293. } #loop
  294.  
  295. #update srtae and check
  296. sfreq[i] = ifelse(time_step$reward==1,1,0)
  297.  
  298. cat('Episode ',i,'is done \n',
  299. 'Epsilon :',epsilon,' \n',
  300. 'Score :',score,' \n',
  301. 'Loss :',loss[i],' \n',
  302. 'Memory size :',nrow(memory),' \n',
  303. 'Steps :',n_step,' \n',
  304. 'Result :',ifelse(time_step$reward==1,'Success','Fail'),' \n',
  305. 'Winning rate(10)',ifelse(length(sfreq)<10,'-Less than 10',sum(tail(sfreq,10))/10),'\n',
  306. 'Overall Winning rate',sum(sfreq)/i,'\n',
  307. '----------- \n') #message that episode is done
  308.  
  309. if(sum(tail(sfreq,20))/20 >=0.95){
  310. cat("Last 20 cases' Winning rate is more than 95% : stop training \n")
  311. break
  312. }
  313.  
  314. }#episode end
  315.  
  316. return(list(loss = loss,score = scores,model = model,target_model = target_model,
  317. last_moveset = small_memory))
  318. }#function end
  319.  
  320.  
  321.  
  322. #======================================================================
  323. # Visualization
  324. #======================================================================
  325.  
  326. #draw policy by given model
  327. draw_policy_model = function(grid,qmodel,method){
  328.  
  329. #mapping destination and obstacle
  330. obstacle = which(grid==-1,arr.ind = T)
  331. destination = which(grid==1,arr.ind = T)
  332.  
  333. #allow 5 actions
  334. actions = c('up','down','left','right')
  335.  
  336. #initiate start point and qtable
  337. state = list(x= 1,y = 1)
  338. qpolicy = predict(qmodel,expand.grid(state.x = 1:nrow(grid),state.y = 1:ncol(grid)) %>% data.matrix) %>%
  339. as.data.frame %>% mutate(action = actions[max.col(.)]) %>% select(action) %>%
  340. cbind(expand.grid(x = 1:nrow(grid),y = 1:ncol(grid)),.)
  341.  
  342. #do action
  343. action = (qpolicy %>% filter(x==state$x,y==state$y))$action
  344.  
  345. if(action == 'up' | action == 'left'){
  346. stop('Your policy ask agent to do illegal movement : stop')
  347. }
  348.  
  349.  
  350. #make record dataframe
  351. record = data.frame(x=state$x,y=state$y,act = action)
  352.  
  353. #do recording by given q policy
  354. repeat{
  355.  
  356. #update state
  357. res = step(state,action)
  358. state = res$next_state
  359. done = (res$reward==1)
  360.  
  361.  
  362. if(!done){
  363.  
  364. #choose action by existing policy
  365. action = (qpolicy %>% filter(x==state$x,y==state$y))$action
  366.  
  367. #check
  368. if(state$x==nrow(grid) & action == 'down' | state$y==ncol(grid) & action == 'right' |
  369. state$x==1 & action == 'up' | state$y==1 & action == 'left'){
  370. stop('Your policy ask agent to do illegal movement : stop')
  371. }
  372.  
  373. #update record
  374. record = rbind(record,data.frame(x=state$x,y=state$y,act = action))
  375.  
  376.  
  377.  
  378. #chekc policy ask to loop or not
  379. if(TRUE %in% duplicated(record)){
  380. stop('Your Policy ask agent to go visted area : This might cause infinite loop. stop')
  381. }
  382. }else{
  383.  
  384. #update record
  385. record = rbind(record,data.frame(x=state$x,y=state$y,act = action))
  386. break
  387. }
  388.  
  389. }#end loop
  390.  
  391. #visualization : make empty grid object
  392. grid_plot = expand.grid(x=1:(ncol(grid)+1),y=1:(nrow(grid)+1)) %>% ggplot(aes(x,y))+
  393. geom_point(alpha=0)+
  394. geom_vline(xintercept = 1:(ncol(grid)+1))+
  395. geom_hline(yintercept = 1:(nrow(grid)+1))+
  396. ggtitle('Grid World Policy Viewer',
  397. subtitle = paste(method,'result'))+
  398. theme_bw()
  399.  
  400. #add obstacle and destination
  401. for(i in 1:nrow(obstacle)){
  402. x_add = obstacle[i,2]+0.5
  403. y_add = (nrow(grid)+1)-obstacle[i,1]+0.5
  404. grid_plot = grid_plot +
  405. geom_point(aes(x = !!x_add, y= !!y_add),shape = 2,color = 'dark green',size = 2)
  406. }
  407.  
  408. #add destination
  409. grid_plot = grid_plot +
  410. geom_point(aes(x = destination[1,2]+0.5,
  411. y= (nrow(grid)+1)-destination[1,1]+0.5),
  412. shape = 9,color = 'red',size = 2.5)
  413.  
  414. #update empty plot
  415. for(i in 1:nrow(record)){
  416. x_add = record[i,2]+0.5
  417. y_add = (nrow(grid)+1)-record[i,1]+0.5
  418.  
  419. grid_plot = grid_plot +
  420. geom_point(aes(x = !!x_add,y = !!y_add),
  421. color = 'blue',shape = 13, size = 1.8)
  422. }
  423.  
  424. return(grid_plot)
  425.  
  426. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement