SHARE
TWEET

Untitled

a guest Jan 21st, 2020 73 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/bin/env python2
  2. from __future__ import print_function
  3. from __future__ import absolute_import, division, print_function, unicode_literals
  4. #import matplotlib
  5. #matplotlib.use('Agg')
  6. import os
  7. import time
  8. import numpy as np
  9. from numpy import inf, random
  10. #import matplotlib.pyplot as plt
  11. import pickle
  12. import json
  13. #import robobo
  14. import cv2
  15. import sys
  16. import signal
  17. from pprint import pprint
  18. import prey
  19.  
  20. import collections
  21.  
  22. use_simulation = False
  23. run_test = True
  24. speed = 20 if use_simulation else 30
  25. dist = 500 if use_simulation else 400
  26. rewards = [0]
  27. fitness = [0]
  28.  
  29.  
  30. def terminate_program(signal_number, frame):
  31.     print("Ctrl-C received, terminating program")
  32.     sys.exit(1)
  33.  
  34.  
  35. def main():
  36.     signal.signal(signal.SIGINT, terminate_program)
  37.  
  38.     rob = robobo.SimulationRobobo().connect(address='192.168.0.102', port=19997) if use_simulation \
  39.         else robobo.HardwareRobobo(camera=True).connect(address="10.15.3.48")
  40.  
  41.     def get_sensor_info(direction):
  42.         a = np.log(np.array(rob.read_irs())) / 10
  43.         all_sensor_info = np.array([0 if x == inf else 1 + (-x / 2) - 0.2 for x in a]) if use_simulation \
  44.             else np.array(np.log(rob.read_irs())) / 10
  45.         all_sensor_info[all_sensor_info == inf] = 0
  46.         all_sensor_info[all_sensor_info == -inf] = 0
  47.         # [0, 1, 2, 3, 4, 5, 6, 7]
  48.         if direction == 'front':
  49.             return all_sensor_info[5]
  50.         elif direction == 'back':
  51.             return all_sensor_info[1]
  52.         elif direction == 'front_left':
  53.             return np.max(all_sensor_info[[6, 7]])
  54.         elif direction == 'front_right':
  55.             return np.max(all_sensor_info[[3, 4]])
  56.         elif direction == 'back_left':
  57.             return all_sensor_info[0]
  58.         elif direction == 'back_right':
  59.             return all_sensor_info[2]
  60.         elif direction == 'all':
  61.             print(all_sensor_info[3:])
  62.             return all_sensor_info
  63.         else:
  64.             raise Exception('Invalid direction')
  65.  
  66.     # safe, almost safe, not safe. combine with previous state of safe almost safe and not safe.
  67.     # safe to almost safe is good, almost safe to safe is okay, safe to safe is neutral
  68.     # s to a to r to s'.
  69.     # Small steps for going left or right (left or right are only rotating and straight is going forward).
  70.     # controller is the q values: the boundary for every sensor.
  71.  
  72.     def move_left():
  73.         rob.move(-speed, speed, dist)
  74.  
  75.     def move_right():
  76.         rob.move(speed, -speed, dist)
  77.  
  78.     def go_straight():
  79.         rob.move(speed, speed, dist)
  80.  
  81.     def move_back():
  82.         rob.move(-speed, -speed, dist)
  83.  
  84.     boundary = [0.4, 0.7] if not use_simulation else [0.75, 0.95]
  85.  
  86.     # A static collision-avoidance policy
  87.     def static_policy(s):
  88.         if get_sensor_info('front_left') >= s \
  89.                 and get_sensor_info('front_left') > get_sensor_info('front_right'):
  90.             return 2
  91.  
  92.         elif get_sensor_info('front_right') >= s \
  93.                 and get_sensor_info('front_right') > get_sensor_info('front_left'):
  94.             return 1
  95.         else:
  96.             return 0
  97.  
  98.     state_table = {}
  99.     if os.path.exists('./src/state_table.json'):
  100.         with open('./src/state_table.json') as f:
  101.             state_table = json.load(f)
  102.  
  103.     def epsilon_policy(s, epsilon):
  104.         s = str(s)
  105.         # epsilon greedy
  106.         """"
  107.        ACTIONS ARE DEFINED AS FOLLOWS:
  108.          NUM: ACTION
  109.            ------------
  110.            0: STRAIGHT
  111.            1: LEFT
  112.            2: RIGHT
  113.            ------------
  114.        """
  115.         e = 0 if run_test else epsilon
  116.         if e > random.random():
  117.             return random.choice([0, 1, 2])
  118.         else:
  119.             return np.argmax(state_table[s])
  120.  
  121.     def take_action(action):
  122.         if action == 1:
  123.             move_left()
  124.         elif action == 2:
  125.             move_right()
  126.         elif action == 0:
  127.             go_straight()
  128.         # elif action == 'back':
  129.         #     move_back()
  130.  
  131.     def get_reward(current, new, action):
  132.         if current == 0 and new == 0:
  133.             return 0 if action == 0 else -1
  134.         elif current == 0 and new == 1:
  135.             return 1
  136.         elif current == 0 and new == 2:
  137.             return -10
  138.         elif current == 1 and new == 0:
  139.             return 1
  140.         elif current == 1 and new == 1:
  141.             return 1 if action == 0 else 0
  142.         elif current == 1 and new == 2:
  143.             return -10
  144.         elif current == 2 and new == 2:
  145.             return -10
  146.         return 0
  147.         # TODO give negative reward for repetitions
  148.  
  149.     def make_discrete(values, boundaries):
  150.         discrete_list = []
  151.         for x in values:
  152.             if x > boundaries[1]:
  153.                 discrete_list.append(2)
  154.             elif boundaries[1] > x > boundaries[0]:
  155.                 discrete_list.append(1)
  156.             elif boundaries[0] > x:
  157.                 discrete_list.append(0)
  158.         return discrete_list
  159.  
  160.     """
  161.    REINFORCEMENT LEARNING PROCESS
  162.    INPUT:  alpha    : learning rate
  163.            gamma    : discount factor
  164.            epsilon  : epsilon value for e-greedy
  165.            episodes : no. of episodes
  166.            act_lim  : no. of actions robot takes before ending episode
  167.            qL       : True if you use Q-Learning
  168.    """
  169.     stat_fitness = [0]
  170.     stat_rewards = [0]
  171.  
  172.     def run_static(lim):
  173.         for _ in range(lim):
  174.             if use_simulation:
  175.                 rob.play_simulation()
  176.  
  177.             current_state = make_discrete(get_sensor_info('all')[3:], boundary)
  178.  
  179.             if str(current_state) not in state_table.keys():
  180.                 state_table[str(current_state)] = [0 for _ in range(3)]
  181.  
  182.             action = static_policy(0.75)
  183.  
  184.             take_action(action)
  185.  
  186.             new_state = make_discrete(get_sensor_info('all')[3:], boundary)
  187.  
  188.             r = get_reward(np.max(current_state), np.max(new_state), action)
  189.  
  190.             normalized_r = ((r - -10) / (1 - -10)) * (1 - -1) + -1
  191.             stat_fitness.append(stat_fitness[-1] + (normalized_r * np.max(get_sensor_info("all")[3:])))
  192.             # print(fitness)
  193.             if stat_rewards:
  194.                 stat_rewards.append(stat_rewards[-1] + normalized_r)
  195.             else:
  196.                 rewards.append(normalized_r)
  197.  
  198.     def rl(alpha, gamma, epsilon, episodes, act_lim, qL=False):
  199.         for i in range(episodes):
  200.             print('Episode ' + str(i))
  201.             terminate = False
  202.             if use_simulation:
  203.                 rob.play_simulation()
  204.  
  205.             current_state = make_discrete(get_sensor_info('all')[3:], boundary)
  206.  
  207.             if str(current_state) not in state_table.keys():
  208.                 state_table[str(current_state)] = [0 for _ in range(3)]
  209.  
  210.             action = epsilon_policy(current_state, epsilon)
  211.             # initialise state if it doesn't exist, else retrieve the current q-value
  212.             x = 0
  213.             while not terminate:
  214.                 take_action(action)
  215.                 new_state = make_discrete(get_sensor_info('all')[3:], boundary)
  216.  
  217.                 if str(new_state) not in state_table.keys():
  218.                     state_table[str(new_state)] = [0 for _ in range(3)]
  219.  
  220.                 new_action = epsilon_policy(new_state, epsilon)
  221.  
  222.                 # Retrieve the max action if we use Q-Learning
  223.                 max_action = np.argmax(state_table[str(new_state)]) if qL else new_action
  224.  
  225.                 # Get reward
  226.                 r = get_reward(np.max(current_state), np.max(new_state), action)
  227.  
  228.                 normalized_r = ((r - -10) / (1 - -10)) * (1 - -1) + -1
  229.                 fitness.append(fitness[-1] + normalized_r * np.max(get_sensor_info("all")[3:]))
  230.                 # print(fitness)
  231.                 if rewards:
  232.                     rewards.append(rewards[-1] + normalized_r)
  233.                 else:
  234.                     rewards.append(normalized_r)
  235.  
  236.                 # Update rule
  237.                 print("r: ", r)
  238.  
  239.                 if not run_test:
  240.                     print('update')
  241.                     state_table[str(current_state)][action] += \
  242.                         alpha * (r + (gamma *
  243.                                       np.array(
  244.                                           state_table[str(new_state)][max_action]))
  245.                                  - np.array(state_table[str(current_state)][action]))
  246.  
  247.                 # Stop episode if we get very close to an obstacle
  248.                 if (max(new_state) == 2 and use_simulation) or x == act_lim-1:
  249.                     state_table[str(new_state)][new_action] = -10
  250.                     terminate = True
  251.                     print("done")
  252.                     if not run_test:
  253.                         print('writing json')
  254.                         with open('./src/state_table.json', 'w') as json_file:
  255.                             json.dump(state_table, json_file)
  256.  
  257.                     if use_simulation:
  258.                         print("stopping the simulation")
  259.                         rob.stop_world()
  260.                         while not rob.is_sim_stopped():
  261.                             print("waiting for the simulation to stop")
  262.                         time.sleep(2)
  263.  
  264.                 # update current state and action
  265.                 current_state = new_state
  266.                 action = new_action
  267.  
  268.                 # increment action limit counter
  269.                 x += 1
  270.  
  271.     # alpha, gamma, epsilon, episodes, actions per episode
  272.     # run_static(200)
  273.     rl(0.9, 0.9, 0.08, 1, 500, qL=True)
  274.  
  275.     pprint(state_table)
  276.  
  277.     if run_test:
  278.         all_rewards = []
  279.         all_fits = []
  280.         if os.path.exists('./src/rewards.csv'):
  281.             with open('./src/rewards.csv') as f:
  282.                 all_rewards = pickle.load(f)
  283.  
  284.         if os.path.exists('./src/fitness.csv'):
  285.             with open('./src/fitness.csv') as f:
  286.                 all_fits = pickle.load(f)
  287.  
  288.         all_rewards += rewards
  289.         all_fits += fitness
  290.  
  291.         # print(all_rewards)
  292.         # print(all_fits)
  293.  
  294.         # with open('./src/rewards.csv', 'w') as f:
  295.         #     pickle.dump(all_rewards, f)
  296.         #
  297.         # with open('./src/fitness.csv', 'w') as f:
  298.         #     pickle.dump(all_fits, f)
  299.         #
  300.         # with open('./src/stat_rewards.csv', 'w') as f:
  301.         #     pickle.dump(stat_rewards, f)
  302.         #
  303.  
  304.         # with open('./src/stat_fitness.csv', 'w') as f:
  305.         #     pickle.dump(stat_fitness, f)
  306.         #
  307.         # plt.figure('Rewards')
  308.         # plt.plot(all_rewards, label='Q-Learning Controller')
  309.         # plt.plot(stat_rewards, label='Static Controller')
  310.         # plt.legend()
  311.         # plt.savefig("./src/plot_reward.png")
  312.         # plt.show()
  313.         #
  314.         # plt.figure('Fitness Values')
  315.         # plt.plot(all_fits, label='Q-Learning Controller')
  316.         # plt.plot(stat_fitness, label='Static Controller')
  317.         # plt.legend()
  318.         # plt.savefig("./src/plot_fitness.png")
  319.         # plt.show()
  320.  
  321.  
  322. def image_test():
  323.     signal.signal(signal.SIGINT, terminate_program)
  324.     # rob = robobo.SimulationRobobo().connect(address='130.37.120.197', port=19997) if use_simulation \
  325.     # #     else robobo.HardwareRobobo(camera=True).connect(address="172.20.10.5")
  326.     # if use_simulation:
  327.     #     rob.play_simulation()
  328.     # rob.set_phone_tilt(109, 100)
  329.  
  330.     print('taking pic')
  331.     # image = rob.get_image_front()
  332.     # cv2.imwrite("test_pictures.png", image)
  333.     image = cv2.imread('../test_pictures1.jpg')
  334.     count = 0
  335.     print(image)
  336.     b = 64
  337.     for i in range(len(image)):
  338.         for j in range(len(image[i])):
  339.             pixel = image[i][j]
  340.             if (pixel[0] > b or pixel[2] > b) and pixel[1] < b*2\
  341.                     or (pixel[0] > b and pixel[1] > b and pixel[2] > b):
  342.                 image[i][j] = [0, 0, 0]
  343.                 count += 1
  344.     print(1 - (count / (640*480)))
  345.     cv2.imwrite("../test_img.png", image)
  346.  
  347.     # if use_simulation:
  348.     #     print('stopping the simulation')
  349.     #     rob.stop_world()
  350.     #     while not rob.is_sim_stopped():
  351.     #         print("waiting for the simulation to stop")
  352.     #     time.sleep(2)
  353.  
  354.  
  355. if __name__ == "__main__":
  356.     # main()
  357.     image_test()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top