View difference between Paste ID: hZX3q4dy and XAyJLb4k
SHOW: | | - or go back to the newest paste.
1
import os
2
import sys
3
import time
4
import ctypes
5
import logging
6
logging.getLogger('tensorflow').disabled = True
7
logging.getLogger('numpy').disabled = True
8
os.environ["CUDA_VISIBLE_DEVICES"]=""
9
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
10
from generator.gpt2.gpt2_generator import *
11
from story.story_manager import *
12
from story.utils import *
13
from func_timeout import func_timeout, FunctionTimedOut
14
15
16
17
def splash():
18
    print("0) New Game\n1) Load Game\n")
19
    choice = get_num_options(2)
20
21
    if choice == 1:
22
        return "load"
23
    else:
24
        return "new"
25
26
27
def select_game():
28
    with open(YAML_FILE, "r") as stream:
29
        data = yaml.safe_load(stream)
30
31
    print("Pick a setting.")
32
    settings = data["settings"].keys()
33
    for i, setting in enumerate(settings):
34
        print_str = str(i) + ") " + setting
35
        if setting == "fantasy":
36
            print_str += " (recommended)"
37
38
        console_print(print_str)
39
    console_print(str(len(settings)) + ") custom")
40
    choice = get_num_options(len(settings) + 1)
41
42
    if choice == len(settings):
43
44
        console_print(
45
            "\nEnter a prompt that describes who you are and what are your goals. The AI will always remember this prompt and "
46
            "will use it for context, ex:\n 'Your name is John Doe. You are a knight in the kingdom of Larion. You "
47
            "were sent by the king to track down and slay an evil dragon.'\n"
48
        )
49
        context = input("Story Context: ") + " "
50
        console_print(
51
            "\nNow enter a prompt that describes the start of your story. This comes after the Story Context and will give the AI "
52
            "a starting point for the story. Unlike the context, the AI will eventually forget this prompt, ex:\n 'After arriving "
53
            "at the forest, it turns out the evil dragon is actually a pretty cute monster girl. You decide you're going to lay "
54
            "this dragon instead.'"
55
        )
56
        prompt = input("Starting Prompt: ")
57
        return context, prompt
58
59
    setting_key = list(settings)[choice]
60
61
    print("\nPick a character")
62
    characters = data["settings"][setting_key]["characters"]
63
    for i, character in enumerate(characters):
64
        console_print(str(i) + ") " + character)
65
    character_key = list(characters)[get_num_options(len(characters))]
66
67
    name = input("\nWhat is your name? ")
68
    setting_description = data["settings"][setting_key]["description"]
69
    character = data["settings"][setting_key]["characters"][character_key]
70
71
    context = (
72
        "You are "
73
        + name
74
        + ", a "
75
        + character_key
76
        + " "
77
        + setting_description
78
        + "You have a "
79
        + character["item1"]
80
        + " and a "
81
        + character["item2"]
82
        + ". "
83
    )
84
    prompt_num = np.random.randint(0, len(character["prompts"]))
85
    prompt = character["prompts"][prompt_num]
86
87
    return context, prompt
88
89
90
def instructions():
91
    text = "\nAI Dungeon 2 Instructions:"
92
    text += '\n Enter actions starting with a verb ex. "go to the tavern" or "attack the orc."'
93
    text += '\n To speak enter \'say "(thing you want to say)"\' or just "(thing you want to say)" '
94
    text += "\n\nThe following commands can be entered for any action: "
95
    text += '\n  "revert"   Reverts the last action allowing you to pick a different action.'
96
    text += '\n  "quit"     Quits the game and saves'
97
    text += '\n  "restart"  Starts a new game and saves your current one'
98
    text += '\n  "save"     Makes a new save of your game and gives you the save ID'
99
    text += '\n  "load"     Asks for a save ID and loads the game if the ID is valid'
100
    text += '\n  "print"    Prints a transcript of your adventure (without extra newline formatting)'
101
    text += '\n  "help"     Prints these instructions again'
102
    text += '\n  "infto"    Change the default timeout.'
103
    text += '\n  "retry"    Get another result from the same input.'
