"""Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds. The full module has no exotic dependencies, but mirroring the definitions here keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor. The implementations below are copied with minor tweaks for readability so that code written against :func:`hook_recorder`, :func:`hook_namespace`, and :func:`torch_recompute_preserving_hook_context` behaves identically in both the training and inference configurations. """ from __future__ import annotations import re from contextlib import contextmanager from functools import partial import torch import torch.utils.checkpoint class HookContext: """State container used by the hook helpers.""" def __init__(self) -> None: self._reset() self.curintervtransformer = lambda x: x def _reset(self) -> None: self.curcontext = None self.curname = "" self.curregex = None self.curinterventions = None self.save_grads = None def _get_interventions(self): return self.curintervtransformer( self.curinterventions if self.curinterventions is not None else {} ) @contextmanager def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False): """Record tensors that pass through hooks matching ``regex``.""" assert self.curcontext is None, "reentrancy not allowed!" try: self.curcontext = {} self.curregex = re.compile(regex) self.curname = "" self.curinterventions = interventions self.save_grads = save_grads yield self.curcontext finally: self._reset() get_context()._reset() @contextmanager def hook_intervention_transform(self, intervention_transformer): oldintervention_transformer = self.curintervtransformer def compose(f, g): return lambda x: f(g(x)) self.curintervtransformer = compose( intervention_transformer, self.curintervtransformer, ) try: yield finally: self.curintervtransformer = oldintervention_transformer @contextmanager def hook_namespace(self, name: str): """Temporarily push ``name`` onto the hook namespace stack.""" oldname = self.curname self.curname = self.curname + name + "." try: yield finally: self.curname = oldname def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor: """Optionally record ``tensor`` using the current namespace.""" curinterventions = self._get_interventions() if curinterventions is not None: key = self.curname + name if key in curinterventions: tensor = curinterventions[key](tensor) if self.curcontext is not None and self.curregex.match(self.curname + name): self.curcontext[self.curname + name] = tensor if self.curcontext is not None and self.save_grads and tensor.requires_grad: class _Grad(torch.autograd.Function): @staticmethod def forward(ctx, input_tensor): return input_tensor @staticmethod def backward(ctx, grad_output): self.curcontext[self.curname + name + ".grad"] = grad_output return grad_output if self.curregex.match(self.curname + name + ".grad"): tensor = _Grad.apply(tensor) return tensor def set_context(new_context: HookContext) -> None: global context context = new_context def get_context() -> HookContext: global context return context def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None): """Wrapper around :func:`torch.utils.checkpoint` that propagates hooks.""" oldcontext = get_context() curcontext = HookContext() curcontext.curcontext = ( dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None ) curcontext.curregex = oldcontext.curregex curcontext.curname = oldcontext.curname curcontext.curinterventions = ( dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None ) curcontext.save_grads = oldcontext.save_grads is_recompute = False def _f(curcontext: HookContext, *xs): initcontext = get_context() nonlocal is_recompute set_context(curcontext) try: res = f(*xs) if not is_recompute and oldcontext.curcontext is not None: oldcontext.curcontext |= curcontext.curcontext finally: set_context(initcontext) is_recompute = True return res res = torch.utils.checkpoint.checkpoint( partial(_f, curcontext), *xs, use_reentrant=use_reentrant ) return res context = HookContext() def hook_recorder(*a, **k): return get_context().hook_recorder(*a, **k) def hook_namespace(*a, **k): return get_context().hook_namespace(*a, **k) def hook_save(*a, **k): return get_context().hook_save(*a, **k) def hook_intervention_transform(*a, **k): return get_context().hook_intervention_transform(*a, **k)