SHOW:
|
|
- or go back to the newest paste.
1 | import os | |
2 | import sys | |
3 | import time | |
4 | import ctypes | |
5 | import logging | |
6 | logging.getLogger('tensorflow').disabled = True | |
7 | logging.getLogger('numpy').disabled = True | |
8 | os.environ["CUDA_VISIBLE_DEVICES"]="" | |
9 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
10 | from generator.gpt2.gpt2_generator import * | |
11 | from story.story_manager import * | |
12 | from story.utils import * | |
13 | from func_timeout import func_timeout, FunctionTimedOut | |
14 | ||
15 | ||
16 | ||
17 | def splash(): | |
18 | print("0) New Game\n1) Load Game\n") | |
19 | choice = get_num_options(2) | |
20 | ||
21 | if choice == 1: | |
22 | return "load" | |
23 | else: | |
24 | return "new" | |
25 | ||
26 | ||
27 | def select_game(): | |
28 | with open(YAML_FILE, "r") as stream: | |
29 | data = yaml.safe_load(stream) | |
30 | ||
31 | print("Pick a setting.") | |
32 | settings = data["settings"].keys() | |
33 | for i, setting in enumerate(settings): | |
34 | print_str = str(i) + ") " + setting | |
35 | if setting == "fantasy": | |
36 | print_str += " (recommended)" | |
37 | ||
38 | console_print(print_str) | |
39 | console_print(str(len(settings)) + ") custom") | |
40 | choice = get_num_options(len(settings) + 1) | |
41 | ||
42 | if choice == len(settings): | |
43 | ||
44 | console_print( | |
45 | "\nEnter a prompt that describes who you are and what are your goals. The AI will always remember this prompt and " | |
46 | "will use it for context, ex:\n 'Your name is John Doe. You are a knight in the kingdom of Larion. You " | |
47 | "were sent by the king to track down and slay an evil dragon.'\n" | |
48 | ) | |
49 | context = input("Story Context: ") + " " | |
50 | console_print( | |
51 | "\nNow enter a prompt that describes the start of your story. This comes after the Story Context and will give the AI " | |
52 | "a starting point for the story. Unlike the context, the AI will eventually forget this prompt, ex:\n 'After arriving " | |
53 | "at the forest, it turns out the evil dragon is actually a pretty cute monster girl. You decide you're going to lay " | |
54 | "this dragon instead.'" | |
55 | ) | |
56 | prompt = input("Starting Prompt: ") | |
57 | return context, prompt | |
58 | ||
59 | setting_key = list(settings)[choice] | |
60 | ||
61 | print("\nPick a character") | |
62 | characters = data["settings"][setting_key]["characters"] | |
63 | for i, character in enumerate(characters): | |
64 | console_print(str(i) + ") " + character) | |
65 | character_key = list(characters)[get_num_options(len(characters))] | |
66 | ||
67 | name = input("\nWhat is your name? ") | |
68 | setting_description = data["settings"][setting_key]["description"] | |
69 | character = data["settings"][setting_key]["characters"][character_key] | |
70 | ||
71 | context = ( | |
72 | "You are " | |
73 | + name | |
74 | + ", a " | |
75 | + character_key | |
76 | + " " | |
77 | + setting_description | |
78 | + "You have a " | |
79 | + character["item1"] | |
80 | + " and a " | |
81 | + character["item2"] | |
82 | + ". " | |
83 | ) | |
84 | prompt_num = np.random.randint(0, len(character["prompts"])) | |
85 | prompt = character["prompts"][prompt_num] | |
86 | ||
87 | return context, prompt | |
88 | ||
89 | ||
90 | def instructions(): | |
91 | text = "\nAI Dungeon 2 Instructions:" | |
92 | text += '\n Enter actions starting with a verb ex. "go to the tavern" or "attack the orc."' | |
93 | text += '\n To speak enter \'say "(thing you want to say)"\' or just "(thing you want to say)" ' | |
94 | text += "\n\nThe following commands can be entered for any action: " | |
95 | text += '\n "revert" Reverts the last action allowing you to pick a different action.' | |
96 | text += '\n "quit" Quits the game and saves' | |
97 | text += '\n "restart" Starts a new game and saves your current one' | |
98 | text += '\n "save" Makes a new save of your game and gives you the save ID' | |
99 | text += '\n "load" Asks for a save ID and loads the game if the ID is valid' | |
100 | text += '\n "print" Prints a transcript of your adventure (without extra newline formatting)' | |
101 | text += '\n "help" Prints these instructions again' | |
102 | text += '\n "infto" Change the default timeout.' | |
103 | text += '\n "retry" Get another result from the same input.' | |
104 | text += '\n "context" Update the story\'s context (helps AI keep track of things in longer stories).' | |
105 | text += '\n "censor off/on" to turn censoring off or on.' | |
106 | text += '\n Use ! at the beginning of a sentence to process it literally.' | |
107 | return text | |
108 | ||
109 | ||
110 | def play_aidungeon_2(): | |
111 | ||
112 | console_print( | |
113 | "AI Dungeon 2 will save and use your actions and game to continually improve AI Dungeon." | |
114 | + " If you would like to disable this enter 'nosaving' for any action. This will also turn off the " | |
115 | + "ability to save games." | |
116 | ) | |
117 | ||
118 | upload_story = True | |
119 | ||
120 | print("\nInitializing AI Dungeon! (This might take a few minutes)\n") | |
121 | generator = GPT2Generator() | |
122 | story_manager = UnconstrainedStoryManager(generator) | |
123 | inference_timeout = 180 | |
124 | def act(action): | |
125 | return func_timeout(inference_timeout, story_manager.act, (action,)) | |
126 | def notify_hanged(): | |
127 | console_print("That input caused the model to hang (timeout is {inference_timeout}, use infto ## command to change)") | |
128 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
129 | print("\n") | |
130 | ||
131 | with open("opening.txt", "r", encoding="utf-8") as file: | |
132 | starter = file.read() | |
133 | print(starter) | |
134 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
135 | ||
136 | while True: | |
137 | if story_manager.story != None: | |
138 | del story_manager.story | |
139 | ||
140 | generated = False | |
141 | while not generated: | |
142 | print("\n\n") | |
143 | splash_choice = splash() | |
144 | ||
145 | if splash_choice == "new": | |
146 | print("\n\n") | |
147 | context, prompt = select_game() | |
148 | console_print(instructions()) | |
149 | print("\nGenerating story...\n") | |
150 | generated = story_manager.start_new_story( | |
151 | prompt, context=context, upload_story=upload_story | |
152 | ) | |
153 | if generated: | |
154 | console_print(str(story_manager.story)) | |
155 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
156 | else: | |
157 | load_ID = input("What is the ID of the saved game? ") | |
158 | result = story_manager.load_new_story(load_ID) | |
159 | print("\nLoading Game...\n") | |
160 | print(result) | |
161 | generated = True | |
162 | ||
163 | while True: | |
164 | sys.stdin.flush() | |
165 | action = input("> ") | |
166 | if action == "restart": | |
167 | rating = input("Please rate the story quality from 1-10: ") | |
168 | rating_float = float(rating) | |
169 | story_manager.story.rating = rating_float | |
170 | break | |
171 | ||
172 | elif action == "quit": | |
173 | rating = input("Please rate the story quality from 1-10: ") | |
174 | rating_float = float(rating) | |
175 | story_manager.story.rating = rating_float | |
176 | exit() | |
177 | ||
178 | elif action == "nosaving": | |
179 | upload_story = False | |
180 | story_manager.story.upload_story = False | |
181 | console_print("Saving turned off.") | |
182 | ||
183 | elif action == "help": | |
184 | console_print(instructions()) | |
185 | ||
186 | elif action == "censor off": | |
187 | generator.censor = False | |
188 | ||
189 | elif action == "censor on": | |
190 | generator.censor = True | |
191 | ||
192 | elif action == "save": | |
193 | if upload_story: | |
194 | save_ID = input("Choose a name (ID) for the saved game:") | |
195 | id = story_manager.story.save_to_local(save_ID) | |
196 | console_print("Game saved.") | |
197 | console_print( | |
198 | "To load the game, type 'load' and enter the following ID: " | |
199 | + save_ID | |
200 | ) | |
201 | else: | |
202 | console_print("Saving has been turned off. Cannot save.") | |
203 | ||
204 | elif action == "load": | |
205 | load_ID = input("What is the ID of the saved game?") | |
206 | result = story_manager.story.load_from_local(load_ID) | |
207 | console_print("\nLoading Game...\n") | |
208 | console_print(result) | |
209 | ||
210 | elif len(action.split(" ")) == 2 and action.split(" ")[0] == "load": | |
211 | load_ID = action.split(" ")[1] | |
212 | result = story_manager.story.load_from_local(load_ID) | |
213 | console_print("\nLoading Game...\n") | |
214 | console_print(result) | |
215 | ||
216 | elif action == "print": | |
217 | print("\nPRINTING\n") | |
218 | print(str(story_manager.story)) | |
219 | ||
220 | elif action == "revert": | |
221 | ||
222 | if len(story_manager.story.actions) is 0: | |
223 | console_print("You can't go back any farther. ") | |
224 | continue | |
225 | ||
226 | story_manager.story.actions = story_manager.story.actions[:-1] | |
227 | story_manager.story.results = story_manager.story.results[:-1] | |
228 | console_print("Last action reverted. ") | |
229 | if len(story_manager.story.results) > 0: | |
230 | console_print(story_manager.story.results[-1]) | |
231 | else: | |
232 | console_print(story_manager.story.story_start) | |
233 | continue | |
234 | ||
235 | elif action == "retry": | |
236 | if len(story_manager.story.actions) is 0: | |
237 | console_print("There is nothing to retry.") | |
238 | continue | |
239 | last_action = story_manager.story.actions.pop() | |
240 | last_result = story_manager.story.results.pop() | |
241 | try: | |
242 | act | |
243 | except NameError: | |
244 | act = story_manager.act | |
245 | try: | |
246 | try: | |
247 | result = act(last_action) | |
248 | if result == "": | |
249 | console_print("Error: The story is too long to be processed, try using a shorter input or reverting some actions.") | |
250 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
251 | else: | |
252 | console_print(last_action) | |
253 | console_print(story_manager.story.results[-1]) | |
254 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
255 | except FunctionTimeOut: | |
256 | story_manager.story.actions.append(last_action) | |
257 | story_manager.story.results.append(last_result) | |
258 | notify_hanged() | |
259 | console_print("Your story progress has not been altered.") | |
260 | except NameError: | |
261 | pass | |
262 | continue | |
263 | ||
264 | elif action == "context": | |
265 | console_print("Current story context: \n" + story_manager.get_context() + "\n") | |
266 | new_context = input("Enter a new context describing the general status of your character and story: ") | |
267 | story_manager.set_context(new_context) | |
268 | console_print("Story context updated.\n") | |
269 | #Fuck reddit | |
270 | elif len(action.split(" ")) == 2 and action.split(" ")[0] == 'infto': | |
271 | try: | |
272 | inference_timeout = int(action.split(" ")[1]) | |
273 | console_print("Set timeout to {inference_timeout}") | |
274 | except: | |
275 | console_print("Failed to set timeout. Example usage: infto 30") | |
276 | else: | |
277 | if action == "": | |
278 | action = "" | |
279 | try: | |
280 | result = act(action) | |
281 | except FunctionTimedOut: | |
282 | notify_hanged() | |
283 | continue | |
284 | console_print(result) | |
285 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
286 | ||
287 | elif action[0] == '"': | |
288 | action = "You say " + action | |
289 | ||
290 | elif action[0] == '!': | |
291 | action = "\n" + action[1:].replace("\\n", "\n") + "\n" | |
292 | ||
293 | else: | |
294 | action = action.strip() | |
295 | action = action[0].lower() + action[1:] | |
296 | if "You" not in action[:6] and "I" not in action[:6]: | |
297 | action = "You " + action | |
298 | if action[-1] not in [".", "?", "!"]: | |
299 | action = action + "." | |
300 | action = first_to_second_person(action) | |
301 | action = "\n> " + action + "\n" | |
302 | try: | |
303 | result = "\n" + act(action) | |
304 | except FunctionTimedOut: | |
305 | notify_hanged() | |
306 | continue | |
307 | if len(story_manager.story.results) >= 2: | |
308 | similarity = get_similarity( | |
309 | story_manager.story.results[-1], story_manager.story.results[-2] | |
310 | ) | |
311 | if similarity > 0.9: | |
312 | story_manager.story.actions = story_manager.story.actions[:-1] | |
313 | story_manager.story.results = story_manager.story.results[:-1] | |
314 | console_print( | |
315 | "Woops that action caused the model to start looping. Try a different action to prevent that." | |
316 | ) | |
317 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
318 | continue | |
319 | ||
320 | if player_won(result): | |
321 | console_print(result + "\n CONGRATS YOU WIN") | |
322 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
323 | break | |
324 | elif player_died(result): | |
325 | console_print(result) | |
326 | console_print("YOU DIED. GAME OVER") | |
327 | console_print("\nOptions:") | |
328 | console_print("0) Start a new game") | |
329 | console_print( | |
330 | "1) \"I'm not dead yet!\" (If you didn't actually die) " | |
331 | ) | |
332 | console_print("Which do you choose? ") | |
333 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
334 | choice = get_num_options(2) | |
335 | if choice == 0: | |
336 | break | |
337 | else: | |
338 | console_print("Sorry about that...where were we?") | |
339 | console_print(result) | |
340 | ||
341 | else: | |
342 | console_print(result) | |
343 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
344 | ||
345 | ||
346 | if __name__ == "__main__": | |
347 | play_aidungeon_2() |