ArgieAnon

play.pi

Dec 12th, 2019
204
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import os
  2. import sys
  3. import time
  4. import ctypes
  5. import logging
  6. logging.getLogger('tensorflow').disabled = True
  7. logging.getLogger('numpy').disabled = True
  8. os.environ["CUDA_VISIBLE_DEVICES"]=""
  9. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  10. from generator.gpt2.gpt2_generator import *
  11. from story.story_manager import *
  12. from story.utils import *
  13. from func_timeout import func_timeout, FunctionTimedOut
  14.  
  15.  
  16.  
  17. def splash():
  18.     print("0) New Game\n1) Load Game\n")
  19.     choice = get_num_options(2)
  20.  
  21.     if choice == 1:
  22.         return "load"
  23.     else:
  24.         return "new"
  25.  
  26.  
  27. def select_game():
  28.     with open(YAML_FILE, "r") as stream:
  29.         data = yaml.safe_load(stream)
  30.  
  31.     print("Pick a setting.")
  32.     settings = data["settings"].keys()
  33.     for i, setting in enumerate(settings):
  34.         print_str = str(i) + ") " + setting
  35.         if setting == "fantasy":
  36.             print_str += " (recommended)"
  37.  
  38.         console_print(print_str)
  39.     console_print(str(len(settings)) + ") custom")
  40.     choice = get_num_options(len(settings) + 1)
  41.  
  42.     if choice == len(settings):
  43.  
  44.         console_print(
  45.             "\nEnter a prompt that describes who you are and what are your goals. The AI will always remember this prompt and "
  46.             "will use it for context, ex:\n 'Your name is John Doe. You are a knight in the kingdom of Larion. You "
  47.             "were sent by the king to track down and slay an evil dragon.'\n"
  48.         )
  49.         context = input("Story Context: ") + " "
  50.         console_print(
  51.             "\nNow enter a prompt that describes the start of your story. This comes after the Story Context and will give the AI "
  52.             "a starting point for the story. Unlike the context, the AI will eventually forget this prompt, ex:\n 'After arriving "
  53.             "at the forest, it turns out the evil dragon is actually a pretty cute monster girl. You decide you're going to lay "
  54.             "this dragon instead.'"
  55.         )
  56.         prompt = input("Starting Prompt: ")
  57.         return context, prompt
  58.  
  59.     setting_key = list(settings)[choice]
  60.  
  61.     print("\nPick a character")
  62.     characters = data["settings"][setting_key]["characters"]
  63.     for i, character in enumerate(characters):
  64.         console_print(str(i) + ") " + character)
  65.     character_key = list(characters)[get_num_options(len(characters))]
  66.  
  67.     name = input("\nWhat is your name? ")
  68.     setting_description = data["settings"][setting_key]["description"]
  69.     character = data["settings"][setting_key]["characters"][character_key]
  70.  
  71.     context = (
  72.         "You are "
  73.         + name
  74.         + ", a "
  75.         + character_key
  76.         + " "
  77.         + setting_description
  78.         + "You have a "
  79.         + character["item1"]
  80.         + " and a "
  81.         + character["item2"]
  82.         + ". "
  83.     )
  84.     prompt_num = np.random.randint(0, len(character["prompts"]))
  85.     prompt = character["prompts"][prompt_num]
  86.  
  87.     return context, prompt
  88.  
  89.  
  90. def instructions():
  91.     text = "\nAI Dungeon 2 Instructions:"
  92.     text += '\n Enter actions starting with a verb ex. "go to the tavern" or "attack the orc."'
  93.     text += '\n To speak enter \'say "(thing you want to say)"\' or just "(thing you want to say)" '
  94.     text += "\n\nThe following commands can be entered for any action: "
  95.     text += '\n  "revert"   Reverts the last action allowing you to pick a different action.'
  96.     text += '\n  "quit"     Quits the game and saves'
  97.     text += '\n  "restart"  Starts a new game and saves your current one'
  98.     text += '\n  "save"     Makes a new save of your game and gives you the save ID'
  99.     text += '\n  "load"     Asks for a save ID and loads the game if the ID is valid'
  100.     text += '\n  "print"    Prints a transcript of your adventure (without extra newline formatting)'
  101.     text += '\n  "help"     Prints these instructions again'
  102.     text += '\n  "infto"    Change the default timeout.'
  103.     text += '\n  "retry"    Get another result from the same input.'
  104.     text += '\n  "context"  Update the story\'s context (helps AI keep track of things in longer stories).'
  105.     text += '\n  "censor off/on" to turn censoring off or on.'
  106.     text += '\n  Use ! at the beginning of a sentence to process it literally.'
  107.     return text
  108.  
  109.  
  110. def play_aidungeon_2():
  111.  
  112.     console_print(
  113.         "AI Dungeon 2 will save and use your actions and game to continually improve AI Dungeon."
  114.         + " If you would like to disable this enter 'nosaving' for any action. This will also turn off the "
  115.         + "ability to save games."
  116.     )
  117.  
  118.     upload_story = True
  119.  
  120.     print("\nInitializing AI Dungeon! (This might take a few minutes)\n")
  121.     generator = GPT2Generator()
  122.     story_manager = UnconstrainedStoryManager(generator)
  123.     inference_timeout = 180
  124.     def act(action):
  125.         return func_timeout(inference_timeout, story_manager.act, (action,))
  126.     def notify_hanged():
  127.         console_print("That input caused the model to hang (timeout is {inference_timeout}, use infto ## command to change)")
  128.         ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  129.     print("\n")
  130.  
  131.     with open("opening.txt", "r", encoding="utf-8") as file:
  132.         starter = file.read()
  133.     print(starter)
  134.     ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  135.  
  136.     while True:
  137.         if story_manager.story != None:
  138.             del story_manager.story
  139.  
  140.         generated = False
  141.         while not generated:
  142.             print("\n\n")
  143.             splash_choice = splash()
  144.  
  145.             if splash_choice == "new":
  146.                 print("\n\n")
  147.                 context, prompt = select_game()
  148.                 console_print(instructions())
  149.                 print("\nGenerating story...\n")
  150.                 generated = story_manager.start_new_story(
  151.                     prompt, context=context, upload_story=upload_story
  152.                 )
  153.                 if generated:
  154.                     console_print(str(story_manager.story))
  155.                     ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  156.             else:
  157.                 load_ID = input("What is the ID of the saved game? ")
  158.                 result = story_manager.load_new_story(load_ID)
  159.                 print("\nLoading Game...\n")
  160.                 print(result)
  161.                 generated = True
  162.  
  163.         while True:
  164.             sys.stdin.flush()
  165.             action = input("> ")
  166.             if action == "restart":
  167.                 rating = input("Please rate the story quality from 1-10: ")
  168.                 rating_float = float(rating)
  169.                 story_manager.story.rating = rating_float
  170.                 break
  171.  
  172.             elif action == "quit":
  173.                 rating = input("Please rate the story quality from 1-10: ")
  174.                 rating_float = float(rating)
  175.                 story_manager.story.rating = rating_float
  176.                 exit()
  177.  
  178.             elif action == "nosaving":
  179.                 upload_story = False
  180.                 story_manager.story.upload_story = False
  181.                 console_print("Saving turned off.")
  182.  
  183.             elif action == "help":
  184.                 console_print(instructions())
  185.  
  186.             elif action == "censor off":
  187.                 generator.censor = False
  188.  
  189.             elif action == "censor on":
  190.                 generator.censor = True
  191.  
  192.             elif action == "save":
  193.                 if upload_story:
  194.                     save_ID = input("Choose a name (ID) for the saved game:")
  195.                     id = story_manager.story.save_to_local(save_ID)
  196.                     console_print("Game saved.")
  197.                     console_print(
  198.                         "To load the game, type 'load' and enter the following ID: "
  199.                         + save_ID
  200.                     )
  201.                 else:
  202.                     console_print("Saving has been turned off. Cannot save.")
  203.  
  204.             elif action == "load":
  205.                 load_ID = input("What is the ID of the saved game?")
  206.                 result = story_manager.story.load_from_local(load_ID)
  207.                 console_print("\nLoading Game...\n")
  208.                 console_print(result)
  209.  
  210.             elif len(action.split(" ")) == 2 and action.split(" ")[0] == "load":
  211.                 load_ID = action.split(" ")[1]
  212.                 result = story_manager.story.load_from_local(load_ID)
  213.                 console_print("\nLoading Game...\n")
  214.                 console_print(result)
  215.  
  216.             elif action == "print":
  217.                 print("\nPRINTING\n")
  218.                 print(str(story_manager.story))
  219.  
  220.             elif action == "revert":
  221.  
  222.                 if len(story_manager.story.actions) is 0:
  223.                     console_print("You can't go back any farther. ")
  224.                     continue
  225.  
  226.                 story_manager.story.actions = story_manager.story.actions[:-1]
  227.                 story_manager.story.results = story_manager.story.results[:-1]
  228.                 console_print("Last action reverted. ")
  229.                 if len(story_manager.story.results) > 0:
  230.                     console_print(story_manager.story.results[-1])
  231.                 else:
  232.                     console_print(story_manager.story.story_start)
  233.                 continue
  234.                
  235.             elif action == "retry":
  236.                 if len(story_manager.story.actions) is 0:
  237.                     console_print("There is nothing to retry.")
  238.                     continue
  239.                 last_action = story_manager.story.actions.pop()
  240.                 last_result = story_manager.story.results.pop()
  241.                 try:
  242.                     act
  243.                 except NameError:
  244.                     act = story_manager.act
  245.                 try:
  246.                     try:
  247.                         result = act(last_action)
  248.                         if result == "":
  249.                             console_print("Error: The story is too long to be processed, try using a shorter input or reverting some actions.")
  250.                             ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  251.                         else:
  252.                             console_print(last_action)
  253.                             console_print(story_manager.story.results[-1])
  254.                             ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  255.                     except FunctionTimeOut:
  256.                         story_manager.story.actions.append(last_action)
  257.                         story_manager.story.results.append(last_result)
  258.                         notify_hanged()
  259.                         console_print("Your story progress has not been altered.")
  260.                 except NameError:
  261.                     pass
  262.                 continue
  263.  
  264.             elif action == "context":
  265.                 console_print("Current story context: \n" + story_manager.get_context() + "\n")
  266.                 new_context = input("Enter a new context describing the general status of your character and story: ")
  267.                 story_manager.set_context(new_context)
  268.                 console_print("Story context updated.\n")
  269.             #Fuck reddit
  270.             elif len(action.split(" ")) == 2 and action.split(" ")[0] == 'infto':
  271.                 try:
  272.                     inference_timeout = int(action.split(" ")[1])
  273.                     console_print("Set timeout to {inference_timeout}")
  274.                 except:
  275.                     console_print("Failed to set timeout. Example usage: infto 30")
  276.             else:
  277.                 if action == "":
  278.                     action = ""
  279.                     try:
  280.                         result = act(action)
  281.                     except FunctionTimedOut:
  282.                         notify_hanged()
  283.                         continue
  284.                     console_print(result)
  285.                     ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  286.  
  287.                 elif action[0] == '"':
  288.                     action = "You say " + action
  289.  
  290.                 elif action[0] == '!':
  291.                     action = "\n" + action[1:].replace("\\n", "\n") + "\n"
  292.  
  293.                 else:
  294.                     action = action.strip()
  295.                     action = action[0].lower() + action[1:]
  296.                     if "You" not in action[:6] and "I" not in action[:6]:
  297.                         action = "You " + action
  298.                     if action[-1] not in [".", "?", "!"]:
  299.                         action = action + "."
  300.                     action = first_to_second_person(action)
  301.                     action = "\n> " + action + "\n"
  302.                 try:
  303.                     result = "\n" + act(action)
  304.                 except FunctionTimedOut:
  305.                     notify_hanged()
  306.                     continue
  307.                 if len(story_manager.story.results) >= 2:
  308.                     similarity = get_similarity(
  309.                         story_manager.story.results[-1], story_manager.story.results[-2]
  310.                     )
  311.                     if similarity > 0.9:
  312.                         story_manager.story.actions = story_manager.story.actions[:-1]
  313.                         story_manager.story.results = story_manager.story.results[:-1]
  314.                         console_print(
  315.                             "Woops that action caused the model to start looping. Try a different action to prevent that."
  316.                         )
  317.                         ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  318.                         continue
  319.  
  320.                 if player_won(result):
  321.                     console_print(result + "\n CONGRATS YOU WIN")
  322.                     ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  323.                     break
  324.                 elif player_died(result):
  325.                     console_print(result)
  326.                     console_print("YOU DIED. GAME OVER")
  327.                     console_print("\nOptions:")
  328.                     console_print("0) Start a new game")
  329.                     console_print(
  330.                         "1) \"I'm not dead yet!\" (If you didn't actually die) "
  331.                     )
  332.                     console_print("Which do you choose? ")
  333.                     ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  334.                     choice = get_num_options(2)
  335.                     if choice == 0:
  336.                         break
  337.                     else:
  338.                         console_print("Sorry about that...where were we?")
  339.                         console_print(result)
  340.  
  341.                 else:
  342.                     console_print(result)
  343.                     ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
  344.  
  345.  
  346. if __name__ == "__main__":
  347.     play_aidungeon_2()
Add Comment
Please, Sign In to add comment