| import abc |
|
|
| from typing import Any |
|
|
| import numpy as np |
| import numpy.typing as npt |
|
|
|
|
| class LlamaDraftModel(abc.ABC): |
| @abc.abstractmethod |
| def __call__( |
| self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any |
| ) -> npt.NDArray[np.intc]: |
| raise NotImplementedError() |
|
|
|
|
| class LlamaPromptLookupDecoding(LlamaDraftModel): |
| """Based on https://github.com/apoorvumang/prompt-lookup-decoding""" |
|
|
| def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10): |
| self.max_ngram_size = max_ngram_size |
| self.num_pred_tokens = num_pred_tokens |
|
|
| @staticmethod |
| def find_candidate_pred_tokens( |
| input_ids: npt.NDArray[np.intc], |
| max_ngram_size: int, |
| num_pred_tokens: int, |
| ): |
| input_length = input_ids.shape[0] |
|
|
| for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1): |
| |
| windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) |
|
|
| |
| ngram_array = input_ids[-ngram_size:] |
|
|
| |
| matches = np.all(windows == ngram_array, axis=1) |
|
|
| |
| match_indices = np.nonzero(matches)[0] |
|
|
| |
| for idx in match_indices: |
| start_idx = idx + ngram_size |
| end_idx = start_idx + num_pred_tokens |
| end_idx = min(end_idx, input_length) |
|
|
| if start_idx < end_idx: |
| return input_ids[start_idx:end_idx] |
|
|
| |
| return np.array([], dtype=np.intc) |
|
|
| def __call__( |
| self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any |
| ) -> npt.NDArray[np.intc]: |
| return self.find_candidate_pred_tokens( |
| input_ids=input_ids, |
| max_ngram_size=self.max_ngram_size, |
| num_pred_tokens=self.num_pred_tokens, |
| ) |
|
|