| |
|
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import tiktoken |
| |
|
| | |
| | MODEL_PATH = "chatgclm_base_2.9M.pt" |
| | VOCAB_PATH = "vocab_map.pt" |
| | TOKENIZER_NAME = "gpt2" |
| |
|
| | |
| | D_MODEL = 256 |
| | N_LAYERS = 4 |
| | MAX_SEQ_LEN = 1024 |
| | LOCAL_KERNEL_SIZE = 5 |
| | GLOBAL_KERNEL_SIZE = 256 |
| | USE_GLOBAL_EVERY_N_LAYERS = 2 |
| | FFT_SIZE = 1024 |
| |
|
| | PAD_ID = 0 |
| | SEP_ID = 1 |
| | EOS_ID = 2 |
| | OFFSET = 3 |
| | |
| |
|
| | |
| | class GlobalConv1D(nn.Module): |
| | def __init__(self, d_model, kernel_size, fft_size): |
| | super().__init__() |
| | self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
| | self.kernel_size = kernel_size |
| | self.fft_size = fft_size |
| |
|
| | def forward(self, x): |
| | B, C, T = x.shape |
| | K = min(self.kernel_size, T) |
| |
|
| | overlap = K - 1 |
| | block = self.fft_size - overlap |
| |
|
| | x = F.pad(x, (overlap, 0)) |
| | k = self.kernel[:, :K] |
| | k = F.pad(k, (0, self.fft_size - K)) |
| | k_f = torch.fft.rfft(k, n=self.fft_size) |
| |
|
| | outs = [] |
| | pos = 0 |
| | while pos < T: |
| | seg = x[..., pos:pos+self.fft_size] |
| | if seg.shape[-1] < self.fft_size: |
| | seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
| |
|
| | y = torch.fft.irfft( |
| | torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), |
| | n=self.fft_size |
| | ) |
| | outs.append(y[..., overlap:overlap+block]) |
| | pos += block |
| |
|
| | return torch.cat(outs, dim=-1)[..., :T] |
| |
|
| |
|
| | class LocalConv1D(nn.Module): |
| | def __init__(self, d_model, k): |
| | super().__init__() |
| | self.k = k |
| | self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
| | self.pw = nn.Conv1d(d_model, d_model, 1) |
| |
|
| | def forward(self, x): |
| | x = F.pad(x, (self.k - 1, 0)) |
| | return self.pw(F.relu(self.dw(x))) |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__(self, d_model, use_global): |
| | super().__init__() |
| | self.use_global = use_global |
| |
|
| | self.ln1 = nn.LayerNorm(d_model) |
| | self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) |
| |
|
| | if use_global: |
| | self.ln2 = nn.LayerNorm(d_model) |
| | self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) |
| |
|
| | self.ln3 = nn.LayerNorm(d_model) |
| | self.ff = nn.Sequential( |
| | nn.Linear(d_model, d_model*4), |
| | nn.GELU(), |
| | nn.Linear(d_model*4, d_model) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) |
| | if self.use_global: |
| | x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) |
| | return x + self.ff(self.ln3(x)) |
| |
|
| |
|
| | class GCLM(nn.Module): |
| | def __init__(self, vocab): |
| | super().__init__() |
| | self.emb = nn.Embedding(vocab, D_MODEL) |
| | self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
| |
|
| | self.layers = nn.ModuleList([ |
| | Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) |
| | for i in range(N_LAYERS) |
| | ]) |
| |
|
| | self.ln = nn.LayerNorm(D_MODEL) |
| | self.head = nn.Linear(D_MODEL, vocab) |
| | |
| | |
| | self.head.weight = self.emb.weight |
| |
|
| | def forward(self, x): |
| | T = x.size(1) |
| | h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) |
| | for layer in self.layers: |
| | h = layer(h) |
| | return self.head(self.ln(h)) |
| |
|
| |
|
| | |
| | def load_model_and_vocab(device): |
| | if not os.path.exists(VOCAB_PATH): |
| | print(f"[ERROR] Vocab file not found: {VOCAB_PATH}") |
| | return None, None, None |
| | |
| | vocab_data = torch.load(VOCAB_PATH, map_location="cpu") |
| | used_tokens = vocab_data["used_tokens"] |
| | id2new = vocab_data["id2new"] |
| | vocab_size = len(used_tokens) + OFFSET |
| | |
| | print(f"[INFO] Vocab loaded. Size: {vocab_size}") |
| |
|
| | model = GCLM(vocab_size).to(device) |
| | |
| | if os.path.exists(MODEL_PATH): |
| | print(f"[INFO] Loading model from {MODEL_PATH}...") |
| | state_dict = torch.load(MODEL_PATH, map_location=device) |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | else: |
| | print(f"[ERROR] Model file not found: {MODEL_PATH}") |
| | return None, None, None |
| |
|
| | return model, used_tokens, id2new |
| |
|
| | @torch.no_grad() |
| | def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50): |
| | model.eval() |
| | |
| | |
| | raw_ids = tokenizer.encode(prompt) |
| | input_ids = [] |
| | |
| | |
| | for rid in raw_ids: |
| | if rid in id2new: |
| | input_ids.append(id2new[rid]) |
| | else: |
| | |
| | continue |
| | |
| | if not input_ids: |
| | print("[WARN] No known tokens in prompt.") |
| | input_ids = [PAD_ID] |
| | |
| | x = torch.tensor([input_ids], dtype=torch.long, device=device) |
| |
|
| | generated = [] |
| | |
| | for _ in range(max_new_tokens): |
| | |
| | if x.size(1) > MAX_SEQ_LEN: |
| | ctx = x[:, -MAX_SEQ_LEN:] |
| | else: |
| | ctx = x |
| |
|
| | logits = model(ctx) |
| | next_token_logits = logits[:, -1, :] / temperature |
| | |
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
| | next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
| |
|
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | idx = next_token.item() |
| | |
| | if idx == EOS_ID: |
| | break |
| | |
| | x = torch.cat((x, next_token), dim=1) |
| | generated.append(idx) |
| | |
| | |
| | decoded_text = decoder(generated, used_tokens, tokenizer) |
| | return decoded_text |
| |
|
| | def decoder(ids, used_tokens, tokenizer): |
| | raw_ids = [] |
| | for i in ids: |
| | if i >= OFFSET: |
| | raw_ids.append(used_tokens[i - OFFSET]) |
| | return tokenizer.decode(raw_ids) |
| |
|
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | if torch.cuda.is_available(): |
| | device = "cuda" |
| | elif torch.backends.mps.is_available(): |
| | device = "mps" |
| | else: |
| | device = "cpu" |
| | |
| | print(f"Using device: {device}") |
| | |
| | model, used_tokens, id2new = load_model_and_vocab(device) |
| | enc = tiktoken.get_encoding(TOKENIZER_NAME) |
| | |
| | if model: |
| | |
| | newline_id = id2new.get(enc.encode("\n")[0], OFFSET) |
| | |
| | while True: |
| | print(f"\n--- Generating Sample (Temp=0.8, TopK=50) ---") |
| | print("-" * 20) |
| | |
| | x = torch.tensor([[newline_id]], dtype=torch.long, device=device) |
| | generated = [] |
| | |
| | with torch.no_grad(): |
| | for _ in range(500): |
| | if x.size(1) > MAX_SEQ_LEN: |
| | ctx = x[:, -MAX_SEQ_LEN:] |
| | else: |
| | ctx = x |
| |
|
| | logits = model(ctx) |
| | logits = logits[:, -1, :] / 0.8 |
| | |
| | |
| | v, _ = torch.topk(logits, min(50, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | idx = next_token.item() |
| | x = torch.cat((x, next_token), dim=1) |
| | generated.append(idx) |
| | |
| | if idx == EOS_ID: |
| | print("[EOS]", end="", flush=True) |
| | break |
| | |
| | if idx >= OFFSET: |
| | raw_id = used_tokens[idx - OFFSET] |
| | token_text = enc.decode([raw_id]) |
| | print(token_text, end="", flush=True) |
| | elif idx == PAD_ID: |
| | print("[PAD]", end="", flush=True) |
| | elif idx == SEP_ID: |
| | print("[SEP]", end="", flush=True) |
| |
|
| | print("\n" + "-"*20) |
| | cont = input("\nPress [Enter] to generate again, or type 'exit': ") |
| | if cont.lower() == 'exit': |
| | break |
| |
|
| |
|
| |
|