from __future__ import annotations from typing import Any from transformers import PretrainedConfig class CircuitGPTConfig(PretrainedConfig): """ Minimal Hugging Face config wrapper around the circuit_sparsity GPTConfig. Only the fields exercised by the Neuronpedia runs are exposed. """ model_type = "circuitgpt" def __init__( self, vocab_size: int = 2048, block_size: int = 256, n_layer: int = 8, n_head: int = 8, d_model: int = 1024, d_mlp: int | None = None, d_head: int | None = None, dropout: float = 0.0, bias: bool = True, ln_bias: bool = True, rms_norm: bool = True, activation_type: str = "gelu", residual_activation_type: str = "identity", tied_unembed: bool = False, unembed_rank: int | None = None, afrac: float | None = None, afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out", flash: bool = True, use_position_embeddings: bool = False, sink: bool = False, enable_bigram_table: bool = False, learnable_bigram_table: bool = False, bigram_table_rank: int | None = None, dropout_cat_pos_emb: bool = False, sinusoidal_cat_pos_emb: bool = False, d_pos_emb: int | None = None, auto_map: dict[str, str] | None = None, **kwargs: Any, ) -> None: # Drop unsupported/sensitive keys that may be present in a loaded config. for key in [ "afrac_ste", "afrac_ste_only_non_neurons", "afrac_approx", "rtopk", "mup", "mup_width_multiplier", "grad_checkpointing", "enable_fp8_linear", "scale_invariance", "cat_pos_emb", ]: kwargs.pop(key, None) d_mlp = d_mlp or 4 * d_model d_head = d_head or d_model // n_head # Avoid duplicate kwargs when loading from a config dict. bos_token_id = kwargs.pop("bos_token_id", None) eos_token_id = kwargs.pop("eos_token_id", vocab_size - 1) pad_token_id = kwargs.pop("pad_token_id", None) super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs, ) self.vocab_size = vocab_size self.block_size = block_size self.max_position_embeddings = block_size self.n_layer = n_layer self.n_head = n_head self.d_model = d_model self.d_mlp = d_mlp self.d_head = d_head self.dropout = dropout self.bias = bias self.ln_bias = ln_bias self.rms_norm = rms_norm self.activation_type = activation_type self.residual_activation_type = residual_activation_type self.tied_unembed = tied_unembed self.unembed_rank = unembed_rank self.afrac = afrac self.afrac_loctypes = afrac_loctypes self.flash = flash self.use_position_embeddings = use_position_embeddings self.d_pos_emb = d_pos_emb self.sink = sink self.enable_bigram_table = enable_bigram_table self.learnable_bigram_table = learnable_bigram_table self.bigram_table_rank = bigram_table_rank self.dropout_cat_pos_emb = dropout_cat_pos_emb self.sinusoidal_cat_pos_emb = sinusoidal_cat_pos_emb self.is_decoder = True # Provide explicit auto_map entries so AutoModel/AutoConfig can locate # the custom classes when trust_remote_code=True on the Hub. self.auto_map = auto_map or { "AutoConfig": "config.CircuitGPTConfig", "AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM", } # --------------------------------------------------------------------- # Conversion helpers # --------------------------------------------------------------------- @classmethod def from_circuit_config(cls, circuit_config: "GPTConfig") -> "CircuitGPTConfig": # type: ignore[name-defined] config_dict: dict[str, Any] = { "vocab_size": circuit_config.vocab_size, "block_size": circuit_config.block_size, "n_layer": circuit_config.n_layer, "n_head": circuit_config.n_head, "d_model": circuit_config.d_model, "d_mlp": circuit_config.d_mlp, "d_head": circuit_config.d_head, "dropout": circuit_config.dropout, "bias": circuit_config.bias, "ln_bias": circuit_config.ln_bias, "rms_norm": circuit_config.rms_norm, "activation_type": circuit_config.activation_type, "residual_activation_type": circuit_config.residual_activation_type, "tied_unembed": circuit_config.tied_unembed, "unembed_rank": circuit_config.unembed_rank, "afrac": circuit_config.afrac, "afrac_loctypes": circuit_config.afrac_loctypes, "flash": circuit_config.flash, "use_position_embeddings": circuit_config.d_pos_emb is not None, "d_pos_emb": getattr(circuit_config, "d_pos_emb", None), "sink": getattr(circuit_config, "sink", False), "enable_bigram_table": getattr(circuit_config, "enable_bigram_table", False), "learnable_bigram_table": getattr(circuit_config, "learnable_bigram_table", False), "bigram_table_rank": getattr(circuit_config, "bigram_table_rank", None), "dropout_cat_pos_emb": getattr(circuit_config, "dropout_cat_pos_emb", False), "sinusoidal_cat_pos_emb": getattr(circuit_config, "sinusoidal_cat_pos_emb", False), } return cls(**config_dict) def to_circuit_config(self) -> "GPTConfig": # type: ignore[name-defined] from circuit_sparsity.gpt import GPTConfig as CircuitConfig config_kwargs: dict[str, Any] = dict( vocab_size=self.vocab_size, block_size=self.block_size, n_layer=self.n_layer, n_head=self.n_head, d_model=self.d_model, dropout=self.dropout, bias=self.bias, ln_bias=self.ln_bias, rms_norm=self.rms_norm, activation_type=self.activation_type, residual_activation_type=self.residual_activation_type, tied_unembed=self.tied_unembed, unembed_rank=self.unembed_rank, afrac=self.afrac, afrac_loctypes=self.afrac_loctypes, flash=self.flash, afrac_ste=False, afrac_ste_only_non_neurons=False, afrac_approx=False, rtopk=False, mup=False, mup_width_multiplier=None, grad_checkpointing=False, enable_fp8_linear=False, scale_invariance=False, d_mlp=self.d_mlp, d_head=self.d_head, enable_sparse_kernels=False, enable_bigram_table=self.enable_bigram_table, learnable_bigram_table=self.learnable_bigram_table, bigram_table_rank=self.bigram_table_rank, d_pos_emb=self.d_pos_emb if self.d_pos_emb is not None else (self.d_model if self.use_position_embeddings else None), sink=self.sink, dropout_cat_pos_emb=self.dropout_cat_pos_emb, sinusoidal_cat_pos_emb=self.sinusoidal_cat_pos_emb, ) return CircuitConfig(**config_kwargs) def to_dict(self) -> dict[str, Any]: base = super().to_dict() data = { "vocab_size": self.vocab_size, "block_size": self.block_size, "n_layer": self.n_layer, "n_head": self.n_head, "d_model": self.d_model, "d_mlp": self.d_mlp, "d_head": self.d_head, "dropout": self.dropout, "bias": self.bias, "ln_bias": self.ln_bias, "rms_norm": self.rms_norm, "activation_type": self.activation_type, "residual_activation_type": self.residual_activation_type, "tied_unembed": self.tied_unembed, "unembed_rank": self.unembed_rank, "flash": self.flash, "afrac": self.afrac, "afrac_loctypes": self.afrac_loctypes, "use_position_embeddings": self.use_position_embeddings, "d_pos_emb": self.d_pos_emb, "sink": self.sink, "enable_bigram_table": self.enable_bigram_table, "learnable_bigram_table": self.learnable_bigram_table, "bigram_table_rank": self.bigram_table_rank, "dropout_cat_pos_emb": self.dropout_cat_pos_emb, "sinusoidal_cat_pos_emb": self.sinusoidal_cat_pos_emb, "auto_map": self.auto_map, } base.update(data) return base