from __future__ import annotations import re from dataclasses import dataclass from typing import Sequence import torch from torch import nn from transformers.generation.utils import GenerationMixin from transformers.modeling_utils import PreTrainedModel from transformers.utils.generic import ModelOutput from .config import CircuitGPTConfig from .gpt import GPT from .hook_utils import hook_recorder @dataclass class CircuitGPTCausalLMOutput(ModelOutput): loss: torch.Tensor | None = None logits: torch.Tensor | None = None activations: dict[str, torch.Tensor] | None = None def _activations_regex(keys: Sequence[str]) -> str: escaped = (re.escape(k) for k in keys) return "^(" + "|".join(escaped) + ")$" class CircuitGPTPreTrainedModel(PreTrainedModel): config_class = CircuitGPTConfig base_model_prefix = "circuit_model" circuit_model: GPT def __init__(self, config: CircuitGPTConfig, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) def get_input_embeddings(self) -> nn.Module: return self.circuit_model.transformer.wte # type: ignore[return-value] def set_input_embeddings(self, value: nn.Module) -> None: self.circuit_model.transformer.wte = value # type: ignore[assignment] def get_output_embeddings(self) -> nn.Module: return self.circuit_model.lm_head # type: ignore[return-value] def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.circuit_model.lm_head = new_embeddings # type: ignore[assignment] class CircuitGPTForCausalLM(CircuitGPTPreTrainedModel, GenerationMixin): """ Hugging Face-compatible wrapper around `circuit_sparsity.gpt.GPT`. All math happens inside the original module so parity is guaranteed. """ def __init__(self, config: CircuitGPTConfig, circuit_model: GPT | None = None) -> None: super().__init__(config) if circuit_model is None: self.circuit_model = GPT(config.to_circuit_config()) self.post_init() else: self.circuit_model = circuit_model # ------------------------------------------------------------------ # Constructors # ------------------------------------------------------------------ @classmethod def from_circuit_model(cls, circuit_model: GPT) -> "CircuitGPTForCausalLM": config = CircuitGPTConfig.from_circuit_config(circuit_model.config) return cls(config, circuit_model=circuit_model) # ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------ def forward( self, input_ids: torch.Tensor, labels: torch.LongTensor | None = None, output_activations: Sequence[str] | None = None, return_dict: bool | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, **kwargs, ) -> CircuitGPTCausalLMOutput: # Ignore HF generation kwargs we don't use; surface any unknowns. remaining_kwargs = {k: v for k, v in kwargs.items() if v is not None} if remaining_kwargs: unsupported = ", ".join(remaining_kwargs.keys()) raise ValueError(f"Unsupported arguments for CircuitGPTForCausalLM: {unsupported}") if input_ids.size(-1) > self.config.block_size: raise ValueError( f"Sequence length {input_ids.size(-1)} exceeds block size {self.config.block_size}" ) if output_activations: regex = _activations_regex(output_activations) with hook_recorder(regex=regex) as recorded: logits, loss, _ = self.circuit_model(input_ids, targets=labels) activations = {key: recorded[key] for key in output_activations if key in recorded} else: activations = None logits, loss, _ = self.circuit_model(input_ids, targets=labels) if labels is None: loss = None return CircuitGPTCausalLMOutput( loss=loss, logits=logits, activations=activations, ) # ------------------------------------------------------------------ # Generation helpers # ------------------------------------------------------------------ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs): if input_ids.size(-1) > self.config.block_size: input_ids = input_ids[:, -self.config.block_size :] return {"input_ids": input_ids} def _reorder_cache(self, past, beam_idx): # No KV cache implemented; method exists for interface completeness. return past