104
    text += '\n  "context"  Update the story\'s context (helps AI keep track of things in longer stories).'
105
    text += '\n  "censor off/on" to turn censoring off or on.'
106
    text += '\n  Use ! at the beginning of a sentence to process it literally.'
107
    return text
108
109
110
def play_aidungeon_2():
111
112
    console_print(
113
        "AI Dungeon 2 will save and use your actions and game to continually improve AI Dungeon."
114
        + " If you would like to disable this enter 'nosaving' for any action. This will also turn off the "
115
        + "ability to save games."
116
    )
117
118
    upload_story = True
119
120
    print("\nInitializing AI Dungeon! (This might take a few minutes)\n")
121
    generator = GPT2Generator()
122
    story_manager = UnconstrainedStoryManager(generator)
123
    inference_timeout = 180
124
    def act(action):
125
        return func_timeout(inference_timeout, story_manager.act, (action,))
126
    def notify_hanged():
127
        console_print("That input caused the model to hang (timeout is {inference_timeout}, use infto ## command to change)")
128
        ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
129
    print("\n")
130
131
    with open("opening.txt", "r", encoding="utf-8") as file:
132
        starter = file.read()
133
    print(starter)
134
    ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
135
136
    while True:
137
        if story_manager.story != None:
138
            del story_manager.story
139
140
        generated = False
141
        while not generated:
142
            print("\n\n")
143
            splash_choice = splash()
144
145
            if splash_choice == "new":
146
                print("\n\n")
147
                context, prompt = select_game()
148
                console_print(instructions())
149
                print("\nGenerating story...\n")
150
                generated = story_manager.start_new_story(
151
                    prompt, context=context, upload_story=upload_story
152
                )
153
                if generated:
154
                    console_print(str(story_manager.story))
155
                    ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
156
            else:
157
                load_ID = input("What is the ID of the saved game? ")
158
                result = story_manager.load_new_story(load_ID)
159
                print("\nLoading Game...\n")
160
                print(result)
161
                generated = True
162
163
        while True:
164
            sys.stdin.flush()
165
            action = input("> ")
166
            if action == "restart":
167
                rating = input("Please rate the story quality from 1-10: ")
168
                rating_float = float(rating)
169
                story_manager.story.rating = rating_float
170
                break
171
172
            elif action == "quit":
173
                rating = input("Please rate the story quality from 1-10: ")
174
                rating_float = float(rating)
175
                story_manager.story.rating = rating_float
176
                exit()
177
178
            elif action == "nosaving":
179
                upload_story = False
180
                story_manager.story.upload_story = False
181
                console_print("Saving turned off.")
182
183
            elif action == "help":
184
                console_print(instructions())
185
186
            elif action == "censor off":
187
                generator.censor = False
188
189
            elif action == "censor on":
190
                generator.censor = True
191
192
            elif action == "save":
193
                if upload_story:
194
                    save_ID = input("Choose a name (ID) for the saved game:")
195
                    id = story_manager.story.save_to_local(save_ID)
196
                    console_print("Game saved.")
197
                    console_print(
198
                        "To load the game, type 'load' and enter the following ID: "
199
                        + save_ID
200
                    )
201
                else:
202
                    console_print("Saving has been turned off. Cannot save.")
203
204
            elif action == "load":
205
                load_ID = input("What is the ID of the saved game?")
206
                result = story_manager.story.load_from_local(load_ID)
207
                console_print("\nLoading Game...\n")
208
                console_print(result)
209
210
            elif len(action.split(" ")) == 2 and action.split(" ")[0] == "load":
211
                load_ID = action.split(" ")[1]
212
                result = story_manager.story.load_from_local(load_ID)
213
                console_print("\nLoading Game...\n")
214
                console_print(result)
215
216
            elif action == "print":
217
                print("\nPRINTING\n")
218
                print(str(story_manager.story))
219
220
            elif action == "revert":
221
222
                if len(story_manager.story.actions) is 0:
223
                    console_print("You can't go back any farther. ")
224
                    continue
225
226
                story_manager.story.actions = story_manager.story.actions[:-1]
227
                story_manager.story.results = story_manager.story.results[:-1]
