pierjoe commited on
Commit
2b64ae9
·
verified ·
1 Parent(s): 0a1fce2

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +128 -0
inference.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torch.utils.data import Dataset, DataLoader, random_split
6
+ import urllib.request
7
+ import os
8
+ from transformers import AutoTokenizer, logging
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ from safetensors.torch import load_file
12
+
13
+
14
+ class TransformerBlock(nn.Module):
15
+ def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):
16
+ super().__init__()
17
+ self.ln1 = nn.LayerNorm(emb_dim)
18
+ self.ln2 = nn.LayerNorm(emb_dim)
19
+ self.attn = nn.MultiheadAttention(
20
+ emb_dim, num_heads, dropout=dropout, batch_first=True
21
+ )
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(emb_dim, 4 * emb_dim),
24
+ nn.GELU(),
25
+ nn.Linear(4 * emb_dim, emb_dim),
26
+ nn.Dropout(dropout),
27
+ )
28
+
29
+ def forward(self, x):
30
+ attn_out, _ = self.attn(
31
+ self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False
32
+ )
33
+ x = x + attn_out
34
+ x = x + self.mlp(self.ln2(x))
35
+ return x
36
+
37
+
38
+ class MiniTransformer(nn.Module):
39
+ def __init__(
40
+ self,
41
+ vocab_size,
42
+ emb_dim,
43
+ context_length,
44
+ num_heads,
45
+ num_layers,
46
+ dropout=0.1,
47
+ ):
48
+ super().__init__()
49
+ self.emb = nn.Embedding(vocab_size, emb_dim)
50
+ self.pos_emb = nn.Embedding(context_length, emb_dim)
51
+ self.blocks = nn.Sequential(
52
+ *[
53
+ TransformerBlock(emb_dim, num_heads, context_length, dropout)
54
+ for _ in range(num_layers)
55
+ ]
56
+ )
57
+ self.ln_f = nn.LayerNorm(emb_dim)
58
+ self.head = nn.Linear(emb_dim, vocab_size, bias=False)
59
+ self.context_length = context_length
60
+
61
+ def forward(self, x):
62
+ B, T = x.shape
63
+ pos = torch.arange(T, device=x.device)
64
+ x = self.emb(x) + self.pos_emb(pos)
65
+ x = self.blocks(x)
66
+ x = self.ln_f(x)
67
+ logits = self.head(x)
68
+ return logits
69
+
70
+ @torch.no_grad()
71
+ def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None):
72
+
73
+ for _ in range(max_new_tokens):
74
+ # truncate context if needed
75
+ x_cond = x[:, -self.context_length :]
76
+
77
+ # get predictions
78
+ logits = self(x_cond) # (B, T_cond, vocab_size)
79
+ logits = logits[:, -1, :] / temperature # only last position
80
+
81
+ # optionally restrict to top-k
82
+
83
+ probs = F.softmax(logits, dim=-1)
84
+
85
+ # sample from the distribution
86
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
87
+ # next_token = torch.argmax(probs, dim = 1).unsqueeze(-1)
88
+ # append to sequence
89
+ x = torch.cat([x, next_token], dim=1)
90
+
91
+ return x
92
+
93
+
94
+ CONTEXT_LENGTH = 128
95
+ EMBEDDING_DIMENSION = 512
96
+ HEAD_NUMBER = 4
97
+ N_LAYER = 4
98
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
99
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps")
100
+
101
+ # Download the model file
102
+ model_path = hf_hub_download(
103
+ repo_id="pierjoe/MiniTransformer",
104
+ filename="checkpoints/mini_transformer_v3/model_40.safetensors",
105
+ )
106
+
107
+ # Load with your custom class
108
+ model = MiniTransformer(
109
+ vocab_size=tokenizer.vocab_size,
110
+ emb_dim=EMBEDDING_DIMENSION,
111
+ context_length=CONTEXT_LENGTH,
112
+ num_heads=HEAD_NUMBER,
113
+ num_layers=N_LAYER,
114
+ ).to(device)
115
+ state_dict = load_file(model_path)
116
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
117
+
118
+ model.load_state_dict(state_dict)
119
+
120
+ model.eval()
121
+ max_tokens = 100
122
+ prompt = "You are a helpful assistant. Provide clear, concise, and accurate responses to the user "
123
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
124
+ output_ids = model.generate(
125
+ input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10
126
+ )
127
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
128
+ generated_text