MiniTransformer / inference.py
pierjoe's picture
Upload inference.py with huggingface_hub
2b64ae9 verified
from huggingface_hub import hf_hub_download
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import urllib.request
import os
from transformers import AutoTokenizer, logging
import pandas as pd
from tqdm import tqdm
from safetensors.torch import load_file
class TransformerBlock(nn.Module):
def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(emb_dim)
self.ln2 = nn.LayerNorm(emb_dim)
self.attn = nn.MultiheadAttention(
emb_dim, num_heads, dropout=dropout, batch_first=True
)
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 4 * emb_dim),
nn.GELU(),
nn.Linear(4 * emb_dim, emb_dim),
nn.Dropout(dropout),
)
def forward(self, x):
attn_out, _ = self.attn(
self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False
)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x
class MiniTransformer(nn.Module):
def __init__(
self,
vocab_size,
emb_dim,
context_length,
num_heads,
num_layers,
dropout=0.1,
):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim)
self.pos_emb = nn.Embedding(context_length, emb_dim)
self.blocks = nn.Sequential(
*[
TransformerBlock(emb_dim, num_heads, context_length, dropout)
for _ in range(num_layers)
]
)
self.ln_f = nn.LayerNorm(emb_dim)
self.head = nn.Linear(emb_dim, vocab_size, bias=False)
self.context_length = context_length
def forward(self, x):
B, T = x.shape
pos = torch.arange(T, device=x.device)
x = self.emb(x) + self.pos_emb(pos)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
@torch.no_grad()
def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
# truncate context if needed
x_cond = x[:, -self.context_length :]
# get predictions
logits = self(x_cond) # (B, T_cond, vocab_size)
logits = logits[:, -1, :] / temperature # only last position
# optionally restrict to top-k
probs = F.softmax(logits, dim=-1)
# sample from the distribution
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
# next_token = torch.argmax(probs, dim = 1).unsqueeze(-1)
# append to sequence
x = torch.cat([x, next_token], dim=1)
return x
CONTEXT_LENGTH = 128
EMBEDDING_DIMENSION = 512
HEAD_NUMBER = 4
N_LAYER = 4
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
# Download the model file
model_path = hf_hub_download(
repo_id="pierjoe/MiniTransformer",
filename="checkpoints/mini_transformer_v3/model_40.safetensors",
)
# Load with your custom class
model = MiniTransformer(
vocab_size=tokenizer.vocab_size,
emb_dim=EMBEDDING_DIMENSION,
context_length=CONTEXT_LENGTH,
num_heads=HEAD_NUMBER,
num_layers=N_LAYER,
).to(device)
state_dict = load_file(model_path)
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
max_tokens = 100
prompt = "You are a helpful assistant. Provide clear, concise, and accurate responses to the user "
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output_ids = model.generate(
input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
generated_text