228
                console_print("Last action reverted. ")
229
                if len(story_manager.story.results) > 0:
230
                    console_print(story_manager.story.results[-1])
231
                else:
232
                    console_print(story_manager.story.story_start)
233
                continue
234
                
235
            elif action == "retry":
236
                if len(story_manager.story.actions) is 0:
237
                    console_print("There is nothing to retry.")
238
                    continue
239
                last_action = story_manager.story.actions.pop()
240
                last_result = story_manager.story.results.pop()
241
                try:
242
                    act
243
                except NameError:
244
                    act = story_manager.act
245
                try:
246
                    try:
247
                        result = act(last_action)
248
                        if result == "":
249
                            console_print("Error: The story is too long to be processed, try using a shorter input or reverting some actions.")
250
                            ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
251
                        else:
252
                            console_print(last_action)
253
                            console_print(story_manager.story.results[-1])
254
                            ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
255
                    except FunctionTimeOut:
256
                        story_manager.story.actions.append(last_action)
257
                        story_manager.story.results.append(last_result)
258
                        notify_hanged()
259
                        console_print("Your story progress has not been altered.")
260
                except NameError:
261
                    pass
262
                continue
263
264
            elif action == "context":
265
                console_print("Current story context: \n" + story_manager.get_context() + "\n")
266
                new_context = input("Enter a new context describing the general status of your character and story: ")
267
                story_manager.set_context(new_context)
268
                console_print("Story context updated.\n")
269
            #Fuck reddit
270
            elif len(action.split(" ")) == 2 and action.split(" ")[0] == 'infto':
271
                try:
272
                    inference_timeout = int(action.split(" ")[1])
273
                    console_print("Set timeout to {inference_timeout}")
274
                except:
275
                    console_print("Failed to set timeout. Example usage: infto 30")
276
            else:
277
                if action == "":
278
                    action = ""
279
                    try:
280
                        result = act(action)
281
                    except FunctionTimedOut:
282
                        notify_hanged()
283
                        continue
284
                    console_print(result)
285
                    ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
286
287
                elif action[0] == '"':
288
                    action = "You say " + action
289
290
                elif action[0] == '!':
291
                    action = "\n" + action[1:].replace("\\n", "\n") + "\n"
292
293
                else:
294
                    action = action.strip()
295
                    action = action[0].lower() + action[1:]
296
                    if "You" not in action[:6] and "I" not in action[:6]:
297
                        action = "You " + action
298
                    if action[-1] not in [".", "?", "!"]:
299
                        action = action + "."
300
                    action = first_to_second_person(action)
301
                    action = "\n> " + action + "\n"
302
                try:
303
                    result = "\n" + act(action)
304
                except FunctionTimedOut:
305
                    notify_hanged()
306
                    continue
307
                if len(story_manager.story.results) >= 2:
308
                    similarity = get_similarity(
309
                        story_manager.story.results[-1], story_manager.story.results[-2]
310
                    )
311
                    if similarity > 0.9:
312
                        story_manager.story.actions = story_manager.story.actions[:-1]
313
                        story_manager.story.results = story_manager.story.results[:-1]
314
                        console_print(
315
                            "Woops that action caused the model to start looping. Try a different action to prevent that."
316
                        )
317
                        ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
318
                        continue
319
320
                if player_won(result):
321
                    console_print(result + "\n CONGRATS YOU WIN")
322
                    ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
323
                    break
324
                elif player_died(result):
325
                    console_print(result)
326
                    console_print("YOU DIED. GAME OVER")
327
                    console_print("\nOptions:")
328
                    console_print("0) Start a new game")
329
                    console_print(
330
                        "1) \"I'm not dead yet!\" (If you didn't actually die) "
331
                    )
332
                    console_print("Which do you choose? ")
333
                    ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
334
                    choice = get_num_options(2)
335
                    if choice == 0:
336
                        break
337
                    else:
338
                        console_print("Sorry about that...where were we?")
339
                        console_print(result)
340
341
                else:
342
                    console_print(result)
343
                    ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
344
345
346
if __name__ == "__main__":
347
    play_aidungeon_2()