Viharikvs commited on
Commit
63dc939
·
verified ·
1 Parent(s): e41fd14

Upload ARC2 GLPS checkpoint

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ step_72391 filter=lfs diff=lfs merge=lfs -text
all_config.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_layers: 2
3
+ L_layers: 6
4
+ dep_rank: 64
5
+ dep_topk: 12
6
+ expansion: 4.0
7
+ forward_dtype: bfloat16
8
+ glps_dep_graph: true
9
+ glps_enabled: true
10
+ glps_fill_obvious: true
11
+ glps_global_propagate_on_low_conf: true
12
+ glps_max_targeted_iters: 4
13
+ glps_tau_halt: 0.92
14
+ glps_tau_uncertain: 0.8
15
+ glps_token_masking: true
16
+ halt_exploration_prob: 0.1
17
+ halt_max_steps: 16
18
+ hidden_size: 512
19
+ loss:
20
+ loss_type: stablemax_cross_entropy
21
+ name: losses@ACTLossHead
22
+ mlp_t: false
23
+ name: recursive_reasoning.glps@GLPS_ACTV1
24
+ num_heads: 8
25
+ pos_encodings: rope
26
+ puzzle_emb_ndim: 512
27
+ rms_norm_eps: 1.0e-05
28
+ rope_theta: 10000.0
29
+ share_levels: true
30
+ shared_layers: 9
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Arc2concept-aug-1000-ACT-torch/pretrain_att_arc2concept_4
35
+ data_paths:
36
+ - data/arc2concept-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 100000
41
+ eval_glps_max_targeted_iters: null
42
+ eval_glps_tau_halt: null
43
+ eval_halt_max_steps: null
44
+ eval_interval: 10000
45
+ eval_only: false
46
+ eval_save_outputs: []
47
+ evaluators:
48
+ - name: arc@ARC
49
+ freeze_weights: false
50
+ global_batch_size: 768
51
+ load_checkpoint: null
52
+ lr: 0.0001
53
+ lr_min_ratio: 0.1
54
+ lr_warmup_steps: 2000
55
+ min_eval_interval: 0
56
+ project_name: Arc2concept-aug-1000-ACT-torch
57
+ puzzle_emb_lr: 0.01
58
+ puzzle_emb_weight_decay: 0.1
59
+ run_name: pretrain_att_arc2concept_4
60
+ seed: 0
61
+ weight_decay: 0.1
evaluator_ARC_step_72391/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
glps.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+ """
15
+ Global-Local Predictive Solver (GLPS)
16
+ ------------------------------------
17
+ A light-weight control-policy on top of the style blocks:
18
+ - H1: global scan -> certainty map
19
+ - L1: fill-obvious (lock stable cells)
20
+ - H2: dependency scoring over remaining cells
21
+ - L2: targeted refinement (masked updates)
22
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
23
+
24
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
25
+ """
26
+
27
+ @dataclass
28
+ class GLPS_ACTV1InnerCarry:
29
+ z_H: torch.Tensor
30
+ z_L: torch.Tensor
31
+
32
+ @dataclass
33
+ class GLPS_ACTV1Carry:
34
+ inner_carry: GLPS_ACTV1InnerCarry
35
+ steps: torch.Tensor
36
+ halted: torch.Tensor
37
+ current_data: Dict[str, torch.Tensor]
38
+
39
+ class GLPS_ACTV1Config(BaseModel):
40
+ # Core IO / shapes
41
+ batch_size: int
42
+ seq_len: int
43
+ puzzle_emb_ndim: int = 0
44
+ num_puzzle_identifiers: int = 1
45
+ vocab_size: int = 256
46
+
47
+ # Cycle schedule
48
+ H_cycles: int = 3 # (scan -> refine -> check) typical
49
+ L_cycles: int = 1
50
+
51
+ # Depth
52
+ H_layers: int = 2
53
+ L_layers: int = 4
54
+ # Parameter sharing (TRM-style): when true, use one shared stack for H and L
55
+ share_levels: bool = True
56
+ # If > 0, overrides depth of shared stack; otherwise min(H_layers, L_layers)
57
+ shared_layers: int = 0
58
+
59
+ # Transformer config
60
+ hidden_size: int = 512
61
+ expansion: float = 2.0
62
+ num_heads: int = 8
63
+ pos_encodings: str = "rope"
64
+
65
+ rms_norm_eps: float = 1e-5
66
+ rope_theta: float = 10000.0
67
+
68
+ # ACT wrapper
69
+ halt_max_steps: int = 4
70
+ halt_exploration_prob: float = 0.1
71
+
72
+ forward_dtype: str = "bfloat16"
73
+
74
+ # Optional: use MLP on L instead of attention (matches / option)
75
+ mlp_t: bool = False
76
+
77
+ # ---- GLPS extras (tiny) ----
78
+ glps_enabled: bool = True
79
+ glps_fill_obvious: bool = True
80
+ glps_dep_graph: bool = True
81
+ glps_token_masking: bool = True
82
+ glps_global_propagate_on_low_conf: bool = True
83
+
84
+ glps_tau_halt: float = 0.95 # final confidence to halt
85
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
86
+ glps_max_targeted_iters: int = 2 # small number: 1-2
87
+
88
+ # Dependency scorer (low rank bilinear)
89
+ dep_rank: int = 32
90
+ dep_topk: int = 8
91
+
92
+ # When True, use simple halt threshold (q_halt > 0) instead of comparing q_halt vs q_continue
93
+ no_ACT_continue: bool = True
94
+
95
+ class GLPSBlock(nn.Module):
96
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
97
+ super().__init__()
98
+ self.config = config
99
+ if self.config.mlp_t:
100
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
101
+ self.mlp_t = SwiGLU(
102
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
103
+ expansion=config.expansion,
104
+ )
105
+ else:
106
+ self.self_attn = Attention(
107
+ hidden_size=config.hidden_size,
108
+ head_dim=config.hidden_size // config.num_heads,
109
+ num_heads=config.num_heads,
110
+ num_key_value_heads=config.num_heads,
111
+ causal=False,
112
+ )
113
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
114
+ self.norm_eps = config.rms_norm_eps
115
+
116
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
117
+ if self.config.mlp_t:
118
+ # MLP over sequence dimension (mlp-t)
119
+ hidden_states = hidden_states.transpose(1, 2)
120
+ out = self.mlp_t(hidden_states)
121
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
122
+ hidden_states = hidden_states.transpose(1, 2)
123
+ else:
124
+ hidden_states = rms_norm(
125
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
126
+ variance_epsilon=self.norm_eps,
127
+ )
128
+ out = self.mlp(hidden_states)
129
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
130
+ return hidden_states
131
+
132
+ class GLPSReasoningModule(nn.Module):
133
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
134
+ def __init__(self, layers: List[GLPSBlock]):
135
+ super().__init__()
136
+ self.layers = torch.nn.ModuleList(layers)
137
+
138
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
139
+ x = hidden_states
140
+ for layer in self.layers:
141
+ # Compute candidate update using injected context
142
+ y = layer(hidden_states=x + input_injection, **kwargs)
143
+ if update_mask is not None:
144
+ # Convex blend keeps frozen tokens unchanged
145
+ m = update_mask.to(x.dtype)[..., None]
146
+ x = x + m * (y - x)
147
+ else:
148
+ x = y
149
+ return x
150
+
151
+ class GLPS_ACTV1_Inner(nn.Module):
152
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
153
+ super().__init__()
154
+ self.config = config
155
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
156
+
157
+ # I/O
158
+ self.embed_scale = math.sqrt(self.config.hidden_size)
159
+ embed_init_std = 1.0 / self.embed_scale
160
+
161
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
162
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
163
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
164
+
165
+ # Puzzle emb (optional) — same convention as /
166
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
167
+ if self.config.puzzle_emb_ndim > 0:
168
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
169
+
170
+ # Positional encodings
171
+ if self.config.pos_encodings == "rope":
172
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
173
+ elif self.config.pos_encodings == "learned":
174
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
175
+
176
+ # Reasoning stacks (optionally shared between H and L, TRM-style)
177
+ if self.config.share_levels:
178
+ depth = self.config.shared_layers if (getattr(self.config, "shared_layers", 0) and self.config.shared_layers > 0) else min(self.config.H_layers, self.config.L_layers)
179
+ shared_reasoner = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(depth)])
180
+ self.H_level = shared_reasoner
181
+ self.L_level = shared_reasoner
182
+ else:
183
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
184
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
185
+
186
+ # Initial states (match / style)
187
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
188
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
189
+ self.register_buffer("H_init", H_init, persistent=True)
190
+ self.register_buffer("L_init", L_init, persistent=True)
191
+
192
+ # GLPS small heads
193
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
194
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
195
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
196
+
197
+ # Low-rank dependency scorer (shared)
198
+ r = max(1, self.config.dep_rank)
199
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
200
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
201
+
202
+ # Q head init like / (near-zero -> easier bootstrapping)
203
+ with torch.no_grad():
204
+ self.q_head.weight.zero_()
205
+ self.q_head.bias.fill_(-5)
206
+
207
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
208
+ # Token embedding
209
+ embedding = self.embed_tokens(input.to(torch.int32))
210
+
211
+ # Puzzle embeddings
212
+ if self.config.puzzle_emb_ndim > 0:
213
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
214
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
215
+ if pad_count > 0:
216
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
217
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
218
+
219
+ # Position embeddings
220
+ if self.config.pos_encodings == "learned":
221
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
222
+
223
+ return self.embed_scale * embedding
224
+
225
+ def empty_carry(self, batch_size: int):
226
+ return GLPS_ACTV1InnerCarry(
227
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
228
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
229
+ )
230
+
231
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
232
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
233
+ B, L, D = carry.z_H.shape
234
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
235
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
236
+ reset_b = reset_flag.to(torch.bool)
237
+ else:
238
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
239
+ try:
240
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
241
+ except Exception:
242
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
243
+ m = reset_b.view(B, 1, 1)
244
+ mH = m.expand(B, L, D)
245
+ mL = mH # same shape for z_L
246
+ H_init_exp = self.H_init.expand(B, L, D)
247
+ L_init_exp = self.L_init.expand(B, L, D)
248
+ return GLPS_ACTV1InnerCarry(
249
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
250
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
251
+ )
252
+
253
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
254
+ # One light pass to gather global signals
255
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
256
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
257
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
258
+ return z_scan, cand_logits, certainty
259
+
260
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
261
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
262
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
263
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
264
+ """
265
+ B, L, D = z_ctx.shape
266
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
267
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
268
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
269
+ r = max(1, int(Q.shape[-1]))
270
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
271
+ sim = sim / math.sqrt(r)
272
+
273
+ # Aggregate influence from uncertain queries onto target tokens
274
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
275
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
276
+
277
+ # Top-k influenced tokens per batch
278
+ topk = min(self.config.dep_topk, L)
279
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
280
+ dep_mask = torch.zeros_like(uncertain_mask)
281
+ dep_mask.scatter_(1, idx, True)
282
+
283
+ # Always include uncertain cells themselves
284
+ dep_mask = dep_mask | uncertain_mask
285
+ return dep_mask
286
+
287
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
288
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
289
+
290
+ # Encode inputs
291
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
292
+
293
+ # States
294
+ z_H, z_L = carry.z_H, carry.z_L
295
+
296
+ if not self.config.glps_enabled:
297
+ # Fallback: run all cycles with gradients (TRM-style full backprop)
298
+ for _H in range(self.config.H_cycles):
299
+ for _L in range(self.config.L_cycles):
300
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
301
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
302
+
303
+ # Outputs
304
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
305
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
306
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
307
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
308
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
309
+
310
+ # ===== GLPS path =====
311
+ # H1: global scan (keep gradients to enable full backprop through recursion)
312
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
313
+
314
+ # L1: fill-obvious -> compute stable vs uncertain masks
315
+ if self.config.glps_fill_obvious:
316
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
317
+ else:
318
+ obvious_mask = torch.zeros_like(certainty).bool()
319
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
320
+ uncertain_mask = ~stable_mask # [B, L]
321
+
322
+ # H2: dependency prediction over remaining cells (no_grad; selection only)
323
+ if self.config.glps_dep_graph:
324
+ with torch.no_grad():
325
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
326
+ else:
327
+ dep_mask = uncertain_mask
328
+
329
+ # L2: targeted refinement — run all iters with gradients (full backprop)
330
+ update_mask = dep_mask if self.config.glps_token_masking else None
331
+ z = z_scan # use scanned context as start (no detach to keep gradients)
332
+ iters = max(1, int(self.config.glps_max_targeted_iters))
333
+ for _ in range(iters):
334
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
335
+ # Refresh certainty to shrink mask; mask ops are non-differentiable, keep them out of graph
336
+ if self.config.glps_token_masking:
337
+ with torch.no_grad():
338
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
339
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
340
+
341
+ # Merge into H and do a light H update with grad
342
+ z_L = z
343
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
344
+
345
+ # H3: energy/consistency -> confidence & optional global propagate
346
+ with torch.no_grad():
347
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
348
+ conf = 1.0 - energy
349
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
350
+ perform_sweep = self.config.glps_global_propagate_on_low_conf and bool(need_sweep.any())
351
+ if perform_sweep:
352
+ # one final full sweep only for rows needing it (run with gradients)
353
+ maskB = need_sweep.view(-1, 1, 1)
354
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
355
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
356
+ z_L = torch.where(maskB, zL2, z_L)
357
+ z_H = torch.where(maskB, zH2, z_H)
358
+
359
+ # Outputs
360
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
361
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
362
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
363
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
364
+
365
+ class GLPS_ACTV1(nn.Module):
366
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
367
+ def __init__(self, config_dict: dict):
368
+ super().__init__()
369
+ self.config = GLPS_ACTV1Config(**config_dict)
370
+ self.inner = GLPS_ACTV1_Inner(self.config)
371
+
372
+ @property
373
+ def puzzle_emb(self):
374
+ return self.inner.puzzle_emb
375
+
376
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
377
+ batch_size = batch["inputs"].shape[0]
378
+ return GLPS_ACTV1Carry(
379
+ inner_carry=self.inner.empty_carry(batch_size),
380
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
381
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
382
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
383
+ )
384
+
385
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
386
+ # Reset halted seqs
387
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
388
+ new_steps = torch.where(carry.halted, 0, carry.steps)
389
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
390
+
391
+ # Inner step
392
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
393
+
394
+ outputs = {
395
+ "logits": logits,
396
+ "q_halt_logits": q_halt_logits,
397
+ "q_continue_logits": q_continue_logits,
398
+ "conf": conf.squeeze(-1),
399
+ }
400
+
401
+ with torch.no_grad():
402
+ new_steps = new_steps + 1
403
+ is_last_step = new_steps >= self.config.halt_max_steps
404
+
405
+ # Combine halt signals: max-steps, Q-head, and confidence
406
+ if self.config.no_ACT_continue:
407
+ # Simple -style: q_halt > 0 (no comparison with q_continue)
408
+ q_halt_signal = (q_halt_logits > 0)
409
+ else:
410
+ # RL-style: compare q_halt vs q_continue
411
+ q_halt_signal = (q_halt_logits > q_continue_logits)
412
+
413
+ halted = is_last_step | q_halt_signal | (conf.squeeze(-1) >= self.config.glps_tau_halt)
414
+
415
+ # Exploration during training only
416
+ if self.training and (self.config.halt_max_steps > 1):
417
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
418
+ halted = halted & (new_steps >= min_halt_steps)
419
+
420
+ # Optional Q-learning target (only if using RL-style)
421
+ if not self.config.no_ACT_continue:
422
+ _carry2, _logits2, (next_q_halt_logits, next_q_continue_logits), _conf2 = self.inner(new_inner_carry, new_current_data)
423
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
424
+ else:
425
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as / eval behavior)
426
+ halted = is_last_step
427
+
428
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
step_72391 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d520607a1b28d132dbc0ff195e8d77741a5f47e0079569caa2ffeaabe626ecda
3
+ size 2563735221