View difference between Paste ID: GVWA9HE6 and D6tr0a80
SHOW: | | - or go back to the newest paste.
1
import json
2
import os
3
import subprocess
4
import uuid
5
import ctypes
6
from subprocess import Popen
7
from func_timeout import func_timeout, FunctionTimedOut
8
from story.utils import *
9
10
11
class Story:
12
    def __init__(
13
        self, story_start, context="", seed=None, game_state=None, upload_story=False
14
    ):
15
        self.story_start = story_start
16
        self.context = context
17
        self.rating = -1
18
        self.upload_story = upload_story
19
20
        # list of actions. First action is the prompt length should always equal that of story blocks
21
        self.actions = []
22
23
        # list of story blocks first story block follows prompt and is intro story
24
        self.results = []
25
26
        # Only needed in constrained/cached version
27
        self.seed = seed
28
        self.choices = []
29
        self.possible_action_results = None
30
        self.uuid = None
31
32
        if game_state is None:
33
            game_state = dict()
34
        self.game_state = game_state
35
        self.memory = 20
36
37
    def __del__(self):
38
        if self.upload_story:
39
            self.save_to_local("AUTOSAVE")
40
            console_print("Game saved.")
41
            console_print(
42
                "To load the game, type 'load' and enter the following ID: AUTOSAVE"
43
            )
44
45
    def init_from_dict(self, story_dict):
46
        self.story_start = story_dict["story_start"]
47
        self.seed = story_dict["seed"]
48
        self.actions = story_dict["actions"]
49
        self.results = story_dict["results"]
50
        self.choices = story_dict["choices"]
51
        self.possible_action_results = story_dict["possible_action_results"]
52
        self.game_state = story_dict["game_state"]
53
        self.context = story_dict["context"]
54
        self.uuid = story_dict["uuid"]
55
56
        if "rating" in story_dict.keys():
57
            self.rating = story_dict["rating"]
58
        else:
59
            self.rating = -1
60
61
    def initialize_from_json(self, json_string):
62
        story_dict = json.loads(json_string)
63
        self.init_from_dict(story_dict)
64
65
    def add_to_story(self, action, story_block):
66
        self.actions.append(action)
67
        self.results.append(story_block)
68
69
    def latest_result(self, action):
70
71
        mem_ind = self.memory
72
        if len(self.results) < 2:
73
            start_context = self.story_start
74
        else:
75
            start_context = self.context
76
        context_t = start_context.split(' ')
77
        latest_result = ""
78
        while mem_ind > 0:
79
80
            if len(self.results) >= mem_ind:
81
                #1) Split the variables we will be checking into some temp variables
82
                latest_result_t = latest_result.split(' ')
83
                actions_t = self.actions[mem_ind-1].split(' ')
84
                results_t = self.results[mem_ind-1].split(' ')
85
                #2) Check total length of context_t + latest_result_t + results_t + actions_t, it must all be under 700 (?)
86
                #(haven't figured out the precise number yet, it goes out of bonds after 1024 but ~750 words is often already enough to crash it)
87
                if len(context_t) + len(latest_result_t) + len(actions_t) + len(results_t) >= 700:
88
                    break
89
                latest_result = self.actions[mem_ind-1] + self.results[mem_ind-1] + latest_result
90
            #Fuck reddit
91
            mem_ind -= 1
92
93
        latest_result = start_context + latest_result
94
        #print("##DEBUG## LATEST_RESULT FED TO AI: \n" + latest_result)
95
        return latest_result
96
97
    def __str__(self):
98
        story_list = [self.story_start]
99
        for i in range(len(self.results)):
100
            story_list.append("\n" + self.actions[i] + "\n")
101
            story_list.append("\n" + self.results[i])
102
103
        return "".join(story_list)
104
105
    def to_json(self):
106
        story_dict = {}
107
        story_dict["story_start"] = self.story_start
108
        story_dict["seed"] = self.seed
