SHOW:
|
|
- or go back to the newest paste.
1 | import json | |
2 | import os | |
3 | import subprocess | |
4 | import uuid | |
5 | import ctypes | |
6 | from subprocess import Popen | |
7 | from func_timeout import func_timeout, FunctionTimedOut | |
8 | from story.utils import * | |
9 | ||
10 | ||
11 | class Story: | |
12 | def __init__( | |
13 | self, story_start, context="", seed=None, game_state=None, upload_story=False | |
14 | ): | |
15 | self.story_start = story_start | |
16 | self.context = context | |
17 | self.rating = -1 | |
18 | self.upload_story = upload_story | |
19 | ||
20 | # list of actions. First action is the prompt length should always equal that of story blocks | |
21 | self.actions = [] | |
22 | ||
23 | # list of story blocks first story block follows prompt and is intro story | |
24 | self.results = [] | |
25 | ||
26 | # Only needed in constrained/cached version | |
27 | self.seed = seed | |
28 | self.choices = [] | |
29 | self.possible_action_results = None | |
30 | self.uuid = None | |
31 | ||
32 | if game_state is None: | |
33 | game_state = dict() | |
34 | self.game_state = game_state | |
35 | self.memory = 20 | |
36 | ||
37 | def __del__(self): | |
38 | if self.upload_story: | |
39 | self.save_to_local("AUTOSAVE") | |
40 | console_print("Game saved.") | |
41 | console_print( | |
42 | "To load the game, type 'load' and enter the following ID: AUTOSAVE" | |
43 | ) | |
44 | ||
45 | def init_from_dict(self, story_dict): | |
46 | self.story_start = story_dict["story_start"] | |
47 | self.seed = story_dict["seed"] | |
48 | self.actions = story_dict["actions"] | |
49 | self.results = story_dict["results"] | |
50 | self.choices = story_dict["choices"] | |
51 | self.possible_action_results = story_dict["possible_action_results"] | |
52 | self.game_state = story_dict["game_state"] | |
53 | self.context = story_dict["context"] | |
54 | self.uuid = story_dict["uuid"] | |
55 | ||
56 | if "rating" in story_dict.keys(): | |
57 | self.rating = story_dict["rating"] | |
58 | else: | |
59 | self.rating = -1 | |
60 | ||
61 | def initialize_from_json(self, json_string): | |
62 | story_dict = json.loads(json_string) | |
63 | self.init_from_dict(story_dict) | |
64 | ||
65 | def add_to_story(self, action, story_block): | |
66 | self.actions.append(action) | |
67 | self.results.append(story_block) | |
68 | ||
69 | def latest_result(self, action): | |
70 | ||
71 | mem_ind = self.memory | |
72 | if len(self.results) < 2: | |
73 | start_context = self.story_start | |
74 | else: | |
75 | start_context = self.context | |
76 | context_t = start_context.split(' ') | |
77 | latest_result = "" | |
78 | while mem_ind > 0: | |
79 | ||
80 | if len(self.results) >= mem_ind: | |
81 | #1) Split the variables we will be checking into some temp variables | |
82 | latest_result_t = latest_result.split(' ') | |
83 | actions_t = self.actions[mem_ind-1].split(' ') | |
84 | results_t = self.results[mem_ind-1].split(' ') | |
85 | #2) Check total length of context_t + latest_result_t + results_t + actions_t, it must all be under 700 (?) | |
86 | #(haven't figured out the precise number yet, it goes out of bonds after 1024 but ~750 words is often already enough to crash it) | |
87 | if len(context_t) + len(latest_result_t) + len(actions_t) + len(results_t) >= 700: | |
88 | break | |
89 | latest_result = self.actions[mem_ind-1] + self.results[mem_ind-1] + latest_result | |
90 | #Fuck reddit | |
91 | mem_ind -= 1 | |
92 | ||
93 | latest_result = start_context + latest_result | |
94 | #print("##DEBUG## LATEST_RESULT FED TO AI: \n" + latest_result) | |
95 | return latest_result | |
96 | ||
97 | def __str__(self): | |
98 | story_list = [self.story_start] | |
99 | for i in range(len(self.results)): | |
100 | story_list.append("\n" + self.actions[i] + "\n") | |
101 | story_list.append("\n" + self.results[i]) | |
102 | ||
103 | return "".join(story_list) | |
104 | ||
105 | def to_json(self): | |
106 | story_dict = {} | |
107 | story_dict["story_start"] = self.story_start | |
108 | story_dict["seed"] = self.seed | |
109 | story_dict["actions"] = self.actions | |
110 | story_dict["results"] = self.results | |
111 | story_dict["choices"] = self.choices | |
112 | story_dict["possible_action_results"] = self.possible_action_results | |
113 | story_dict["game_state"] = self.game_state | |
114 | story_dict["context"] = self.context | |
115 | story_dict["uuid"] = self.uuid | |
116 | story_dict["rating"] = self.rating | |
117 | ||
118 | return json.dumps(story_dict) | |
119 | ||
120 | def save_to_local(self, save_name): | |
121 | self.uuid = str(uuid.uuid1()) | |
122 | story_json = self.to_json() | |
123 | file_name = "stories\AIDungeonSave_" + save_name + ".json" | |
124 | f = open(file_name, "w+") | |
125 | f.write(story_json) | |
126 | f.close() | |
127 | ||
128 | def load_from_local(self, save_name): | |
129 | file_name = "AIDungeonSave_" + save_name + ".json" | |
130 | print("Save ID that can be used to load game is: ", self.uuid) | |
131 | ||
132 | with open(file_name, "r") as fp: | |
133 | game = json.load(fp) | |
134 | self.init_from_dict(game) | |
135 | ||
136 | def save_to_storage(self): | |
137 | self.uuid = str(uuid.uuid1()) | |
138 | ||
139 | story_json = self.to_json() | |
140 | file_name = "story" + str(self.uuid) + ".json" | |
141 | f = open(file_name, "w") | |
142 | f.write(story_json) | |
143 | f.close() | |
144 | ||
145 | FNULL = open(os.devnull, "w") | |
146 | p = Popen( | |
147 | ["gsutil", "cp", file_name, "gs://aidungeonstories"], | |
148 | stdout=FNULL, | |
149 | stderr=subprocess.STDOUT, | |
150 | ) | |
151 | return self.uuid | |
152 | ||
153 | def load_from_storage(self, story_id): | |
154 | ||
155 | file_name = "stories\AIDungeonSave_" + story_id + ".json" | |
156 | exists = os.path.isfile(file_name) | |
157 | ||
158 | if exists: | |
159 | with open(file_name, "r") as fp: | |
160 | game = json.load(fp) | |
161 | self.init_from_dict(game) | |
162 | return str(self) | |
163 | else: | |
164 | return "Error save not found." | |
165 | ||
166 | ||
167 | class StoryManager: | |
168 | def __init__(self, generator): | |
169 | self.generator = generator | |
170 | self.story = None | |
171 | ||
172 | def start_new_story( | |
173 | self, story_prompt, context="", game_state=None, upload_story=False | |
174 | ): | |
175 | try: | |
176 | block = func_timeout(180, self.generator.generate, (context + story_prompt,)) | |
177 | except FunctionTimedOut: | |
178 | console_print("Error generating story, please try a different start.\n") | |
179 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
180 | return False | |
181 | block = cut_trailing_sentence(block) | |
182 | self.story = Story( | |
183 | context + story_prompt + block, | |
184 | context=context, | |
185 | game_state=game_state, | |
186 | upload_story=upload_story, | |
187 | ) | |
188 | return True | |
189 | def load_new_story(self, story_id): | |
190 | file_name = "stories\AIDungeonSave_" + story_id + ".json" | |
191 | exists = os.path.isfile(file_name) | |
192 | ||
193 | if exists: | |
194 | with open(file_name, "r") as fp: | |
195 | game = json.load(fp) | |
196 | self.story = Story("") | |
197 | self.story.init_from_dict(game) | |
198 | return str(self.story) | |
199 | else: | |
200 | return "Error: save not found." | |
201 | ||
202 | def load_story(self, story, from_json=False): | |
203 | if from_json: | |
204 | self.story = Story("") | |
205 | self.story.initialize_from_json(story) | |
206 | else: | |
207 | self.story = story | |
208 | return str(story) | |
209 | ||
210 | def json_story(self): | |
211 | return self.story.to_json() | |
212 | ||
213 | def story_context(self, action): | |
214 | return self.story.latest_result(action) | |
215 | ||
216 | ||
217 | class UnconstrainedStoryManager(StoryManager): | |
218 | def act(self, action_choice): | |
219 | ||
220 | result = self.generate_result(action_choice) | |
221 | if result == "": | |
222 | return result | |
223 | self.story.add_to_story(action_choice, result) | |
224 | return result | |
225 | ||
226 | def generate_result(self, action): | |
227 | try: | |
228 | block = self.generator.generate(self.story_context(action) + action) | |
229 | except BaseException as e: | |
230 | console_print("An exception occured: %s" % (e)) | |
231 | return "" | |
232 | return block | |
233 | def set_context(self, context): | |
234 | self.story.context = context | |
235 | def get_context(self): | |
236 | return self.story.context | |
237 | ||
238 | class ConstrainedStoryManager(StoryManager): | |
239 | def __init__(self, generator, action_verbs_key="classic"): | |
240 | super().__init__(generator) | |
241 | self.action_phrases = get_action_verbs(action_verbs_key) | |
242 | self.cache = False | |
243 | self.cacher = None | |
244 | self.seed = None | |
245 | ||
246 | def enable_caching( | |
247 | self, credentials_file=None, seed=0, bucket_name="dungeon-cache" | |
248 | ): | |
249 | self.cache = True | |
250 | self.cacher = Cacher(credentials_file, bucket_name) | |
251 | self.seed = seed | |
252 | ||
253 | def start_new_story(self, story_prompt, context="", game_state=None): | |
254 | if self.cache: | |
255 | return self.start_new_story_cache(story_prompt, game_state=game_state) | |
256 | else: | |
257 | return super().start_new_story( | |
258 | story_prompt, context=context, game_state=game_state | |
259 | ) | |
260 | ||
261 | def start_new_story_generate(self, story_prompt, game_state=None): | |
262 | super().start_new_story(story_prompt, game_state=game_state) | |
263 | self.story.possible_action_results = self.get_action_results() | |
264 | return self.story.story_start | |
265 | ||
266 | def start_new_story_cache(self, story_prompt, game_state=None): | |
267 | ||
268 | response = self.cacher.retrieve_from_cache(self.seed, [], "story") | |
269 | if response is not None: | |
270 | story_start = story_prompt + response | |
271 | self.story = Story(story_start, seed=self.seed) | |
272 | self.story.possible_action_results = self.get_action_results() | |
273 | else: | |
274 | story_start = self.start_new_story_generate( | |
275 | story_prompt, game_state=game_state | |
276 | ) | |
277 | self.story.seed = self.seed | |
278 | self.cacher.cache_file(self.seed, [], story_start, "story") | |
279 | ||
280 | return story_start | |
281 | ||
282 | def load_story(self, story, from_json=False): | |
283 | story_string = super().load_story(story, from_json=from_json) | |
284 | return story_string | |
285 | ||
286 | def get_possible_actions(self): | |
287 | if self.story.possible_action_results is None: | |
288 | self.story.possible_action_results = self.get_action_results() | |
289 | ||
290 | return [ | |
291 | action_result[0] for action_result in self.story.possible_action_results | |
292 | ] | |
293 | ||
294 | def act(self, action_choice_str): | |
295 | ||
296 | try: | |
297 | action_choice = int(action_choice_str) | |
298 | except: | |
299 | print("Error invalid choice.") | |
300 | return None, None | |
301 | ||
302 | if action_choice < 0 or action_choice >= len(self.action_phrases): | |
303 | print("Error invalid choice.") | |
304 | return None, None | |
305 | ||
306 | self.story.choices.append(action_choice) | |
307 | action, result = self.story.possible_action_results[action_choice] | |
308 | self.story.add_to_story(action, result) | |
309 | self.story.possible_action_results = self.get_action_results() | |
310 | return result, self.get_possible_actions() | |
311 | ||
312 | def get_action_results(self): | |
313 | if self.cache: | |
314 | return self.get_action_results_cache() | |
315 | else: | |
316 | return self.get_action_results_generate() | |
317 | ||
318 | def get_action_results_generate(self): | |
319 | action_results = [ | |
320 | self.generate_action_result(self.story_context(), phrase) | |
321 | for phrase in self.action_phrases | |
322 | ] | |
323 | return action_results | |
324 | ||
325 | def get_action_results_cache(self): | |
326 | response = self.cacher.retrieve_from_cache( | |
327 | self.story.seed, self.story.choices, "choices" | |
328 | ) | |
329 | ||
330 | if response is not None: | |
331 | print("Retrieved from cache") | |
332 | return json.loads(response) | |
333 | else: | |
334 | print("Didn't receive from cache") | |
335 | action_results = self.get_action_results_generate() | |
336 | response = json.dumps(action_results) | |
337 | self.cacher.cache_file( | |
338 | self.story.seed, self.story.choices, response, "choices" | |
339 | ) | |
340 | return action_results | |
341 | ||
342 | def generate_action_result(self, prompt, phrase, options=None): | |
343 | ||
344 | action_result = ( | |
345 | phrase + " " + self.generator.generate(prompt + " " + phrase, options) | |
346 | ) | |
347 | action, result = split_first_sentence(action_result) | |
348 | return action, result |