circuit-sparsity / config.py
achyutarajaram's picture
Upload folder using huggingface_hub
179cd55 verified
from __future__ import annotations
from typing import Any
from transformers import PretrainedConfig
class CircuitGPTConfig(PretrainedConfig):
"""
Minimal Hugging Face config wrapper around the circuit_sparsity GPTConfig.
Only the fields exercised by the Neuronpedia runs are exposed.
"""
model_type = "circuitgpt"
def __init__(
self,
vocab_size: int = 2048,
block_size: int = 256,
n_layer: int = 8,
n_head: int = 8,
d_model: int = 1024,
d_mlp: int | None = None,
d_head: int | None = None,
dropout: float = 0.0,
bias: bool = True,
ln_bias: bool = True,
rms_norm: bool = True,
activation_type: str = "gelu",
residual_activation_type: str = "identity",
tied_unembed: bool = False,
unembed_rank: int | None = None,
afrac: float | None = None,
afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out",
flash: bool = True,
use_position_embeddings: bool = False,
sink: bool = False,
enable_bigram_table: bool = False,
learnable_bigram_table: bool = False,
bigram_table_rank: int | None = None,
dropout_cat_pos_emb: bool = False,
sinusoidal_cat_pos_emb: bool = False,
d_pos_emb: int | None = None,
auto_map: dict[str, str] | None = None,
**kwargs: Any,
) -> None:
# Drop unsupported/sensitive keys that may be present in a loaded config.
for key in [
"afrac_ste",
"afrac_ste_only_non_neurons",
"afrac_approx",
"rtopk",
"mup",
"mup_width_multiplier",
"grad_checkpointing",
"enable_fp8_linear",
"scale_invariance",
"cat_pos_emb",
]:
kwargs.pop(key, None)
d_mlp = d_mlp or 4 * d_model
d_head = d_head or d_model // n_head
# Avoid duplicate kwargs when loading from a config dict.
bos_token_id = kwargs.pop("bos_token_id", None)
eos_token_id = kwargs.pop("eos_token_id", vocab_size - 1)
pad_token_id = kwargs.pop("pad_token_id", None)
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.block_size = block_size
self.max_position_embeddings = block_size
self.n_layer = n_layer
self.n_head = n_head
self.d_model = d_model
self.d_mlp = d_mlp
self.d_head = d_head
self.dropout = dropout
self.bias = bias
self.ln_bias = ln_bias
self.rms_norm = rms_norm
self.activation_type = activation_type
self.residual_activation_type = residual_activation_type
self.tied_unembed = tied_unembed
self.unembed_rank = unembed_rank
self.afrac = afrac
self.afrac_loctypes = afrac_loctypes
self.flash = flash
self.use_position_embeddings = use_position_embeddings
self.d_pos_emb = d_pos_emb
self.sink = sink
self.enable_bigram_table = enable_bigram_table
self.learnable_bigram_table = learnable_bigram_table
self.bigram_table_rank = bigram_table_rank
self.dropout_cat_pos_emb = dropout_cat_pos_emb
self.sinusoidal_cat_pos_emb = sinusoidal_cat_pos_emb
self.is_decoder = True
# Provide explicit auto_map entries so AutoModel/AutoConfig can locate
# the custom classes when trust_remote_code=True on the Hub.
self.auto_map = auto_map or {
"AutoConfig": "config.CircuitGPTConfig",
"AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM",
}
# ---------------------------------------------------------------------
# Conversion helpers
# ---------------------------------------------------------------------
@classmethod
def from_circuit_config(cls, circuit_config: "GPTConfig") -> "CircuitGPTConfig": # type: ignore[name-defined]
config_dict: dict[str, Any] = {
"vocab_size": circuit_config.vocab_size,
"block_size": circuit_config.block_size,
"n_layer": circuit_config.n_layer,
"n_head": circuit_config.n_head,
"d_model": circuit_config.d_model,
"d_mlp": circuit_config.d_mlp,
"d_head": circuit_config.d_head,
"dropout": circuit_config.dropout,
"bias": circuit_config.bias,
"ln_bias": circuit_config.ln_bias,
"rms_norm": circuit_config.rms_norm,
"activation_type": circuit_config.activation_type,
"residual_activation_type": circuit_config.residual_activation_type,
"tied_unembed": circuit_config.tied_unembed,
"unembed_rank": circuit_config.unembed_rank,
"afrac": circuit_config.afrac,
"afrac_loctypes": circuit_config.afrac_loctypes,
"flash": circuit_config.flash,
"use_position_embeddings": circuit_config.d_pos_emb is not None,
"d_pos_emb": getattr(circuit_config, "d_pos_emb", None),
"sink": getattr(circuit_config, "sink", False),
"enable_bigram_table": getattr(circuit_config, "enable_bigram_table", False),
"learnable_bigram_table": getattr(circuit_config, "learnable_bigram_table", False),
"bigram_table_rank": getattr(circuit_config, "bigram_table_rank", None),
"dropout_cat_pos_emb": getattr(circuit_config, "dropout_cat_pos_emb", False),
"sinusoidal_cat_pos_emb": getattr(circuit_config, "sinusoidal_cat_pos_emb", False),
}
return cls(**config_dict)
def to_circuit_config(self) -> "GPTConfig": # type: ignore[name-defined]
from circuit_sparsity.gpt import GPTConfig as CircuitConfig
config_kwargs: dict[str, Any] = dict(
vocab_size=self.vocab_size,
block_size=self.block_size,
n_layer=self.n_layer,
n_head=self.n_head,
d_model=self.d_model,
dropout=self.dropout,
bias=self.bias,
ln_bias=self.ln_bias,
rms_norm=self.rms_norm,
activation_type=self.activation_type,
residual_activation_type=self.residual_activation_type,
tied_unembed=self.tied_unembed,
unembed_rank=self.unembed_rank,
afrac=self.afrac,
afrac_loctypes=self.afrac_loctypes,
flash=self.flash,
afrac_ste=False,
afrac_ste_only_non_neurons=False,
afrac_approx=False,
rtopk=False,
mup=False,
mup_width_multiplier=None,
grad_checkpointing=False,
enable_fp8_linear=False,
scale_invariance=False,
d_mlp=self.d_mlp,
d_head=self.d_head,
enable_sparse_kernels=False,
enable_bigram_table=self.enable_bigram_table,
learnable_bigram_table=self.learnable_bigram_table,
bigram_table_rank=self.bigram_table_rank,
d_pos_emb=self.d_pos_emb
if self.d_pos_emb is not None
else (self.d_model if self.use_position_embeddings else None),
sink=self.sink,
dropout_cat_pos_emb=self.dropout_cat_pos_emb,
sinusoidal_cat_pos_emb=self.sinusoidal_cat_pos_emb,
)
return CircuitConfig(**config_kwargs)
def to_dict(self) -> dict[str, Any]:
base = super().to_dict()
data = {
"vocab_size": self.vocab_size,
"block_size": self.block_size,
"n_layer": self.n_layer,
"n_head": self.n_head,
"d_model": self.d_model,
"d_mlp": self.d_mlp,
"d_head": self.d_head,
"dropout": self.dropout,
"bias": self.bias,
"ln_bias": self.ln_bias,
"rms_norm": self.rms_norm,
"activation_type": self.activation_type,
"residual_activation_type": self.residual_activation_type,
"tied_unembed": self.tied_unembed,
"unembed_rank": self.unembed_rank,
"flash": self.flash,
"afrac": self.afrac,
"afrac_loctypes": self.afrac_loctypes,
"use_position_embeddings": self.use_position_embeddings,
"d_pos_emb": self.d_pos_emb,
"sink": self.sink,
"enable_bigram_table": self.enable_bigram_table,
"learnable_bigram_table": self.learnable_bigram_table,
"bigram_table_rank": self.bigram_table_rank,
"dropout_cat_pos_emb": self.dropout_cat_pos_emb,
"sinusoidal_cat_pos_emb": self.sinusoidal_cat_pos_emb,
"auto_map": self.auto_map,
}
base.update(data)
return base