Guest User

temp last

a guest
Nov 21st, 2023
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.91 KB | None | 0 0
  1. import torch
  2. from exllamav2 import ExLlamaV2Tokenizer
  3. from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
  4. import os
  5.  
  6. class ExLlamaV2Sampler:
  7.  
  8. class Settings:
  9.  
  10. token_repetition_penalty = 1.15
  11. token_repetition_range = -1
  12. token_repetition_decay = 0
  13.  
  14. temperature = 0.9
  15. top_k = 40
  16. top_p = 0.9
  17. min_p = 0
  18. tfs = 0
  19. typical = 0
  20.  
  21. temperature_last = False
  22.  
  23. mirostat = False
  24. mirostat_tau = 1.5
  25. mirostat_eta = 0.1
  26. mirostat_mu = None # (re)initialized from mirostat_tau on first sample
  27.  
  28. token_bias = None
  29.  
  30. filters = []
  31.  
  32.  
  33. def clone(self):
  34.  
  35. c = ExLlamaV2Sampler.Settings()
  36.  
  37. c.token_repetition_penalty = self.token_repetition_penalty
  38. c.token_repetition_range = self.token_repetition_range
  39. c.token_repetition_decay = self.token_repetition_decay
  40.  
  41. c.temperature = self.temperature
  42. c.top_k = self.top_k
  43. c.top_p = self.top_p
  44. c.min_p = self.min_p
  45. c.tfs = self.tfs
  46. c.typical = self.typical
  47.  
  48. c.mirostat = self.mirostat
  49. c.mirostat_tau = self.mirostat_tau
  50. c.mirostat_eta = self.mirostat_eta
  51. c.mirostat_mu = None if self.mirostat_mu is None else self.mirostat_mu.copy()
  52.  
  53. c.token_bias = self.token_bias
  54. c.filters = [f.clone() for f in self.filters]
  55.  
  56. return c
  57.  
  58.  
  59. def greedy_clone(self):
  60.  
  61. c = ExLlamaV2Sampler.Settings()
  62. c.top_k = 1
  63. c.top_p = 0
  64. c.token_repetition_penalty = self.token_repetition_penalty
  65. c.token_repetition_range = self.token_repetition_range
  66. c.token_repetition_decay = self.token_repetition_decay
  67. c.token_bias = self.token_bias
  68. c.filters = []
  69. return c
  70.  
  71.  
  72. def disallow_tokens(self, tokenizer, tokens):
  73.  
  74. if self.token_bias is None:
  75. padding = -tokenizer.config.vocab_size % 32
  76. self.token_bias = torch.zeros((tokenizer.config.vocab_size + padding,), dtype = torch.float)
  77.  
  78. self.token_bias[tokens] = float("-inf")
  79.  
  80.  
  81. def begin_filters(self, prefix_str = ""):
  82.  
  83. for f in self.filters: f.begin(prefix_str)
  84.  
  85.  
  86. def feed_filters(self, feed_token):
  87.  
  88. for f in self.filters: f.feed(feed_token)
  89.  
  90.  
  91. @staticmethod
  92. def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, prefix_token = None):
  93.  
  94. batch_size, _, vocab_size = logits.shape
  95.  
  96. assert logits.shape[1] == 1, "Logits tensor is incorrect shape, must be (bsz, 1, vocab_size)"
  97. assert prefix_token is None or prefix_token.shape == (batch_size, 1), "Prefix token list doesn't match batch shape"
  98. assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
  99.  
  100. logits = logits.clone().squeeze(1)
  101. logit_filter = torch.ones((batch_size, vocab_size), dtype = torch.bool)
  102.  
  103. # Repetition penalty
  104.  
  105. if settings.token_repetition_penalty != 1.0:
  106.  
  107. ext_c.apply_rep_penalty(sequence_ids,
  108. settings.token_repetition_penalty,
  109. settings.token_repetition_range,
  110. settings.token_repetition_decay,
  111. logits)
  112.  
  113. # Token bias
  114.  
  115. if settings.token_bias is not None: logits += settings.token_bias
  116.  
  117. # Evaluate filters
  118.  
  119. if len(settings.filters) > 0:
  120.  
  121. pass_tokens = None
  122. end_tokens = None
  123. for f in settings.filters:
  124.  
  125. pt, et = f.next()
  126. pass_tokens = pt if pass_tokens is None else pass_tokens & pt
  127. end_tokens = et if end_tokens is None else end_tokens | et
  128.  
  129. assert pass_tokens, "Filter excluded all tokens"
  130. ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
  131.  
  132. # Healing
  133.  
  134. if prefix_token is not None:
  135.  
  136. prefix_id_to_ids = tokenizer.get_prefix_id_to_ids_dict()
  137.  
  138. valid_token_lists = []
  139. for i in range(batch_size):
  140. valid_token_lists.append(prefix_id_to_ids[prefix_token[i, 0].item()])
  141.  
  142. ext_c.logit_filter_exclusive(logit_filter, valid_token_lists)
  143.  
  144. # for i in range(logit_filter.shape[-1]):
  145. # if logit_filter[0, i].item():
  146. # print(i)
  147.  
  148. # Begin Mirostat
  149.  
  150. if settings.mirostat:
  151. if settings.mirostat_mu is None:
  152. settings.mirostat_mu = [0.0] * batch_size
  153.  
  154. # Sampling
  155.  
  156. batch_size = logits.shape[0]
  157.  
  158. # This is a goddamn mess because I was too stubborn to modify the C++ for accessibility's sake. Works though.
  159.  
  160. def apply_top_k_top_p(logits, top_k, top_p):
  161. # If top_k is set (and greater than 0), we filter out all but the top k logits
  162. if top_k > 0:
  163. top_values, _ = logits.topk(top_k, dim=-1)
  164.  
  165. # Create a mask to identify logits that are not in the top k
  166. mask = torch.ones_like(logits, dtype=torch.bool)
  167. mask.scatter_(1, _, False)
  168.  
  169. # Zero out all logits that are not in the top k
  170. logits[mask] = -float('Inf')
  171.  
  172. # If top_p is set (and less than 1), we filter out all but the top p probability mass
  173. if top_p > 0 and top_p < 1.0:
  174. # Sort the logits so they are in descending order for cumulative probability calculation
  175. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  176.  
  177. # Convert logits to probabilities
  178. cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
  179.  
  180. # Create a mask for values to be removed, i.e., where cumulative probability > top_p
  181. remove_mask = cumulative_probs > top_p
  182.  
  183. # Since we want to keep the first token that exceeds top_p, shift the mask one place to the right
  184. remove_mask[..., 1:] = remove_mask[..., :-1].clone()
  185. remove_mask[..., 0] = 0
  186.  
  187. # Apply the mask to the sorted indices, find the original indices to be removed, and zero out the corresponding logits
  188. indices_to_remove = sorted_indices[remove_mask]
  189. logits.view(-1)[indices_to_remove.view(-1)] = -float('Inf')
  190.  
  191. # Extract the valid logits (those that are not '-Inf' indicating they've been filtered out)
  192. valid_logits = logits[logits != -float('Inf')]
  193.  
  194. return valid_logits # Now we're returning only the valid logits
  195.  
  196. #print("-----------------------")
  197.  
  198. # Check if the original temperature is within the specified range
  199. if 1.83 <= settings.temperature <= 1.84:
  200. # Apply the Top-K and Top-P sampling to get valid logits
  201. valid_logits = apply_top_k_top_p(logits.clone(), settings.top_k, settings.top_p)
  202.  
  203. # Calculate the entropy only on the valid logits
  204. probabilities = torch.softmax(valid_logits, dim=-1) # Converts valid logits to probabilities.
  205. entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=-1)
  206.  
  207. # Check if the entropy is negative zero, and if so, set it to zero
  208. if entropy.eq(-0.0).any():
  209. entropy = torch.zeros_like(entropy)
  210.  
  211. #print("Entropy:", entropy.item())
  212.  
  213. # Calculate the maximum possible entropy based on the number of valid logits
  214. max_entropy = torch.log(torch.tensor(float(valid_logits.size(-1))))
  215.  
  216. #print("Max Possible Entropy:", max_entropy.item())
  217.  
  218. # If the maximum entropy is zero (which happens if there's only one valid logit),
  219. # then the normalized entropy should directly be zero, as there's no uncertainty.
  220. if max_entropy.eq(0).any():
  221. normalized_entropy = torch.zeros_like(entropy)
  222. else:
  223. normalized_entropy = torch.div(entropy, max_entropy) # Normalize entropy values.
  224.  
  225. #print("Normalized Entropy:", normalized_entropy.item())
  226.  
  227. # Check if the normalized entropy is 'nan' (which can happen if max_entropy is zero), and if so, set it to zero
  228. if torch.isnan(normalized_entropy):
  229. normalized_entropy = torch.zeros_like(normalized_entropy)
  230.  
  231. min_temp = 0.0
  232. max_temp = 2.0
  233.  
  234. # Define the file name
  235. file_name = "EntropySampling.txt"
  236.  
  237. # Check if the file exists. If not, create it with default values.
  238. if not os.path.exists(file_name):
  239. with open(file_name, 'w') as file:
  240. file.write("min_temp=0.0\nmax_temp=2.0\n")
  241.  
  242. # Read the values from the file
  243. with open(file_name, 'r') as file:
  244. lines = file.readlines()
  245. min_temp = float(lines[0].split('=')[1])
  246. max_temp = float(lines[1].split('=')[1])
  247.  
  248. #print("min_temp:", min_temp)
  249. #print("max_temp:", max_temp)
  250.  
  251. # Calculate the dynamic temperature based on normalized entropy.
  252. # We use a simple linear mapping.
  253. dynamic_temperature = min_temp + (max_temp - min_temp) * normalized_entropy # Linear scaling.
  254.  
  255. if dynamic_temperature == 0:
  256. dynamic_temperature = torch.full_like(dynamic_temperature, 0.00390625) # fill with 1/256th of 1 because exllama cries otherwise for some reason
  257.  
  258. #print("Dynamic Temperature (dyn_temp):", dynamic_temperature.item())
  259.  
  260. # Ensure the dynamic temperature is within the defined min/max range.
  261. dynamic_temperature = torch.clamp(dynamic_temperature, min=min_temp, max=max_temp)
  262. else:
  263. # If the temperature is not in the specified range, use the original temperature
  264. #print("Temperature was not set to 1.84 override value, using static temp instead:", settings.temperature)
  265. dynamic_temperature = settings.temperature
  266.  
  267. #print("-----------------------")
  268.  
  269. # Sampling step using the dynamic temperature.
  270.  
  271. output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long)
  272. output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
  273.  
  274. m = ext_c.sample_basic(logits,
  275. #dynamic_temperature, # If override isn't on, this is just reg temp.
  276. #1.0 if settings.temperature_last else settings.temperature,
  277. settings.top_k,
  278. settings.top_p,
  279. settings.min_p,
  280. settings.tfs,
  281. settings.typical,
  282. random,
  283. output_tokens,
  284. output_probs,
  285. logit_filter,
  286. settings.mirostat,
  287. settings.mirostat_mu if settings.mirostat else [],
  288. settings.mirostat_tau,
  289. settings.mirostat_eta,
  290. #settings.temperature if settings.temperature_last else 1.0)
  291. dynamic_temperature)
  292.  
  293. if settings.mirostat: settings.mirostat_mu = m
  294.  
  295. # Stop condition from filters
  296.  
  297. end_filter = False
  298. if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True
  299.  
  300. return output_tokens, output_probs, end_filter
  301.  
Add Comment
Please, Sign In to add comment