|
|
import numpy as np |
|
|
import sys |
|
|
import torch |
|
|
from typing import List, Tuple, Union |
|
|
|
|
|
from stripedhyena.model import StripedHyena |
|
|
from stripedhyena.sample import sample |
|
|
from stripedhyena.tokenizer import CharLevelTokenizer |
|
|
|
|
|
from .scoring import logits_to_logprobs, prepare_batch |
|
|
|
|
|
|
|
|
class Generator: |
|
|
''' |
|
|
Adapted from https://github.com/togethercomputer/stripedhyena. |
|
|
|
|
|
Modifications include: |
|
|
- `generate()` accepts and returns the recurrent cache state, letting the user |
|
|
keep track of it across sampling runs. |
|
|
- Able to sample with long token prompts in which the cache is initialized with |
|
|
recurrent teacher forcing. |
|
|
''' |
|
|
def __init__( |
|
|
self, |
|
|
model: StripedHyena, |
|
|
tokenizer: CharLevelTokenizer, |
|
|
top_k: int = 50, |
|
|
top_p: float = 0.7, |
|
|
temperature: float = 1., |
|
|
): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.top_k = top_k |
|
|
self.top_p = top_p |
|
|
self.temperature = temperature |
|
|
self.untils = ['\n\n'] |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
device: str, |
|
|
input_string: str = None, |
|
|
input_ids: torch.tensor = None, |
|
|
num_tokens: int = 32, |
|
|
cached_generation: bool = True, |
|
|
force_prompt_threshold: int = 128, |
|
|
print_generation: bool = True, |
|
|
verbose: bool = False, |
|
|
skip_special_tokens: bool = False, |
|
|
stop_at_eos: bool = True, |
|
|
max_seqlen: int = None, |
|
|
inference_params_dict: dict = None, |
|
|
) -> Tuple[torch.tensor, torch.tensor, dict]: |
|
|
""" |
|
|
A version of the generate() method that enables passing in and that returns the |
|
|
`inference_params_dict` for replaying cached sampling from a given state. |
|
|
""" |
|
|
if isinstance(self.tokenizer.eos, int): |
|
|
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device) |
|
|
else: |
|
|
|
|
|
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device) |
|
|
|
|
|
if input_ids is None: |
|
|
input = self.tokenizer.tokenize(input_string) |
|
|
if isinstance(input, list): |
|
|
input = torch.LongTensor(input).unsqueeze(0).to(device) |
|
|
|
|
|
else: |
|
|
input = input.unsqueeze(0).to(device) |
|
|
|
|
|
else: |
|
|
input = input_ids |
|
|
x = input |
|
|
|
|
|
if max_seqlen is not None: |
|
|
x = x[:, -max_seqlen :] |
|
|
|
|
|
num_tokens = int(num_tokens) |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
prompt_length = x.shape[1] |
|
|
prompt_forcing = prompt_length > force_prompt_threshold |
|
|
if prompt_forcing: |
|
|
forced_prompt_length = prompt_length - force_prompt_threshold |
|
|
x_force = x[:, force_prompt_threshold:] |
|
|
x = x[:, :force_prompt_threshold] |
|
|
else: |
|
|
forced_prompt_length = 0 |
|
|
|
|
|
generation = torch.empty( |
|
|
x.shape[0], |
|
|
num_tokens, |
|
|
dtype=torch.long, |
|
|
device=x.device, |
|
|
) |
|
|
|
|
|
scores = torch.empty( |
|
|
x.shape[0], |
|
|
num_tokens, |
|
|
self.tokenizer.vocab_size, |
|
|
dtype=torch.float, |
|
|
device=x.device, |
|
|
) |
|
|
|
|
|
|
|
|
prefilled = False |
|
|
|
|
|
if inference_params_dict is not None: |
|
|
cached_generation = True |
|
|
prefilled = True |
|
|
|
|
|
for key, data in inference_params_dict['mha'].key_value_memory_dict.items(): |
|
|
inference_params_dict['mha'].key_value_memory_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict['hyena'].fir_state_dict.items(): |
|
|
inference_params_dict['hyena'].fir_state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict['hyena'].state_dict.items(): |
|
|
inference_params_dict['hyena'].state_dict[key] = data.to(x.device) |
|
|
|
|
|
elif cached_generation: |
|
|
inference_params_dict = self.model.initialize_inference_params() |
|
|
inference_params_dict['mha'].max_batch_size = batch_size |
|
|
inference_params_dict['hyena'].max_batch_size = batch_size |
|
|
prefilled = False |
|
|
|
|
|
if verbose: |
|
|
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9 |
|
|
print(f'Memory after tokenization: {mem_after_tok} GB') |
|
|
print('Starting generation...') |
|
|
if input_string is not None: |
|
|
print('Prompt: ' + input_string) |
|
|
else: |
|
|
print(f'Prompt ids: {input_ids} {input_ids.shape}') |
|
|
|
|
|
for i in range(forced_prompt_length + num_tokens): |
|
|
if prefilled: |
|
|
post_prefill = True |
|
|
else: |
|
|
post_prefill = cached_generation and i > 0 |
|
|
|
|
|
|
|
|
if post_prefill: |
|
|
x = x[:, -1:] |
|
|
seqlen_offset = inference_params_dict['mha'].seqlen_offset |
|
|
|
|
|
if seqlen_offset == 0: |
|
|
seqlen_offset = input.shape[-1] |
|
|
inference_params_dict['hyena'].seqlen_offset = seqlen_offset |
|
|
inference_params_dict['mha'].seqlen_offset = seqlen_offset |
|
|
else: |
|
|
inference_params_dict['mha'].seqlen_offset += 1 |
|
|
inference_params_dict['hyena'].seqlen_offset += 1 |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
logits, inference_params_dict = self.model( |
|
|
x, |
|
|
inference_params_dict=inference_params_dict, |
|
|
) |
|
|
|
|
|
last_logits = logits[:, -1] |
|
|
|
|
|
if prompt_forcing and i < forced_prompt_length: |
|
|
new_idx = x_force[:, i] |
|
|
else: |
|
|
new_idx = sample( |
|
|
last_logits, |
|
|
top_k=self.top_k, |
|
|
top_p=self.top_p, |
|
|
temperature=self.temperature, |
|
|
) |
|
|
|
|
|
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all(): |
|
|
print('Stopping generation at EOS') |
|
|
|
|
|
if print_generation and verbose and batch_size == 1: |
|
|
print( |
|
|
f'{self.tokenizer.detokenize([new_idx.item()])}', |
|
|
end=' ', |
|
|
) |
|
|
|
|
|
if prompt_forcing: |
|
|
if i >= forced_prompt_length: |
|
|
scores[:, i - forced_prompt_length] = last_logits |
|
|
generation[:, i - forced_prompt_length] = new_idx |
|
|
else: |
|
|
scores[:, i] = last_logits |
|
|
generation[:, i] = new_idx |
|
|
|
|
|
if post_prefill: |
|
|
x = new_idx[:, None] |
|
|
else: |
|
|
x = torch.cat([x, new_idx[:, None]], dim=-1) |
|
|
|
|
|
if verbose: |
|
|
y = self.tokenizer.detokenize_batch(generation[:, : i + 1]) |
|
|
|
|
|
for until in self.untils: |
|
|
if until in y: |
|
|
y = y.split(until)[0] |
|
|
break |
|
|
|
|
|
print(f'\nInput: {input_string}, Output: {y}') |
|
|
|
|
|
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9 |
|
|
print(f'Memory after generation: {mem_end} GB') |
|
|
|
|
|
return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict |
|
|
|
|
|
|
|
|
def generate( |
|
|
prompt_seqs: List[str], |
|
|
model: StripedHyena, |
|
|
tokenizer: CharLevelTokenizer, |
|
|
n_tokens: int = 100, |
|
|
temperature: float = 0., |
|
|
top_k: int = 1, |
|
|
top_p: float = 1., |
|
|
batched: bool = True, |
|
|
prepend_bos: bool = False, |
|
|
cached_generation: bool = False, |
|
|
force_prompt_threshold: int = 128, |
|
|
verbose: int = 1, |
|
|
device: str = 'cuda:0', |
|
|
**kwargs, |
|
|
) -> Tuple[List[str], List[float]]: |
|
|
""" |
|
|
Performs generation from a list of prompts. |
|
|
If all prompts are the same length, this can do batched generation. |
|
|
Also supports cached generation for efficient sampling. |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
g = Generator( |
|
|
model, |
|
|
tokenizer, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
temperature=temperature, |
|
|
) |
|
|
|
|
|
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs) |
|
|
|
|
|
if batched and uniform_lengths: |
|
|
input_ids_list = [ |
|
|
prepare_batch( |
|
|
prompt_seqs, |
|
|
tokenizer, |
|
|
prepend_bos=prepend_bos, |
|
|
device=device, |
|
|
)[0] |
|
|
] |
|
|
else: |
|
|
if verbose: |
|
|
if not uniform_lengths: |
|
|
sys.stderr.write('Note: Prompts are of different lengths.\n') |
|
|
sys.stderr.write('Note: Will not do batched generation.\n') |
|
|
input_ids_list = [ |
|
|
prepare_batch( |
|
|
[ prompt_seq ], |
|
|
tokenizer, |
|
|
prepend_bos=prepend_bos, |
|
|
device=device, |
|
|
)[0] |
|
|
for prompt_seq in prompt_seqs |
|
|
] |
|
|
|
|
|
generated_seqs, generated_scores = [], [] |
|
|
for input_ids in input_ids_list: |
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
output_ids, logits, _ = g.generate( |
|
|
input_ids=input_ids, |
|
|
num_tokens=n_tokens, |
|
|
cached_generation=cached_generation, |
|
|
force_prompt_threshold=force_prompt_threshold, |
|
|
device=device, |
|
|
print_generation=(verbose > 1), |
|
|
verbose=(verbose > 1), |
|
|
stop_at_eos=False, |
|
|
) |
|
|
if verbose > 1: |
|
|
print('input_ids.shape', input_ids.shape) |
|
|
print('output_ids.shape', output_ids.shape) |
|
|
print('logits.shape', logits.shape) |
|
|
|
|
|
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids)) |
|
|
assert len(generated_seqs_batch) == batch_size |
|
|
generated_seqs += generated_seqs_batch |
|
|
|
|
|
logprobs = logits_to_logprobs(logits, output_ids) |
|
|
logprobs = logprobs.float().cpu().numpy() |
|
|
|
|
|
generated_scores += [ np.mean(logprobs[idx]) for idx in range(batch_size) ] |
|
|
|
|
|
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs) |
|
|
if verbose: |
|
|
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs): |
|
|
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}') |
|
|
|
|
|
return generated_seqs, generated_scores |
|
|
|