109
        story_dict["actions"] = self.actions
110
        story_dict["results"] = self.results
111
        story_dict["choices"] = self.choices
112
        story_dict["possible_action_results"] = self.possible_action_results
113
        story_dict["game_state"] = self.game_state
114
        story_dict["context"] = self.context
115
        story_dict["uuid"] = self.uuid
116
        story_dict["rating"] = self.rating
117
118
        return json.dumps(story_dict)
119
120
    def save_to_local(self, save_name):
121
        self.uuid = str(uuid.uuid1())
122
        story_json = self.to_json()
123
        file_name = "stories\AIDungeonSave_" + save_name + ".json"
124
        f = open(file_name, "w+")
125
        f.write(story_json)
126
        f.close()
127
128
    def load_from_local(self, save_name):
129
        file_name = "AIDungeonSave_" + save_name + ".json"
130
        print("Save ID that can be used to load game is: ", self.uuid)
131
132
        with open(file_name, "r") as fp:
133
            game = json.load(fp)
134
        self.init_from_dict(game)
135
136
    def save_to_storage(self):
137
        self.uuid = str(uuid.uuid1())
138
139
        story_json = self.to_json()
140
        file_name = "story" + str(self.uuid) + ".json"
141
        f = open(file_name, "w")
142
        f.write(story_json)
143
        f.close()
144
145
        FNULL = open(os.devnull, "w")
146
        p = Popen(
147
            ["gsutil", "cp", file_name, "gs://aidungeonstories"],
148
            stdout=FNULL,
149
            stderr=subprocess.STDOUT,
150
        )
151
        return self.uuid
152
153
    def load_from_storage(self, story_id):
154
155
        file_name = "stories\AIDungeonSave_" + story_id + ".json"
156
        exists = os.path.isfile(file_name)
157
158
        if exists:
159
            with open(file_name, "r") as fp:
160
                game = json.load(fp)
161
            self.init_from_dict(game)
162
            return str(self)
163
        else:
164
            return "Error save not found."
165
166
167
class StoryManager:
168
    def __init__(self, generator):
169
        self.generator = generator
170
        self.story = None
171
172
    def start_new_story(
173
        self, story_prompt, context="", game_state=None, upload_story=False
174
    ):
175
        try:
176
            block = func_timeout(180, self.generator.generate, (context + story_prompt,))
177
        except FunctionTimedOut:
178
            console_print("Error generating story, please try a different start.\n")
179
            ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True)
180
            return False
181
        block = cut_trailing_sentence(block)
182
        self.story = Story(
183
            context + story_prompt + block,
184
            context=context,
185
            game_state=game_state,
186
            upload_story=upload_story,
187
        )
188
        return True
189
    def load_new_story(self, story_id):
190
        file_name = "stories\AIDungeonSave_" + story_id + ".json"
191
        exists = os.path.isfile(file_name)
192
193
        if exists:
194
            with open(file_name, "r") as fp:
195
                game = json.load(fp)
196
            self.story = Story("")
197
            self.story.init_from_dict(game)
198
            return str(self.story)
199
        else:
200
            return "Error: save not found."
201
202
    def load_story(self, story, from_json=False):
203
        if from_json:
204
            self.story = Story("")
205
            self.story.initialize_from_json(story)
206
        else:
207
            self.story = story
208
        return str(story)
209
210
    def json_story(self):
211
        return self.story.to_json()
212
213
    def story_context(self, action):
214
        return self.story.latest_result(action)
215
216
217
class UnconstrainedStoryManager(StoryManager):
218
    def act(self, action_choice):
219
220
        result = self.generate_result(action_choice)
221
        if result == "":
222
            return result
223
        self.story.add_to_story(action_choice, result)
224
        return result
225
226
    def generate_result(self, action):
227
        try:
228
            block = self.generator.generate(self.story_context(action) + action)
229
        except BaseException as e:
230
            console_print("An exception occured: %s" % (e))
231
            return ""
