|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
|
import urllib.request |
|
|
import os |
|
|
from transformers import AutoTokenizer, logging |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, emb_dim, num_heads, context_length, dropout=0.1): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(emb_dim) |
|
|
self.ln2 = nn.LayerNorm(emb_dim) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
emb_dim, num_heads, dropout=dropout, batch_first=True |
|
|
) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(emb_dim, 4 * emb_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(4 * emb_dim, emb_dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
attn_out, _ = self.attn( |
|
|
self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False |
|
|
) |
|
|
x = x + attn_out |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class MiniTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size, |
|
|
emb_dim, |
|
|
context_length, |
|
|
num_heads, |
|
|
num_layers, |
|
|
dropout=0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(vocab_size, emb_dim) |
|
|
self.pos_emb = nn.Embedding(context_length, emb_dim) |
|
|
self.blocks = nn.Sequential( |
|
|
*[ |
|
|
TransformerBlock(emb_dim, num_heads, context_length, dropout) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.ln_f = nn.LayerNorm(emb_dim) |
|
|
self.head = nn.Linear(emb_dim, vocab_size, bias=False) |
|
|
self.context_length = context_length |
|
|
|
|
|
def forward(self, x): |
|
|
B, T = x.shape |
|
|
pos = torch.arange(T, device=x.device) |
|
|
x = self.emb(x) + self.pos_emb(pos) |
|
|
x = self.blocks(x) |
|
|
x = self.ln_f(x) |
|
|
logits = self.head(x) |
|
|
return logits |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None): |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
x_cond = x[:, -self.context_length :] |
|
|
|
|
|
|
|
|
logits = self(x_cond) |
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
x = torch.cat([x, next_token], dim=1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
CONTEXT_LENGTH = 128 |
|
|
EMBEDDING_DIMENSION = 512 |
|
|
HEAD_NUMBER = 4 |
|
|
N_LAYER = 4 |
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "mps") |
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="pierjoe/MiniTransformer", |
|
|
filename="checkpoints/mini_transformer_v3/model_40.safetensors", |
|
|
) |
|
|
|
|
|
|
|
|
model = MiniTransformer( |
|
|
vocab_size=tokenizer.vocab_size, |
|
|
emb_dim=EMBEDDING_DIMENSION, |
|
|
context_length=CONTEXT_LENGTH, |
|
|
num_heads=HEAD_NUMBER, |
|
|
num_layers=N_LAYER, |
|
|
).to(device) |
|
|
state_dict = load_file(model_path) |
|
|
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.eval() |
|
|
max_tokens = 100 |
|
|
prompt = "You are a helpful assistant. Provide clear, concise, and accurate responses to the user " |
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
output_ids = model.generate( |
|
|
input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10 |
|
|
) |
|
|
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
generated_text |
|
|
|