|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
|
class CircuitGPTConfig(PretrainedConfig): |
|
|
""" |
|
|
Minimal Hugging Face config wrapper around the circuit_sparsity GPTConfig. |
|
|
Only the fields exercised by the Neuronpedia runs are exposed. |
|
|
""" |
|
|
|
|
|
model_type = "circuitgpt" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 2048, |
|
|
block_size: int = 256, |
|
|
n_layer: int = 8, |
|
|
n_head: int = 8, |
|
|
d_model: int = 1024, |
|
|
d_mlp: int | None = None, |
|
|
d_head: int | None = None, |
|
|
dropout: float = 0.0, |
|
|
bias: bool = True, |
|
|
ln_bias: bool = True, |
|
|
rms_norm: bool = True, |
|
|
activation_type: str = "gelu", |
|
|
residual_activation_type: str = "identity", |
|
|
tied_unembed: bool = False, |
|
|
unembed_rank: int | None = None, |
|
|
afrac: float | None = None, |
|
|
afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out", |
|
|
flash: bool = True, |
|
|
use_position_embeddings: bool = False, |
|
|
sink: bool = False, |
|
|
enable_bigram_table: bool = False, |
|
|
learnable_bigram_table: bool = False, |
|
|
bigram_table_rank: int | None = None, |
|
|
dropout_cat_pos_emb: bool = False, |
|
|
sinusoidal_cat_pos_emb: bool = False, |
|
|
d_pos_emb: int | None = None, |
|
|
auto_map: dict[str, str] | None = None, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
|
|
|
for key in [ |
|
|
"afrac_ste", |
|
|
"afrac_ste_only_non_neurons", |
|
|
"afrac_approx", |
|
|
"rtopk", |
|
|
"mup", |
|
|
"mup_width_multiplier", |
|
|
"grad_checkpointing", |
|
|
"enable_fp8_linear", |
|
|
"scale_invariance", |
|
|
"cat_pos_emb", |
|
|
]: |
|
|
kwargs.pop(key, None) |
|
|
d_mlp = d_mlp or 4 * d_model |
|
|
d_head = d_head or d_model // n_head |
|
|
|
|
|
|
|
|
bos_token_id = kwargs.pop("bos_token_id", None) |
|
|
eos_token_id = kwargs.pop("eos_token_id", vocab_size - 1) |
|
|
pad_token_id = kwargs.pop("pad_token_id", None) |
|
|
|
|
|
super().__init__( |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
pad_token_id=pad_token_id, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.block_size = block_size |
|
|
self.max_position_embeddings = block_size |
|
|
self.n_layer = n_layer |
|
|
self.n_head = n_head |
|
|
self.d_model = d_model |
|
|
self.d_mlp = d_mlp |
|
|
self.d_head = d_head |
|
|
self.dropout = dropout |
|
|
self.bias = bias |
|
|
self.ln_bias = ln_bias |
|
|
self.rms_norm = rms_norm |
|
|
self.activation_type = activation_type |
|
|
self.residual_activation_type = residual_activation_type |
|
|
self.tied_unembed = tied_unembed |
|
|
self.unembed_rank = unembed_rank |
|
|
self.afrac = afrac |
|
|
self.afrac_loctypes = afrac_loctypes |
|
|
self.flash = flash |
|
|
self.use_position_embeddings = use_position_embeddings |
|
|
self.d_pos_emb = d_pos_emb |
|
|
self.sink = sink |
|
|
self.enable_bigram_table = enable_bigram_table |
|
|
self.learnable_bigram_table = learnable_bigram_table |
|
|
self.bigram_table_rank = bigram_table_rank |
|
|
self.dropout_cat_pos_emb = dropout_cat_pos_emb |
|
|
self.sinusoidal_cat_pos_emb = sinusoidal_cat_pos_emb |
|
|
self.is_decoder = True |
|
|
|
|
|
|
|
|
self.auto_map = auto_map or { |
|
|
"AutoConfig": "config.CircuitGPTConfig", |
|
|
"AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_circuit_config(cls, circuit_config: "GPTConfig") -> "CircuitGPTConfig": |
|
|
config_dict: dict[str, Any] = { |
|
|
"vocab_size": circuit_config.vocab_size, |
|
|
"block_size": circuit_config.block_size, |
|
|
"n_layer": circuit_config.n_layer, |
|
|
"n_head": circuit_config.n_head, |
|
|
"d_model": circuit_config.d_model, |
|
|
"d_mlp": circuit_config.d_mlp, |
|
|
"d_head": circuit_config.d_head, |
|
|
"dropout": circuit_config.dropout, |
|
|
"bias": circuit_config.bias, |
|
|
"ln_bias": circuit_config.ln_bias, |
|
|
"rms_norm": circuit_config.rms_norm, |
|
|
"activation_type": circuit_config.activation_type, |
|
|
"residual_activation_type": circuit_config.residual_activation_type, |
|
|
"tied_unembed": circuit_config.tied_unembed, |
|
|
"unembed_rank": circuit_config.unembed_rank, |
|
|
"afrac": circuit_config.afrac, |
|
|
"afrac_loctypes": circuit_config.afrac_loctypes, |
|
|
"flash": circuit_config.flash, |
|
|
"use_position_embeddings": circuit_config.d_pos_emb is not None, |
|
|
"d_pos_emb": getattr(circuit_config, "d_pos_emb", None), |
|
|
"sink": getattr(circuit_config, "sink", False), |
|
|
"enable_bigram_table": getattr(circuit_config, "enable_bigram_table", False), |
|
|
"learnable_bigram_table": getattr(circuit_config, "learnable_bigram_table", False), |
|
|
"bigram_table_rank": getattr(circuit_config, "bigram_table_rank", None), |
|
|
"dropout_cat_pos_emb": getattr(circuit_config, "dropout_cat_pos_emb", False), |
|
|
"sinusoidal_cat_pos_emb": getattr(circuit_config, "sinusoidal_cat_pos_emb", False), |
|
|
} |
|
|
return cls(**config_dict) |
|
|
|
|
|
def to_circuit_config(self) -> "GPTConfig": |
|
|
from circuit_sparsity.gpt import GPTConfig as CircuitConfig |
|
|
|
|
|
config_kwargs: dict[str, Any] = dict( |
|
|
vocab_size=self.vocab_size, |
|
|
block_size=self.block_size, |
|
|
n_layer=self.n_layer, |
|
|
n_head=self.n_head, |
|
|
d_model=self.d_model, |
|
|
dropout=self.dropout, |
|
|
bias=self.bias, |
|
|
ln_bias=self.ln_bias, |
|
|
rms_norm=self.rms_norm, |
|
|
activation_type=self.activation_type, |
|
|
residual_activation_type=self.residual_activation_type, |
|
|
tied_unembed=self.tied_unembed, |
|
|
unembed_rank=self.unembed_rank, |
|
|
afrac=self.afrac, |
|
|
afrac_loctypes=self.afrac_loctypes, |
|
|
flash=self.flash, |
|
|
afrac_ste=False, |
|
|
afrac_ste_only_non_neurons=False, |
|
|
afrac_approx=False, |
|
|
rtopk=False, |
|
|
mup=False, |
|
|
mup_width_multiplier=None, |
|
|
grad_checkpointing=False, |
|
|
enable_fp8_linear=False, |
|
|
scale_invariance=False, |
|
|
d_mlp=self.d_mlp, |
|
|
d_head=self.d_head, |
|
|
enable_sparse_kernels=False, |
|
|
enable_bigram_table=self.enable_bigram_table, |
|
|
learnable_bigram_table=self.learnable_bigram_table, |
|
|
bigram_table_rank=self.bigram_table_rank, |
|
|
d_pos_emb=self.d_pos_emb |
|
|
if self.d_pos_emb is not None |
|
|
else (self.d_model if self.use_position_embeddings else None), |
|
|
sink=self.sink, |
|
|
dropout_cat_pos_emb=self.dropout_cat_pos_emb, |
|
|
sinusoidal_cat_pos_emb=self.sinusoidal_cat_pos_emb, |
|
|
) |
|
|
return CircuitConfig(**config_kwargs) |
|
|
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
|
base = super().to_dict() |
|
|
data = { |
|
|
"vocab_size": self.vocab_size, |
|
|
"block_size": self.block_size, |
|
|
"n_layer": self.n_layer, |
|
|
"n_head": self.n_head, |
|
|
"d_model": self.d_model, |
|
|
"d_mlp": self.d_mlp, |
|
|
"d_head": self.d_head, |
|
|
"dropout": self.dropout, |
|
|
"bias": self.bias, |
|
|
"ln_bias": self.ln_bias, |
|
|
"rms_norm": self.rms_norm, |
|
|
"activation_type": self.activation_type, |
|
|
"residual_activation_type": self.residual_activation_type, |
|
|
"tied_unembed": self.tied_unembed, |
|
|
"unembed_rank": self.unembed_rank, |
|
|
"flash": self.flash, |
|
|
"afrac": self.afrac, |
|
|
"afrac_loctypes": self.afrac_loctypes, |
|
|
"use_position_embeddings": self.use_position_embeddings, |
|
|
"d_pos_emb": self.d_pos_emb, |
|
|
"sink": self.sink, |
|
|
"enable_bigram_table": self.enable_bigram_table, |
|
|
"learnable_bigram_table": self.learnable_bigram_table, |
|
|
"bigram_table_rank": self.bigram_table_rank, |
|
|
"dropout_cat_pos_emb": self.dropout_cat_pos_emb, |
|
|
"sinusoidal_cat_pos_emb": self.sinusoidal_cat_pos_emb, |
|
|
"auto_map": self.auto_map, |
|
|
} |
|
|
base.update(data) |
|
|
return base |
|
|
|