232
        return block
233
    def set_context(self, context):
234
        self.story.context = context
235
    def get_context(self):
236
        return self.story.context
237
238
class ConstrainedStoryManager(StoryManager):
239
    def __init__(self, generator, action_verbs_key="classic"):
240
        super().__init__(generator)
241
        self.action_phrases = get_action_verbs(action_verbs_key)
242
        self.cache = False
243
        self.cacher = None
244
        self.seed = None
245
246
    def enable_caching(
247
        self, credentials_file=None, seed=0, bucket_name="dungeon-cache"
248
    ):
249
        self.cache = True
250
        self.cacher = Cacher(credentials_file, bucket_name)
251
        self.seed = seed
252
253
    def start_new_story(self, story_prompt, context="", game_state=None):
254
        if self.cache:
255
            return self.start_new_story_cache(story_prompt, game_state=game_state)
256
        else:
257
            return super().start_new_story(
258
                story_prompt, context=context, game_state=game_state
259
            )
260
261
    def start_new_story_generate(self, story_prompt, game_state=None):
262
        super().start_new_story(story_prompt, game_state=game_state)
263
        self.story.possible_action_results = self.get_action_results()
264
        return self.story.story_start
265
266
    def start_new_story_cache(self, story_prompt, game_state=None):
267
268
        response = self.cacher.retrieve_from_cache(self.seed, [], "story")
269
        if response is not None:
270
            story_start = story_prompt + response
271
            self.story = Story(story_start, seed=self.seed)
272
            self.story.possible_action_results = self.get_action_results()
273
        else:
274
            story_start = self.start_new_story_generate(
275
                story_prompt, game_state=game_state
276
            )
277
            self.story.seed = self.seed
278
            self.cacher.cache_file(self.seed, [], story_start, "story")
279
280
        return story_start
281
282
    def load_story(self, story, from_json=False):
283
        story_string = super().load_story(story, from_json=from_json)
284
        return story_string
285
286
    def get_possible_actions(self):
287
        if self.story.possible_action_results is None:
288
            self.story.possible_action_results = self.get_action_results()
289
290
        return [
291
            action_result[0] for action_result in self.story.possible_action_results
292
        ]
293
294
    def act(self, action_choice_str):
295
296
        try:
297
            action_choice = int(action_choice_str)
298
        except:
299
            print("Error invalid choice.")
300
            return None, None
301
302
        if action_choice < 0 or action_choice >= len(self.action_phrases):
303
            print("Error invalid choice.")
304
            return None, None
305
306
        self.story.choices.append(action_choice)
307
        action, result = self.story.possible_action_results[action_choice]
308
        self.story.add_to_story(action, result)
309
        self.story.possible_action_results = self.get_action_results()
310
        return result, self.get_possible_actions()
311
312
    def get_action_results(self):
313
        if self.cache:
314
            return self.get_action_results_cache()
315
        else:
316
            return self.get_action_results_generate()
317
318
    def get_action_results_generate(self):
319
        action_results = [
320
            self.generate_action_result(self.story_context(), phrase)
321
            for phrase in self.action_phrases
322
        ]
323
        return action_results
324
325
    def get_action_results_cache(self):
326
        response = self.cacher.retrieve_from_cache(
327
            self.story.seed, self.story.choices, "choices"
328
        )
329
330
        if response is not None:
331
            print("Retrieved from cache")
332
            return json.loads(response)
333
        else:
334
            print("Didn't receive from cache")
335
            action_results = self.get_action_results_generate()
336
            response = json.dumps(action_results)
337
            self.cacher.cache_file(
338
                self.story.seed, self.story.choices, response, "choices"
339
            )
340
            return action_results
341
342
    def generate_action_result(self, prompt, phrase, options=None):
343
344
        action_result = (
345
            phrase + " " + self.generator.generate(prompt + " " + phrase, options)
346
        )
347
        action, result = split_first_sentence(action_result)
348
        return action, result