SHOW:
|
|
- or go back to the newest paste.
| 1 | import json | |
| 2 | import os | |
| 3 | import subprocess | |
| 4 | import uuid | |
| 5 | import ctypes | |
| 6 | from subprocess import Popen | |
| 7 | from func_timeout import func_timeout, FunctionTimedOut | |
| 8 | from story.utils import * | |
| 9 | ||
| 10 | ||
| 11 | class Story: | |
| 12 | def __init__( | |
| 13 | self, story_start, context="", seed=None, game_state=None, upload_story=False | |
| 14 | ): | |
| 15 | self.story_start = story_start | |
| 16 | self.context = context | |
| 17 | self.rating = -1 | |
| 18 | self.upload_story = upload_story | |
| 19 | ||
| 20 | # list of actions. First action is the prompt length should always equal that of story blocks | |
| 21 | self.actions = [] | |
| 22 | ||
| 23 | # list of story blocks first story block follows prompt and is intro story | |
| 24 | self.results = [] | |
| 25 | ||
| 26 | # Only needed in constrained/cached version | |
| 27 | self.seed = seed | |
| 28 | self.choices = [] | |
| 29 | self.possible_action_results = None | |
| 30 | self.uuid = None | |
| 31 | ||
| 32 | if game_state is None: | |
| 33 | game_state = dict() | |
| 34 | self.game_state = game_state | |
| 35 | self.memory = 20 | |
| 36 | ||
| 37 | def __del__(self): | |
| 38 | if self.upload_story: | |
| 39 | self.save_to_local("AUTOSAVE")
| |
| 40 | console_print("Game saved.")
| |
| 41 | console_print( | |
| 42 | "To load the game, type 'load' and enter the following ID: AUTOSAVE" | |
| 43 | ) | |
| 44 | ||
| 45 | def init_from_dict(self, story_dict): | |
| 46 | self.story_start = story_dict["story_start"] | |
| 47 | self.seed = story_dict["seed"] | |
| 48 | self.actions = story_dict["actions"] | |
| 49 | self.results = story_dict["results"] | |
| 50 | self.choices = story_dict["choices"] | |
| 51 | self.possible_action_results = story_dict["possible_action_results"] | |
| 52 | self.game_state = story_dict["game_state"] | |
| 53 | self.context = story_dict["context"] | |
| 54 | self.uuid = story_dict["uuid"] | |
| 55 | ||
| 56 | if "rating" in story_dict.keys(): | |
| 57 | self.rating = story_dict["rating"] | |
| 58 | else: | |
| 59 | self.rating = -1 | |
| 60 | ||
| 61 | def initialize_from_json(self, json_string): | |
| 62 | story_dict = json.loads(json_string) | |
| 63 | self.init_from_dict(story_dict) | |
| 64 | ||
| 65 | def add_to_story(self, action, story_block): | |
| 66 | self.actions.append(action) | |
| 67 | self.results.append(story_block) | |
| 68 | ||
| 69 | def latest_result(self, action): | |
| 70 | ||
| 71 | mem_ind = self.memory | |
| 72 | if len(self.results) < 2: | |
| 73 | start_context = self.story_start | |
| 74 | else: | |
| 75 | start_context = self.context | |
| 76 | context_t = start_context.split(' ')
| |
| 77 | latest_result = "" | |
| 78 | while mem_ind > 0: | |
| 79 | ||
| 80 | if len(self.results) >= mem_ind: | |
| 81 | #1) Split the variables we will be checking into some temp variables | |
| 82 | latest_result_t = latest_result.split(' ')
| |
| 83 | actions_t = self.actions[mem_ind-1].split(' ')
| |
| 84 | results_t = self.results[mem_ind-1].split(' ')
| |
| 85 | #2) Check total length of context_t + latest_result_t + results_t + actions_t, it must all be under 700 (?) | |
| 86 | #(haven't figured out the precise number yet, it goes out of bonds after 1024 but ~750 words is often already enough to crash it) | |
| 87 | if len(context_t) + len(latest_result_t) + len(actions_t) + len(results_t) >= 700: | |
| 88 | break | |
| 89 | latest_result = self.actions[mem_ind-1] + self.results[mem_ind-1] + latest_result | |
| 90 | #Fuck reddit | |
| 91 | mem_ind -= 1 | |
| 92 | ||
| 93 | latest_result = start_context + latest_result | |
| 94 | #print("##DEBUG## LATEST_RESULT FED TO AI: \n" + latest_result)
| |
| 95 | return latest_result | |
| 96 | ||
| 97 | def __str__(self): | |
| 98 | story_list = [self.story_start] | |
| 99 | for i in range(len(self.results)): | |
| 100 | story_list.append("\n" + self.actions[i] + "\n")
| |
| 101 | story_list.append("\n" + self.results[i])
| |
| 102 | ||
| 103 | return "".join(story_list) | |
| 104 | ||
| 105 | def to_json(self): | |
| 106 | story_dict = {}
| |
| 107 | story_dict["story_start"] = self.story_start | |
| 108 | story_dict["seed"] = self.seed | |
| 109 | story_dict["actions"] = self.actions | |
| 110 | story_dict["results"] = self.results | |
| 111 | story_dict["choices"] = self.choices | |
| 112 | story_dict["possible_action_results"] = self.possible_action_results | |
| 113 | story_dict["game_state"] = self.game_state | |
| 114 | story_dict["context"] = self.context | |
| 115 | story_dict["uuid"] = self.uuid | |
| 116 | story_dict["rating"] = self.rating | |
| 117 | ||
| 118 | return json.dumps(story_dict) | |
| 119 | ||
| 120 | def save_to_local(self, save_name): | |
| 121 | self.uuid = str(uuid.uuid1()) | |
| 122 | story_json = self.to_json() | |
| 123 | file_name = "stories\AIDungeonSave_" + save_name + ".json" | |
| 124 | f = open(file_name, "w+") | |
| 125 | f.write(story_json) | |
| 126 | f.close() | |
| 127 | ||
| 128 | def load_from_local(self, save_name): | |
| 129 | file_name = "AIDungeonSave_" + save_name + ".json" | |
| 130 | print("Save ID that can be used to load game is: ", self.uuid)
| |
| 131 | ||
| 132 | with open(file_name, "r") as fp: | |
| 133 | game = json.load(fp) | |
| 134 | self.init_from_dict(game) | |
| 135 | ||
| 136 | def save_to_storage(self): | |
| 137 | self.uuid = str(uuid.uuid1()) | |
| 138 | ||
| 139 | story_json = self.to_json() | |
| 140 | file_name = "story" + str(self.uuid) + ".json" | |
| 141 | f = open(file_name, "w") | |
| 142 | f.write(story_json) | |
| 143 | f.close() | |
| 144 | ||
| 145 | FNULL = open(os.devnull, "w") | |
| 146 | p = Popen( | |
| 147 | ["gsutil", "cp", file_name, "gs://aidungeonstories"], | |
| 148 | stdout=FNULL, | |
| 149 | stderr=subprocess.STDOUT, | |
| 150 | ) | |
| 151 | return self.uuid | |
| 152 | ||
| 153 | def load_from_storage(self, story_id): | |
| 154 | ||
| 155 | file_name = "stories\AIDungeonSave_" + story_id + ".json" | |
| 156 | exists = os.path.isfile(file_name) | |
| 157 | ||
| 158 | if exists: | |
| 159 | with open(file_name, "r") as fp: | |
| 160 | game = json.load(fp) | |
| 161 | self.init_from_dict(game) | |
| 162 | return str(self) | |
| 163 | else: | |
| 164 | return "Error save not found." | |
| 165 | ||
| 166 | ||
| 167 | class StoryManager: | |
| 168 | def __init__(self, generator): | |
| 169 | self.generator = generator | |
| 170 | self.story = None | |
| 171 | ||
| 172 | def start_new_story( | |
| 173 | self, story_prompt, context="", game_state=None, upload_story=False | |
| 174 | ): | |
| 175 | try: | |
| 176 | block = func_timeout(180, self.generator.generate, (context + story_prompt,)) | |
| 177 | except FunctionTimedOut: | |
| 178 | console_print("Error generating story, please try a different start.\n")
| |
| 179 | ctypes.windll.user32.FlashWindow(ctypes.windll.kernel32.GetConsoleWindow(), True) | |
| 180 | return False | |
| 181 | block = cut_trailing_sentence(block) | |
| 182 | self.story = Story( | |
| 183 | context + story_prompt + block, | |
| 184 | context=context, | |
| 185 | game_state=game_state, | |
| 186 | upload_story=upload_story, | |
| 187 | ) | |
| 188 | return True | |
| 189 | def load_new_story(self, story_id): | |
| 190 | file_name = "stories\AIDungeonSave_" + story_id + ".json" | |
| 191 | exists = os.path.isfile(file_name) | |
| 192 | ||
| 193 | if exists: | |
| 194 | with open(file_name, "r") as fp: | |
| 195 | game = json.load(fp) | |
| 196 | self.story = Story("")
| |
| 197 | self.story.init_from_dict(game) | |
| 198 | return str(self.story) | |
| 199 | else: | |
| 200 | return "Error: save not found." | |
| 201 | ||
| 202 | def load_story(self, story, from_json=False): | |
| 203 | if from_json: | |
| 204 | self.story = Story("")
| |
| 205 | self.story.initialize_from_json(story) | |
| 206 | else: | |
| 207 | self.story = story | |
| 208 | return str(story) | |
| 209 | ||
| 210 | def json_story(self): | |
| 211 | return self.story.to_json() | |
| 212 | ||
| 213 | def story_context(self, action): | |
| 214 | return self.story.latest_result(action) | |
| 215 | ||
| 216 | ||
| 217 | class UnconstrainedStoryManager(StoryManager): | |
| 218 | def act(self, action_choice): | |
| 219 | ||
| 220 | result = self.generate_result(action_choice) | |
| 221 | if result == "": | |
| 222 | return result | |
| 223 | self.story.add_to_story(action_choice, result) | |
| 224 | return result | |
| 225 | ||
| 226 | def generate_result(self, action): | |
| 227 | try: | |
| 228 | block = self.generator.generate(self.story_context(action) + action) | |
| 229 | except BaseException as e: | |
| 230 | console_print("An exception occured: %s" % (e))
| |
| 231 | return "" | |
| 232 | return block | |
| 233 | def set_context(self, context): | |
| 234 | self.story.context = context | |
| 235 | def get_context(self): | |
| 236 | return self.story.context | |
| 237 | ||
| 238 | class ConstrainedStoryManager(StoryManager): | |
| 239 | def __init__(self, generator, action_verbs_key="classic"): | |
| 240 | super().__init__(generator) | |
| 241 | self.action_phrases = get_action_verbs(action_verbs_key) | |
| 242 | self.cache = False | |
| 243 | self.cacher = None | |
| 244 | self.seed = None | |
| 245 | ||
| 246 | def enable_caching( | |
| 247 | self, credentials_file=None, seed=0, bucket_name="dungeon-cache" | |
| 248 | ): | |
| 249 | self.cache = True | |
| 250 | self.cacher = Cacher(credentials_file, bucket_name) | |
| 251 | self.seed = seed | |
| 252 | ||
| 253 | def start_new_story(self, story_prompt, context="", game_state=None): | |
| 254 | if self.cache: | |
| 255 | return self.start_new_story_cache(story_prompt, game_state=game_state) | |
| 256 | else: | |
| 257 | return super().start_new_story( | |
| 258 | story_prompt, context=context, game_state=game_state | |
| 259 | ) | |
| 260 | ||
| 261 | def start_new_story_generate(self, story_prompt, game_state=None): | |
| 262 | super().start_new_story(story_prompt, game_state=game_state) | |
| 263 | self.story.possible_action_results = self.get_action_results() | |
| 264 | return self.story.story_start | |
| 265 | ||
| 266 | def start_new_story_cache(self, story_prompt, game_state=None): | |
| 267 | ||
| 268 | response = self.cacher.retrieve_from_cache(self.seed, [], "story") | |
| 269 | if response is not None: | |
| 270 | story_start = story_prompt + response | |
| 271 | self.story = Story(story_start, seed=self.seed) | |
| 272 | self.story.possible_action_results = self.get_action_results() | |
| 273 | else: | |
| 274 | story_start = self.start_new_story_generate( | |
| 275 | story_prompt, game_state=game_state | |
| 276 | ) | |
| 277 | self.story.seed = self.seed | |
| 278 | self.cacher.cache_file(self.seed, [], story_start, "story") | |
| 279 | ||
| 280 | return story_start | |
| 281 | ||
| 282 | def load_story(self, story, from_json=False): | |
| 283 | story_string = super().load_story(story, from_json=from_json) | |
| 284 | return story_string | |
| 285 | ||
| 286 | def get_possible_actions(self): | |
| 287 | if self.story.possible_action_results is None: | |
| 288 | self.story.possible_action_results = self.get_action_results() | |
| 289 | ||
| 290 | return [ | |
| 291 | action_result[0] for action_result in self.story.possible_action_results | |
| 292 | ] | |
| 293 | ||
| 294 | def act(self, action_choice_str): | |
| 295 | ||
| 296 | try: | |
| 297 | action_choice = int(action_choice_str) | |
| 298 | except: | |
| 299 | print("Error invalid choice.")
| |
| 300 | return None, None | |
| 301 | ||
| 302 | if action_choice < 0 or action_choice >= len(self.action_phrases): | |
| 303 | print("Error invalid choice.")
| |
| 304 | return None, None | |
| 305 | ||
| 306 | self.story.choices.append(action_choice) | |
| 307 | action, result = self.story.possible_action_results[action_choice] | |
| 308 | self.story.add_to_story(action, result) | |
| 309 | self.story.possible_action_results = self.get_action_results() | |
| 310 | return result, self.get_possible_actions() | |
| 311 | ||
| 312 | def get_action_results(self): | |
| 313 | if self.cache: | |
| 314 | return self.get_action_results_cache() | |
| 315 | else: | |
| 316 | return self.get_action_results_generate() | |
| 317 | ||
| 318 | def get_action_results_generate(self): | |
| 319 | action_results = [ | |
| 320 | self.generate_action_result(self.story_context(), phrase) | |
| 321 | for phrase in self.action_phrases | |
| 322 | ] | |
| 323 | return action_results | |
| 324 | ||
| 325 | def get_action_results_cache(self): | |
| 326 | response = self.cacher.retrieve_from_cache( | |
| 327 | self.story.seed, self.story.choices, "choices" | |
| 328 | ) | |
| 329 | ||
| 330 | if response is not None: | |
| 331 | print("Retrieved from cache")
| |
| 332 | return json.loads(response) | |
| 333 | else: | |
| 334 | print("Didn't receive from cache")
| |
| 335 | action_results = self.get_action_results_generate() | |
| 336 | response = json.dumps(action_results) | |
| 337 | self.cacher.cache_file( | |
| 338 | self.story.seed, self.story.choices, response, "choices" | |
| 339 | ) | |
| 340 | return action_results | |
| 341 | ||
| 342 | def generate_action_result(self, prompt, phrase, options=None): | |
| 343 | ||
| 344 | action_result = ( | |
| 345 | phrase + " " + self.generator.generate(prompt + " " + phrase, options) | |
| 346 | ) | |
| 347 | action, result = split_first_sentence(action_result) | |
| 348 | return action, result |