Guest User

sampler

a guest
Nov 21st, 2023
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.64 KB | None | 0 0
  1. import torch
  2. from exllamav2 import ExLlamaV2Tokenizer
  3. from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
  4. import os
  5.  
  6. class ExLlamaV2Sampler:
  7.  
  8. class Settings:
  9.  
  10. token_repetition_penalty = 1.15
  11. token_repetition_range = -1
  12. token_repetition_decay = 0
  13.  
  14. temperature = 0.9
  15. top_k = 40
  16. top_p = 0.9
  17. min_p = 0
  18. tfs = 0
  19. typical = 0
  20.  
  21. mirostat = False
  22. mirostat_tau = 1.5
  23. mirostat_eta = 0.1
  24. mirostat_mu = None # (re)initialized from mirostat_tau on first sample
  25.  
  26. token_bias = None
  27.  
  28. filters = []
  29.  
  30.  
  31. def clone(self):
  32.  
  33. c = ExLlamaV2Sampler.Settings()
  34.  
  35. c.token_repetition_penalty = self.token_repetition_penalty
  36. c.token_repetition_range = self.token_repetition_range
  37. c.token_repetition_decay = self.token_repetition_decay
  38.  
  39. c.temperature = self.temperature
  40. c.top_k = self.top_k
  41. c.top_p = self.top_p
  42. c.min_p = self.min_p
  43. c.tfs = self.tfs
  44. c.typical = self.typical
  45.  
  46. c.mirostat = self.mirostat
  47. c.mirostat_tau = self.mirostat_tau
  48. c.mirostat_eta = self.mirostat_eta
  49. c.mirostat_mu = None if self.mirostat_mu is None else self.mirostat_mu.copy()
  50.  
  51. c.token_bias = self.token_bias
  52. c.filters = [f.clone() for f in self.filters]
  53.  
  54. return c
  55.  
  56.  
  57. def greedy_clone(self):
  58.  
  59. c = ExLlamaV2Sampler.Settings()
  60. c.top_k = 1
  61. c.top_p = 0
  62. c.token_repetition_penalty = self.token_repetition_penalty
  63. c.token_repetition_range = self.token_repetition_range
  64. c.token_repetition_decay = self.token_repetition_decay
  65. c.token_bias = self.token_bias
  66. c.filters = []
  67. return c
  68.  
  69.  
  70. def disallow_tokens(self, tokenizer, tokens):
  71.  
  72. if self.token_bias is None:
  73. padding = -tokenizer.config.vocab_size % 32
  74. self.token_bias = torch.zeros((tokenizer.config.vocab_size + padding,), dtype = torch.float)
  75.  
  76. self.token_bias[tokens] = float("-inf")
  77.  
  78.  
  79. def begin_filters(self, prefix_str = ""):
  80.  
  81. for f in self.filters: f.begin(prefix_str)
  82.  
  83.  
  84. def feed_filters(self, feed_token):
  85.  
  86. for f in self.filters: f.feed(feed_token)
  87.  
  88.  
  89. @staticmethod
  90. def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, prefix_token = None):
  91.  
  92. batch_size, _, vocab_size = logits.shape
  93.  
  94. assert logits.shape[1] == 1, "Logits tensor is incorrect shape, must be (bsz, 1, vocab_size)"
  95. assert prefix_token is None or prefix_token.shape == (batch_size, 1), "Prefix token list doesn't match batch shape"
  96. assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
  97.  
  98. logits = logits.clone().squeeze(1)
  99. logit_filter = torch.ones((batch_size, vocab_size), dtype = torch.bool)
  100.  
  101. # Repetition penalty
  102.  
  103. if settings.token_repetition_penalty != 1.0:
  104.  
  105. ext_c.apply_rep_penalty(sequence_ids,
  106. settings.token_repetition_penalty,
  107. settings.token_repetition_range,
  108. settings.token_repetition_decay,
  109. logits)
  110.  
  111. # Token bias
  112.  
  113. if settings.token_bias is not None: logits += settings.token_bias
  114.  
  115. # Evaluate filters
  116.  
  117. if len(settings.filters) > 0:
  118.  
  119. pass_tokens = None
  120. end_tokens = None
  121. for f in settings.filters:
  122.  
  123. pt, et = f.next()
  124. pass_tokens = pt if pass_tokens is None else pass_tokens & pt
  125. end_tokens = et if end_tokens is None else end_tokens | et
  126.  
  127. assert pass_tokens, "Filter excluded all tokens"
  128. ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
  129.  
  130. # Healing
  131.  
  132. if prefix_token is not None:
  133.  
  134. prefix_id_to_ids = tokenizer.get_prefix_id_to_ids_dict()
  135.  
  136. valid_token_lists = []
  137. for i in range(batch_size):
  138. valid_token_lists.append(prefix_id_to_ids[prefix_token[i, 0].item()])
  139.  
  140. ext_c.logit_filter_exclusive(logit_filter, valid_token_lists)
  141.  
  142. # for i in range(logit_filter.shape[-1]):
  143. # if logit_filter[0, i].item():
  144. # print(i)
  145.  
  146. # Begin Mirostat
  147.  
  148. if settings.mirostat:
  149. if settings.mirostat_mu is None:
  150. settings.mirostat_mu = [0.0] * batch_size
  151.  
  152. # Sampling
  153.  
  154. batch_size = logits.shape[0]
  155.  
  156. # This is a goddamn mess because I was too stubborn to modify the C++ for accessibility's sake. Works though.
  157.  
  158. def apply_top_k_top_p(logits, top_k, top_p):
  159. # If top_k is set (and greater than 0), we filter out all but the top k logits
  160. if top_k > 0:
  161. top_values, _ = logits.topk(top_k, dim=-1)
  162.  
  163. # Create a mask to identify logits that are not in the top k
  164. mask = torch.ones_like(logits, dtype=torch.bool)
  165. mask.scatter_(1, _, False)
  166.  
  167. # Zero out all logits that are not in the top k
  168. logits[mask] = -float('Inf')
  169.  
  170. # If top_p is set (and less than 1), we filter out all but the top p probability mass
  171. if top_p > 0 and top_p < 1.0:
  172. # Sort the logits so they are in descending order for cumulative probability calculation
  173. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  174.  
  175. # Convert logits to probabilities
  176. cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
  177.  
  178. # Create a mask for values to be removed, i.e., where cumulative probability > top_p
  179. remove_mask = cumulative_probs > top_p
  180.  
  181. # Since we want to keep the first token that exceeds top_p, shift the mask one place to the right
  182. remove_mask[..., 1:] = remove_mask[..., :-1].clone()
  183. remove_mask[..., 0] = 0
  184.  
  185. # Apply the mask to the sorted indices, find the original indices to be removed, and zero out the corresponding logits
  186. indices_to_remove = sorted_indices[remove_mask]
  187. logits.view(-1)[indices_to_remove.view(-1)] = -float('Inf')
  188.  
  189. # Extract the valid logits (those that are not '-Inf' indicating they've been filtered out)
  190. valid_logits = logits[logits != -float('Inf')]
  191.  
  192. return valid_logits # Now we're returning only the valid logits
  193.  
  194. #print("-----------------------")
  195.  
  196. # Check if the original temperature is within the specified range
  197. if 1.83 <= settings.temperature <= 1.84:
  198. # Apply the Top-K and Top-P sampling to get valid logits
  199. valid_logits = apply_top_k_top_p(logits.clone(), settings.top_k, settings.top_p)
  200.  
  201. # Calculate the entropy only on the valid logits
  202. probabilities = torch.softmax(valid_logits, dim=-1) # Converts valid logits to probabilities.
  203. entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=-1)
  204.  
  205. # Check if the entropy is negative zero, and if so, set it to zero
  206. if entropy.eq(-0.0).any():
  207. entropy = torch.zeros_like(entropy)
  208.  
  209. #print("Entropy:", entropy.item())
  210.  
  211. # Calculate the maximum possible entropy based on the number of valid logits
  212. max_entropy = torch.log(torch.tensor(float(valid_logits.size(-1))))
  213.  
  214. #print("Max Possible Entropy:", max_entropy.item())
  215.  
  216. # If the maximum entropy is zero (which happens if there's only one valid logit),
  217. # then the normalized entropy should directly be zero, as there's no uncertainty.
  218. if max_entropy.eq(0).any():
  219. normalized_entropy = torch.zeros_like(entropy)
  220. else:
  221. normalized_entropy = torch.div(entropy, max_entropy) # Normalize entropy values.
  222.  
  223. #print("Normalized Entropy:", normalized_entropy.item())
  224.  
  225. # Check if the normalized entropy is 'nan' (which can happen if max_entropy is zero), and if so, set it to zero
  226. if torch.isnan(normalized_entropy):
  227. normalized_entropy = torch.zeros_like(normalized_entropy)
  228.  
  229. min_temp = 0.0
  230. max_temp = 2.0
  231.  
  232. # Define the file name
  233. file_name = "EntropySampling.txt"
  234.  
  235. # Check if the file exists. If not, create it with default values.
  236. if not os.path.exists(file_name):
  237. with open(file_name, 'w') as file:
  238. file.write("min_temp=0.0\nmax_temp=2.0\n")
  239.  
  240. # Read the values from the file
  241. with open(file_name, 'r') as file:
  242. lines = file.readlines()
  243. min_temp = float(lines[0].split('=')[1])
  244. max_temp = float(lines[1].split('=')[1])
  245.  
  246. #print("min_temp:", min_temp)
  247. #print("max_temp:", max_temp)
  248.  
  249. # Calculate the dynamic temperature based on normalized entropy.
  250. # We use a simple linear mapping.
  251. dynamic_temperature = min_temp + (max_temp - min_temp) * normalized_entropy # Linear scaling.
  252.  
  253. if dynamic_temperature == 0:
  254. dynamic_temperature = torch.full_like(dynamic_temperature, 0.00390625) # fill with 1/256th of 1 because exllama cries otherwise for some reason
  255.  
  256. #print("Dynamic Temperature (dyn_temp):", dynamic_temperature.item())
  257.  
  258. # Ensure the dynamic temperature is within the defined min/max range.
  259. dynamic_temperature = torch.clamp(dynamic_temperature, min=min_temp, max=max_temp)
  260. else:
  261. # If the temperature is not in the specified range, use the original temperature
  262. #print("Temperature was not set to 1.84 override value, using static temp instead:", settings.temperature)
  263. dynamic_temperature = settings.temperature
  264.  
  265. #print("-----------------------")
  266.  
  267. # Sampling step using the dynamic temperature.
  268.  
  269. output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long)
  270. output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
  271.  
  272. m = ext_c.sample_basic(logits,
  273. dynamic_temperature, # If override isn't on, this is just reg temp.
  274. settings.top_k,
  275. settings.top_p,
  276. settings.min_p,
  277. settings.tfs,
  278. settings.typical,
  279. random,
  280. output_tokens,
  281. output_probs,
  282. logit_filter,
  283. settings.mirostat,
  284. settings.mirostat_mu if settings.mirostat else [],
  285. settings.mirostat_tau,
  286. settings.mirostat_eta)
  287.  
  288. if settings.mirostat: settings.mirostat_mu = m
  289.  
  290. # Stop condition from filters
  291.  
  292. end_filter = False
  293. if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True
  294.  
  295. return output_tokens, output_probs, end_filter
  296.  
Add Comment
Please, Sign In to add comment