|
|
from __future__ import annotations |
|
|
|
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from typing import Sequence |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils.generic import ModelOutput |
|
|
|
|
|
from .config import CircuitGPTConfig |
|
|
from .gpt import GPT |
|
|
from .hook_utils import hook_recorder |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CircuitGPTCausalLMOutput(ModelOutput): |
|
|
loss: torch.Tensor | None = None |
|
|
logits: torch.Tensor | None = None |
|
|
activations: dict[str, torch.Tensor] | None = None |
|
|
|
|
|
|
|
|
def _activations_regex(keys: Sequence[str]) -> str: |
|
|
escaped = (re.escape(k) for k in keys) |
|
|
return "^(" + "|".join(escaped) + ")$" |
|
|
|
|
|
|
|
|
class CircuitGPTPreTrainedModel(PreTrainedModel): |
|
|
config_class = CircuitGPTConfig |
|
|
base_model_prefix = "circuit_model" |
|
|
circuit_model: GPT |
|
|
|
|
|
def __init__(self, config: CircuitGPTConfig, *inputs, **kwargs) -> None: |
|
|
super().__init__(config, *inputs, **kwargs) |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
|
return self.circuit_model.transformer.wte |
|
|
|
|
|
def set_input_embeddings(self, value: nn.Module) -> None: |
|
|
self.circuit_model.transformer.wte = value |
|
|
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
|
return self.circuit_model.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
|
|
self.circuit_model.lm_head = new_embeddings |
|
|
|
|
|
|
|
|
class CircuitGPTForCausalLM(CircuitGPTPreTrainedModel, GenerationMixin): |
|
|
""" |
|
|
Hugging Face-compatible wrapper around `circuit_sparsity.gpt.GPT`. |
|
|
All math happens inside the original module so parity is guaranteed. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: CircuitGPTConfig, circuit_model: GPT | None = None) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
if circuit_model is None: |
|
|
self.circuit_model = GPT(config.to_circuit_config()) |
|
|
self.post_init() |
|
|
else: |
|
|
self.circuit_model = circuit_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_circuit_model(cls, circuit_model: GPT) -> "CircuitGPTForCausalLM": |
|
|
config = CircuitGPTConfig.from_circuit_config(circuit_model.config) |
|
|
return cls(config, circuit_model=circuit_model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
labels: torch.LongTensor | None = None, |
|
|
output_activations: Sequence[str] | None = None, |
|
|
return_dict: bool | None = None, |
|
|
use_cache: bool | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
**kwargs, |
|
|
) -> CircuitGPTCausalLMOutput: |
|
|
|
|
|
remaining_kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
|
if remaining_kwargs: |
|
|
unsupported = ", ".join(remaining_kwargs.keys()) |
|
|
raise ValueError(f"Unsupported arguments for CircuitGPTForCausalLM: {unsupported}") |
|
|
|
|
|
if input_ids.size(-1) > self.config.block_size: |
|
|
raise ValueError( |
|
|
f"Sequence length {input_ids.size(-1)} exceeds block size {self.config.block_size}" |
|
|
) |
|
|
|
|
|
if output_activations: |
|
|
regex = _activations_regex(output_activations) |
|
|
with hook_recorder(regex=regex) as recorded: |
|
|
logits, loss, _ = self.circuit_model(input_ids, targets=labels) |
|
|
activations = {key: recorded[key] for key in output_activations if key in recorded} |
|
|
else: |
|
|
activations = None |
|
|
logits, loss, _ = self.circuit_model(input_ids, targets=labels) |
|
|
|
|
|
if labels is None: |
|
|
loss = None |
|
|
|
|
|
return CircuitGPTCausalLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
activations=activations, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs): |
|
|
if input_ids.size(-1) > self.config.block_size: |
|
|
input_ids = input_ids[:, -self.config.block_size :] |
|
|
return {"input_ids": input_ids} |
|
|
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
|
|
|
|
return past |
|
|
|