Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from exllamav2 import ExLlamaV2Tokenizer
- from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
- import os
- class ExLlamaV2Sampler:
- class Settings:
- token_repetition_penalty = 1.15
- token_repetition_range = -1
- token_repetition_decay = 0
- temperature = 0.9
- top_k = 40
- top_p = 0.9
- min_p = 0
- tfs = 0
- typical = 0
- mirostat = False
- mirostat_tau = 1.5
- mirostat_eta = 0.1
- mirostat_mu = None # (re)initialized from mirostat_tau on first sample
- token_bias = None
- filters = []
- def clone(self):
- c = ExLlamaV2Sampler.Settings()
- c.token_repetition_penalty = self.token_repetition_penalty
- c.token_repetition_range = self.token_repetition_range
- c.token_repetition_decay = self.token_repetition_decay
- c.temperature = self.temperature
- c.top_k = self.top_k
- c.top_p = self.top_p
- c.min_p = self.min_p
- c.tfs = self.tfs
- c.typical = self.typical
- c.mirostat = self.mirostat
- c.mirostat_tau = self.mirostat_tau
- c.mirostat_eta = self.mirostat_eta
- c.mirostat_mu = None if self.mirostat_mu is None else self.mirostat_mu.copy()
- c.token_bias = self.token_bias
- c.filters = [f.clone() for f in self.filters]
- return c
- def greedy_clone(self):
- c = ExLlamaV2Sampler.Settings()
- c.top_k = 1
- c.top_p = 0
- c.token_repetition_penalty = self.token_repetition_penalty
- c.token_repetition_range = self.token_repetition_range
- c.token_repetition_decay = self.token_repetition_decay
- c.token_bias = self.token_bias
- c.filters = []
- return c
- def disallow_tokens(self, tokenizer, tokens):
- if self.token_bias is None:
- padding = -tokenizer.config.vocab_size % 32
- self.token_bias = torch.zeros((tokenizer.config.vocab_size + padding,), dtype = torch.float)
- self.token_bias[tokens] = float("-inf")
- def begin_filters(self, prefix_str = ""):
- for f in self.filters: f.begin(prefix_str)
- def feed_filters(self, feed_token):
- for f in self.filters: f.feed(feed_token)
- @staticmethod
- def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, prefix_token = None):
- batch_size, _, vocab_size = logits.shape
- assert logits.shape[1] == 1, "Logits tensor is incorrect shape, must be (bsz, 1, vocab_size)"
- assert prefix_token is None or prefix_token.shape == (batch_size, 1), "Prefix token list doesn't match batch shape"
- assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
- logits = logits.clone().squeeze(1)
- logit_filter = torch.ones((batch_size, vocab_size), dtype = torch.bool)
- # Repetition penalty
- if settings.token_repetition_penalty != 1.0:
- ext_c.apply_rep_penalty(sequence_ids,
- settings.token_repetition_penalty,
- settings.token_repetition_range,
- settings.token_repetition_decay,
- logits)
- # Token bias
- if settings.token_bias is not None: logits += settings.token_bias
- # Evaluate filters
- if len(settings.filters) > 0:
- pass_tokens = None
- end_tokens = None
- for f in settings.filters:
- pt, et = f.next()
- pass_tokens = pt if pass_tokens is None else pass_tokens & pt
- end_tokens = et if end_tokens is None else end_tokens | et
- assert pass_tokens, "Filter excluded all tokens"
- ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
- # Healing
- if prefix_token is not None:
- prefix_id_to_ids = tokenizer.get_prefix_id_to_ids_dict()
- valid_token_lists = []
- for i in range(batch_size):
- valid_token_lists.append(prefix_id_to_ids[prefix_token[i, 0].item()])
- ext_c.logit_filter_exclusive(logit_filter, valid_token_lists)
- # for i in range(logit_filter.shape[-1]):
- # if logit_filter[0, i].item():
- # print(i)
- # Begin Mirostat
- if settings.mirostat:
- if settings.mirostat_mu is None:
- settings.mirostat_mu = [0.0] * batch_size
- # Sampling
- batch_size = logits.shape[0]
- # This is a goddamn mess because I was too stubborn to modify the C++ for accessibility's sake. Works though.
- def apply_top_k_top_p(logits, top_k, top_p):
- # If top_k is set (and greater than 0), we filter out all but the top k logits
- if top_k > 0:
- top_values, _ = logits.topk(top_k, dim=-1)
- # Create a mask to identify logits that are not in the top k
- mask = torch.ones_like(logits, dtype=torch.bool)
- mask.scatter_(1, _, False)
- # Zero out all logits that are not in the top k
- logits[mask] = -float('Inf')
- # If top_p is set (and less than 1), we filter out all but the top p probability mass
- if top_p > 0 and top_p < 1.0:
- # Sort the logits so they are in descending order for cumulative probability calculation
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- # Convert logits to probabilities
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
- # Create a mask for values to be removed, i.e., where cumulative probability > top_p
- remove_mask = cumulative_probs > top_p
- # Since we want to keep the first token that exceeds top_p, shift the mask one place to the right
- remove_mask[..., 1:] = remove_mask[..., :-1].clone()
- remove_mask[..., 0] = 0
- # Apply the mask to the sorted indices, find the original indices to be removed, and zero out the corresponding logits
- indices_to_remove = sorted_indices[remove_mask]
- logits.view(-1)[indices_to_remove.view(-1)] = -float('Inf')
- # Extract the valid logits (those that are not '-Inf' indicating they've been filtered out)
- valid_logits = logits[logits != -float('Inf')]
- return valid_logits # Now we're returning only the valid logits
- #print("-----------------------")
- # Check if the original temperature is within the specified range
- if 1.83 <= settings.temperature <= 1.84:
- # Apply the Top-K and Top-P sampling to get valid logits
- valid_logits = apply_top_k_top_p(logits.clone(), settings.top_k, settings.top_p)
- # Calculate the entropy only on the valid logits
- probabilities = torch.softmax(valid_logits, dim=-1) # Converts valid logits to probabilities.
- entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=-1)
- # Check if the entropy is negative zero, and if so, set it to zero
- if entropy.eq(-0.0).any():
- entropy = torch.zeros_like(entropy)
- #print("Entropy:", entropy.item())
- # Calculate the maximum possible entropy based on the number of valid logits
- max_entropy = torch.log(torch.tensor(float(valid_logits.size(-1))))
- #print("Max Possible Entropy:", max_entropy.item())
- # If the maximum entropy is zero (which happens if there's only one valid logit),
- # then the normalized entropy should directly be zero, as there's no uncertainty.
- if max_entropy.eq(0).any():
- normalized_entropy = torch.zeros_like(entropy)
- else:
- normalized_entropy = torch.div(entropy, max_entropy) # Normalize entropy values.
- #print("Normalized Entropy:", normalized_entropy.item())
- # Check if the normalized entropy is 'nan' (which can happen if max_entropy is zero), and if so, set it to zero
- if torch.isnan(normalized_entropy):
- normalized_entropy = torch.zeros_like(normalized_entropy)
- min_temp = 0.0
- max_temp = 2.0
- # Define the file name
- file_name = "EntropySampling.txt"
- # Check if the file exists. If not, create it with default values.
- if not os.path.exists(file_name):
- with open(file_name, 'w') as file:
- file.write("min_temp=0.0\nmax_temp=2.0\n")
- # Read the values from the file
- with open(file_name, 'r') as file:
- lines = file.readlines()
- min_temp = float(lines[0].split('=')[1])
- max_temp = float(lines[1].split('=')[1])
- #print("min_temp:", min_temp)
- #print("max_temp:", max_temp)
- # Calculate the dynamic temperature based on normalized entropy.
- # We use a simple linear mapping.
- dynamic_temperature = min_temp + (max_temp - min_temp) * normalized_entropy # Linear scaling.
- if dynamic_temperature == 0:
- dynamic_temperature = torch.full_like(dynamic_temperature, 0.00390625) # fill with 1/256th of 1 because exllama cries otherwise for some reason
- #print("Dynamic Temperature (dyn_temp):", dynamic_temperature.item())
- # Ensure the dynamic temperature is within the defined min/max range.
- dynamic_temperature = torch.clamp(dynamic_temperature, min=min_temp, max=max_temp)
- else:
- # If the temperature is not in the specified range, use the original temperature
- #print("Temperature was not set to 1.84 override value, using static temp instead:", settings.temperature)
- dynamic_temperature = settings.temperature
- #print("-----------------------")
- # Sampling step using the dynamic temperature.
- output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long)
- output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
- m = ext_c.sample_basic(logits,
- dynamic_temperature, # If override isn't on, this is just reg temp.
- settings.top_k,
- settings.top_p,
- settings.min_p,
- settings.tfs,
- settings.typical,
- random,
- output_tokens,
- output_probs,
- logit_filter,
- settings.mirostat,
- settings.mirostat_mu if settings.mirostat else [],
- settings.mirostat_tau,
- settings.mirostat_eta)
- if settings.mirostat: settings.mirostat_mu = m
- # Stop condition from filters
- end_filter = False
- if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True
- return output_tokens, output_probs, end_filter
Add Comment
Please, Sign In to add comment