paultltc commited on
Commit
1ed03e7
·
verified ·
1 Parent(s): 12cba74

Update modeling_modernvbert.py

Browse files
Files changed (1) hide show
  1. modeling_modernvbert.py +214 -60
modeling_modernvbert.py CHANGED
@@ -1,18 +1,26 @@
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
- from typing import List, Optional, Tuple, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from torch.nn import CrossEntropyLoss
8
- from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging
9
- from transformers.modeling_outputs import BaseModelOutput
10
- from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
11
 
 
 
 
 
 
 
 
12
  from .configuration_modernvbert import ModernVBertConfig
13
 
14
- logger = logging.get_logger(__name__)
15
-
16
 
17
  class DecoupledEmbedding(nn.Embedding):
18
  # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
@@ -97,7 +105,7 @@ class DecoupledEmbedding(nn.Embedding):
97
  # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
98
  input_ids[additional_vocab_indices] = 0
99
  full_vector = F.embedding(input_ids, self.weight)
100
- full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices
101
  return full_vector
102
 
103
 
@@ -124,10 +132,11 @@ class ModernVBertBaseModelOutput(BaseModelOutput):
124
  sequence_length, hidden_size)`.
125
  image_hidden_states of the model produced by the vision encoder
126
  """
 
127
  last_hidden_state: torch.FloatTensor = None
128
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
129
- attentions: Optional[Tuple[torch.FloatTensor]] = None
130
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
131
 
132
 
133
  @dataclass
@@ -137,7 +146,7 @@ class ModernVBertMaskedLMOutput(MaskedLMOutput):
137
  Args:
138
  loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
139
  Masked language modeling (MLM) loss.
140
- logits (`torch.FloatTensor`):
141
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
142
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
143
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
@@ -153,15 +162,17 @@ class ModernVBertMaskedLMOutput(MaskedLMOutput):
153
  sequence_length, hidden_size)`.
154
  image_hidden_states of the model produced by the vision encoder
155
  """
 
156
  loss: Optional[torch.FloatTensor] = None
157
  logits: torch.FloatTensor = None
158
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
159
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
160
  image_hidden_states: Optional[torch.FloatTensor] = None
161
 
162
 
163
  class ModernVBertSimpleMLP(nn.Module):
164
  """A simple linear projection layer to project the vision hidden states to the text hidden states."""
 
165
  def __init__(self, input_size, output_size):
166
  super().__init__()
167
  self.proj = nn.Linear(input_size, output_size, bias=False)
@@ -175,26 +186,32 @@ class ModernVBertConnector(nn.Module):
175
  Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size.
176
  Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
177
  """
 
178
  def __init__(self, config):
179
  super().__init__()
180
- self.scale_factor = config.pixel_shuffle_factor
181
  self.modality_projection = ModernVBertSimpleMLP(
182
- input_size=config.vision_config.hidden_size * (config.scale_factor**2),
183
  output_size=config.text_config.hidden_size,
184
  )
185
 
186
- def pixel_shuffle(self, x, scale_factor):
187
  bsz, seq, embed_dim = x.size()
188
  height = width = int(seq**0.5)
189
  x = x.view(bsz, height, width, embed_dim)
190
- x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
191
  x = x.permute(0, 2, 1, 3)
192
- x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
 
 
 
 
 
193
  x = x.permute(0, 2, 1, 3)
194
- return x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
195
 
196
  def forward(self, image_hidden_states):
197
- image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
198
  return self.modality_projection(image_hidden_states)
199
 
200
 
@@ -217,55 +234,55 @@ class ModernVBertPreTrainedModel(PreTrainedModel):
217
  module.weight.data[module.padding_idx].zero_()
218
 
219
 
 
220
  class ModernVBertModel(ModernVBertPreTrainedModel):
221
  def __init__(self, config: ModernVBertConfig):
222
  super().__init__(config)
 
 
223
  self.vision_model = ModernVBertModel.init_vision_model(config)
224
  self.connector = ModernVBertConnector(config)
225
  self.text_model = ModernVBertModel.init_language_model(config)
226
- self.image_seq_len = int(
227
- ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
228
- )
229
- self.image_token_id = config.image_token_id
230
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
231
  # set the correct dtype for vision and text models
232
  self.vision_model.to(self.dtype)
233
  self.text_model.to(self.dtype)
 
 
 
 
 
 
 
234
  self.post_init()
235
 
236
  @staticmethod
237
  def init_vision_model(config: ModernVBertConfig):
238
- vision_model_config = AutoConfig.from_pretrained(
239
  config.vision_config.vision_model_name,
240
  _attn_implementation=config._attn_implementation,
241
  )
242
- vision_model = AutoModel.from_config(
243
- vision_model_config,
244
- trust_remote_code=True,
245
- )
246
- return getattr(vision_model, "vision_model", vision_model)
247
 
248
  @staticmethod
249
  def init_language_model(config: ModernVBertConfig):
250
- text_model_config = AutoConfig.from_pretrained(
251
  config.text_config.text_model_name,
252
  _attn_implementation=config._attn_implementation,
253
- trust_remote_code=True,
254
- )
255
- text_model = AutoModel.from_config(
256
- text_model_config,
257
- trust_remote_code=True
258
  )
 
259
  embed_layer = DecoupledEmbedding(
260
  num_embeddings=text_model_config.vocab_size,
261
  num_additional_embeddings=config.additional_vocab_size,
262
  embedding_dim=config.hidden_size,
263
- partially_freeze=config.freeze_config["freeze_text_layers"],
264
  padding_idx=config.pad_token_id,
265
  )
266
  text_model.set_input_embeddings(embed_layer)
267
  return text_model
268
-
 
269
  def enable_input_require_grads(self):
270
  """
271
  Enables the gradients for the input embeddings.
@@ -292,12 +309,65 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
292
  make_inputs_require_grads
293
  )
294
 
 
 
 
 
 
295
  def get_input_embeddings(self):
296
  return self.text_model.get_input_embeddings()
297
 
298
  def set_input_embeddings(self, value):
299
  self.text_model.set_input_embeddings(value)
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states):
302
  """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
303
 
@@ -311,21 +381,47 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
311
  """
312
 
313
  _, patch_size, _ = image_hidden_states.shape
314
- image_mask = input_ids == self.image_token_id
 
 
 
 
 
 
 
 
 
315
  num_image_tokens = image_mask.sum(dim=1)
316
  if not torch.all(num_image_tokens % patch_size == 0):
317
  raise ValueError("Number of <image> tokens not divisible by patch_size.")
 
318
  blocks_per_sample = num_image_tokens // patch_size
 
319
  offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
320
  block_offset = offsets[:-1]
321
  row_cum = image_mask.cumsum(dim=-1)
322
  chunk_idx = (row_cum - 1) // patch_size
323
  local_idx = (row_cum - 1) % patch_size
324
  block_idx = block_offset.unsqueeze(1) + chunk_idx
 
325
  image_embeds = torch.zeros_like(inputs_embeds)
326
  image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
 
327
  return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  def forward(
330
  self,
331
  input_ids: torch.LongTensor = None,
@@ -338,28 +434,44 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
338
  output_attentions: Optional[bool] = None,
339
  output_hidden_states: Optional[bool] = None,
340
  return_dict: Optional[bool] = None,
341
- ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
 
342
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
343
  output_hidden_states = (
344
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
345
  )
346
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
347
  if inputs_embeds is None:
348
  inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
 
 
349
  if pixel_values is not None:
350
- batch_size, num_images, _, _, _ = pixel_values.shape
351
- pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
352
- nb_values_per_image = pixel_values.shape[1:].numel()
353
- real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
354
- if not any(real_images_inds):
355
- real_images_inds[0] = True
356
- pixel_values = pixel_values[real_images_inds].contiguous()
357
- image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
358
  image_hidden_states = self.connector(image_hidden_states)
359
- elif image_hidden_states is not None:
360
- image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
361
- if inputs_embeds is not None and image_hidden_states is not None:
362
- inputs_embeds = self.inputs_merger(input_ids, inputs_embeds, image_hidden_states)
 
 
 
 
 
363
  outputs = self.text_model(
364
  inputs_embeds=inputs_embeds,
365
  attention_mask=attention_mask,
@@ -367,9 +479,9 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
367
  output_attentions=output_attentions,
368
  output_hidden_states=output_hidden_states,
369
  return_dict=return_dict,
 
370
  )
371
- if not return_dict:
372
- return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
373
  return ModernVBertBaseModelOutput(
374
  last_hidden_state=outputs.last_hidden_state,
375
  hidden_states=outputs.hidden_states,
@@ -377,11 +489,12 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
377
  image_hidden_states=image_hidden_states,
378
  )
379
 
 
380
  class ModernVBertLMHead(nn.Module):
381
  def __init__(self, config):
382
  super().__init__()
383
- pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True)
384
- pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True)
385
  self.head = pretrained_model.head
386
  self.decoder = pretrained_model.decoder
387
 
@@ -389,10 +502,12 @@ class ModernVBertLMHead(nn.Module):
389
  return self.decoder(self.head(hidden_states))
390
 
391
 
 
392
  class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
 
 
393
  def __init__(self, config):
394
  super().__init__(config)
395
- self.image_token_id = config.image_token_id
396
  self.in_features = config.hidden_size
397
  self.out_additional_features = config.additional_vocab_size
398
  self.vocab_size = config.vocab_size
@@ -403,6 +518,24 @@ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
403
  self.lm_head.to(self.dtype)
404
  self.post_init()
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  def forward(
407
  self,
408
  input_ids: torch.LongTensor = None,
@@ -416,7 +549,19 @@ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
416
  output_hidden_states: Optional[bool] = None,
417
  return_dict: Optional[bool] = None,
418
  labels: Optional[torch.LongTensor] = None,
419
- ) -> Union[Tuple, ModernVBertMaskedLMOutput]:
 
 
 
 
 
 
 
 
 
 
 
 
420
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
421
  output_hidden_states = (
422
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -434,23 +579,32 @@ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
434
  output_attentions=output_attentions,
435
  output_hidden_states=output_hidden_states,
436
  return_dict=return_dict,
 
437
  )
438
  hidden_states = outputs[0]
 
439
  logits = self.lm_head(hidden_states)
 
440
  if self.out_additional_features > 0:
441
  proj_states = self.lm_head.head(hidden_states)
442
  additional_features = self.additional_fc(proj_states)
443
  logits = torch.cat((logits, additional_features), -1)
 
444
  loss = None
445
  if labels is not None:
446
  loss = CrossEntropyLoss()(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
 
447
  if not return_dict:
448
  output = (logits,) + outputs[2:]
449
  return ((loss,) + output) if loss is not None else output
 
450
  return ModernVBertMaskedLMOutput(
451
  loss=loss,
452
  logits=logits.float(),
453
  hidden_states=outputs.hidden_states,
454
  attentions=outputs.attentions,
455
  image_hidden_states=outputs.image_hidden_states,
456
- )
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/modernvbert/modular_modernvbert.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_modernvbert.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
  from dataclasses import dataclass
8
+ from typing import Optional, Union
9
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from torch.nn import CrossEntropyLoss
 
 
 
14
 
15
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
17
+ from ...modeling_utils import PreTrainedModel
18
+ from ...processing_utils import Unpack
19
+ from ...utils import auto_docstring, can_return_tuple
20
+ from ..modernbert import ModernBertConfig, ModernBertForMaskedLM, ModernBertModel
21
+ from ..siglip import SiglipVisionConfig, SiglipVisionModel
22
  from .configuration_modernvbert import ModernVBertConfig
23
 
 
 
24
 
25
  class DecoupledEmbedding(nn.Embedding):
26
  # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
 
105
  # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
106
  input_ids[additional_vocab_indices] = 0
107
  full_vector = F.embedding(input_ids, self.weight)
108
+ full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices
109
  return full_vector
110
 
111
 
 
132
  sequence_length, hidden_size)`.
133
  image_hidden_states of the model produced by the vision encoder
134
  """
135
+
136
  last_hidden_state: torch.FloatTensor = None
137
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
138
+ attentions: Optional[tuple[torch.FloatTensor]] = None
139
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
140
 
141
 
142
  @dataclass
 
146
  Args:
147
  loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
148
  Masked language modeling (MLM) loss.
149
+ logits (`torch.FloatTensor`):
150
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
151
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
152
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
 
162
  sequence_length, hidden_size)`.
163
  image_hidden_states of the model produced by the vision encoder
164
  """
165
+
166
  loss: Optional[torch.FloatTensor] = None
167
  logits: torch.FloatTensor = None
168
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
169
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
170
  image_hidden_states: Optional[torch.FloatTensor] = None
171
 
172
 
173
  class ModernVBertSimpleMLP(nn.Module):
174
  """A simple linear projection layer to project the vision hidden states to the text hidden states."""
175
+
176
  def __init__(self, input_size, output_size):
177
  super().__init__()
178
  self.proj = nn.Linear(input_size, output_size, bias=False)
 
186
  Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size.
187
  Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
188
  """
189
+
190
  def __init__(self, config):
191
  super().__init__()
192
+ self.pixel_shuffle_factor = config.pixel_shuffle_factor
193
  self.modality_projection = ModernVBertSimpleMLP(
194
+ input_size=config.vision_config.hidden_size * (config.pixel_shuffle_factor**2),
195
  output_size=config.text_config.hidden_size,
196
  )
197
 
198
+ def pixel_shuffle(self, x, pixel_shuffle_factor):
199
  bsz, seq, embed_dim = x.size()
200
  height = width = int(seq**0.5)
201
  x = x.view(bsz, height, width, embed_dim)
202
+ x = x.view(bsz, height, int(width / pixel_shuffle_factor), embed_dim * pixel_shuffle_factor)
203
  x = x.permute(0, 2, 1, 3)
204
+ x = x.reshape(
205
+ bsz,
206
+ int(width / pixel_shuffle_factor),
207
+ int(height / pixel_shuffle_factor),
208
+ embed_dim * (pixel_shuffle_factor**2),
209
+ )
210
  x = x.permute(0, 2, 1, 3)
211
+ return x.reshape(bsz, int(seq / (pixel_shuffle_factor**2)), embed_dim * (pixel_shuffle_factor**2))
212
 
213
  def forward(self, image_hidden_states):
214
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.pixel_shuffle_factor)
215
  return self.modality_projection(image_hidden_states)
216
 
217
 
 
234
  module.weight.data[module.padding_idx].zero_()
235
 
236
 
237
+ @auto_docstring
238
  class ModernVBertModel(ModernVBertPreTrainedModel):
239
  def __init__(self, config: ModernVBertConfig):
240
  super().__init__(config)
241
+
242
+ # init components
243
  self.vision_model = ModernVBertModel.init_vision_model(config)
244
  self.connector = ModernVBertConnector(config)
245
  self.text_model = ModernVBertModel.init_language_model(config)
246
+
 
 
 
 
247
  # set the correct dtype for vision and text models
248
  self.vision_model.to(self.dtype)
249
  self.text_model.to(self.dtype)
250
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
251
+
252
+ self.image_seq_len = int(
253
+ ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
254
+ / (config.pixel_shuffle_factor**2)
255
+ )
256
+
257
  self.post_init()
258
 
259
  @staticmethod
260
  def init_vision_model(config: ModernVBertConfig):
261
+ vision_model_config = SiglipVisionConfig.from_pretrained(
262
  config.vision_config.vision_model_name,
263
  _attn_implementation=config._attn_implementation,
264
  )
265
+ vision_model = SiglipVisionModel(vision_model_config).vision_model
266
+ return vision_model
 
 
 
267
 
268
  @staticmethod
269
  def init_language_model(config: ModernVBertConfig):
270
+ text_model_config = ModernBertConfig.from_pretrained(
271
  config.text_config.text_model_name,
272
  _attn_implementation=config._attn_implementation,
 
 
 
 
 
273
  )
274
+ text_model = ModernBertModel(text_model_config)
275
  embed_layer = DecoupledEmbedding(
276
  num_embeddings=text_model_config.vocab_size,
277
  num_additional_embeddings=config.additional_vocab_size,
278
  embedding_dim=config.hidden_size,
279
+ partially_freeze=getattr(config, "freeze_config", {"freeze_text_layers": False})["freeze_text_layers"],
280
  padding_idx=config.pad_token_id,
281
  )
282
  text_model.set_input_embeddings(embed_layer)
283
  return text_model
284
+
285
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
286
  def enable_input_require_grads(self):
287
  """
288
  Enables the gradients for the input embeddings.
 
309
  make_inputs_require_grads
310
  )
311
 
312
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
313
+ def disable_input_require_grads(self):
314
+ self._text_require_grads_hook.remove()
315
+ self._vision_require_grads_hook.remove()
316
+
317
  def get_input_embeddings(self):
318
  return self.text_model.get_input_embeddings()
319
 
320
  def set_input_embeddings(self, value):
321
  self.text_model.set_input_embeddings(value)
322
 
323
+ def get_image_features(
324
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
325
+ ):
326
+ """
327
+ Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
328
+ Encodes images into continuous embeddings that can be forwarded to the language model.
329
+
330
+ Args:
331
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
332
+ The tensors corresponding to the input images.
333
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
334
+ The attention mask indicating padded regions in the image.
335
+ """
336
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
337
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
338
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
339
+
340
+ # Remove padding images - padding images are full 0.
341
+ nb_values_per_image = pixel_values.shape[1:].numel()
342
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
343
+
344
+ if not any(real_images_inds):
345
+ real_images_inds[0] = True
346
+
347
+ pixel_values = pixel_values[real_images_inds].contiguous()
348
+ # Handle the vision attention mask
349
+ if pixel_attention_mask is None:
350
+ pixel_attention_mask = torch.ones(
351
+ size=[pixel_values.shape[i] for i in (0, 2, 3)],
352
+ dtype=torch.bool,
353
+ device=pixel_values.device,
354
+ )
355
+ else:
356
+ # Remove padding images from the mask
357
+ pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
358
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
359
+
360
+ patch_size = self.config.vision_config.patch_size
361
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
362
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
363
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
364
+
365
+ # Get sequence from the vision encoder
366
+ image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
367
+ image_hidden_states = image_hidden_states.last_hidden_state
368
+
369
+ return image_hidden_states
370
+
371
  def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states):
372
  """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
373
 
 
381
  """
382
 
383
  _, patch_size, _ = image_hidden_states.shape
384
+
385
+ if input_ids is None:
386
+ image_mask = inputs_embeds == self.get_input_embeddings()(
387
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
388
+ )
389
+ image_mask = image_mask[..., 0] # slice off the hidden dim
390
+ else:
391
+ image_mask = input_ids == self.config.image_token_id
392
+
393
+ # Assert that the input <image> tokens are valid (i.e. multiple of patch_size)
394
  num_image_tokens = image_mask.sum(dim=1)
395
  if not torch.all(num_image_tokens % patch_size == 0):
396
  raise ValueError("Number of <image> tokens not divisible by patch_size.")
397
+
398
  blocks_per_sample = num_image_tokens // patch_size
399
+
400
  offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
401
  block_offset = offsets[:-1]
402
  row_cum = image_mask.cumsum(dim=-1)
403
  chunk_idx = (row_cum - 1) // patch_size
404
  local_idx = (row_cum - 1) % patch_size
405
  block_idx = block_offset.unsqueeze(1) + chunk_idx
406
+
407
  image_embeds = torch.zeros_like(inputs_embeds)
408
  image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
409
+
410
  return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
411
 
412
+ @can_return_tuple
413
+ @auto_docstring(
414
+ custom_intro="""
415
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
416
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
417
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
418
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
419
+ For efficiency, we only pass through the vision_model's forward the real images by
420
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
421
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
422
+ """,
423
+ checkpoint="modernvbert/ModernVBert",
424
+ )
425
  def forward(
426
  self,
427
  input_ids: torch.LongTensor = None,
 
434
  output_attentions: Optional[bool] = None,
435
  output_hidden_states: Optional[bool] = None,
436
  return_dict: Optional[bool] = None,
437
+ **kwargs: Unpack[FlashAttentionKwargs],
438
+ ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
439
+ r"""
440
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
441
+ Mask to avoid performing attention on padding pixel indices.
442
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
443
+ The hidden states of the image encoder after modality projection.
444
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
445
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
446
+ config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
447
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
448
+ """
449
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
  output_hidden_states = (
451
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
  )
453
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
+
455
  if inputs_embeds is None:
456
  inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
457
+
458
+ # Images processing
459
  if pixel_values is not None:
460
+ # Vision encoder pass
461
+ image_hidden_states = self.get_image_features(
462
+ pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask
463
+ )
464
+ # Modality projection & resampling
 
 
 
465
  image_hidden_states = self.connector(image_hidden_states)
466
+
467
+ # Merge image and text embeddings
468
+ if image_hidden_states is not None:
469
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
470
+ inputs_embeds = self.inputs_merger(
471
+ input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states
472
+ )
473
+
474
+ # Language model pass
475
  outputs = self.text_model(
476
  inputs_embeds=inputs_embeds,
477
  attention_mask=attention_mask,
 
479
  output_attentions=output_attentions,
480
  output_hidden_states=output_hidden_states,
481
  return_dict=return_dict,
482
+ **kwargs,
483
  )
484
+
 
485
  return ModernVBertBaseModelOutput(
486
  last_hidden_state=outputs.last_hidden_state,
487
  hidden_states=outputs.hidden_states,
 
489
  image_hidden_states=image_hidden_states,
490
  )
491
 
492
+
493
  class ModernVBertLMHead(nn.Module):
494
  def __init__(self, config):
495
  super().__init__()
496
+ pretrained_config = ModernBertConfig.from_pretrained(config.text_config.text_model_name)
497
+ pretrained_model = ModernBertForMaskedLM(pretrained_config)
498
  self.head = pretrained_model.head
499
  self.decoder = pretrained_model.decoder
500
 
 
502
  return self.decoder(self.head(hidden_states))
503
 
504
 
505
+ @auto_docstring
506
  class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
507
+ _tied_weights_keys = ["lm_head.decoder.weight", "model.text_model.embeddings.word_embeddings.weight"]
508
+
509
  def __init__(self, config):
510
  super().__init__(config)
 
511
  self.in_features = config.hidden_size
512
  self.out_additional_features = config.additional_vocab_size
513
  self.vocab_size = config.vocab_size
 
518
  self.lm_head.to(self.dtype)
519
  self.post_init()
520
 
521
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
522
+ def disable_input_require_grads(self):
523
+ self._text_require_grads_hook.remove()
524
+ self._vision_require_grads_hook.remove()
525
+
526
+ @can_return_tuple
527
+ @auto_docstring(
528
+ custom_intro="""
529
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
530
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
531
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
532
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
533
+ For efficiency, we only pass through the vision_model's forward the real images by
534
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
535
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
536
+ """,
537
+ checkpoint="modernvbert/ModernVBert",
538
+ )
539
  def forward(
540
  self,
541
  input_ids: torch.LongTensor = None,
 
549
  output_hidden_states: Optional[bool] = None,
550
  return_dict: Optional[bool] = None,
551
  labels: Optional[torch.LongTensor] = None,
552
+ **kwargs: Unpack[FlashAttentionKwargs],
553
+ ) -> Union[tuple, ModernVBertMaskedLMOutput]:
554
+ r"""
555
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
556
+ Mask to avoid performing attention on padding pixel indices.
557
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
558
+ The hidden states of the image encoder after modality projection.
559
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
560
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
561
+ config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
562
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
563
+ """
564
+
565
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
566
  output_hidden_states = (
567
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
579
  output_attentions=output_attentions,
580
  output_hidden_states=output_hidden_states,
581
  return_dict=return_dict,
582
+ **kwargs,
583
  )
584
  hidden_states = outputs[0]
585
+
586
  logits = self.lm_head(hidden_states)
587
+
588
  if self.out_additional_features > 0:
589
  proj_states = self.lm_head.head(hidden_states)
590
  additional_features = self.additional_fc(proj_states)
591
  logits = torch.cat((logits, additional_features), -1)
592
+
593
  loss = None
594
  if labels is not None:
595
  loss = CrossEntropyLoss()(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
596
+
597
  if not return_dict:
598
  output = (logits,) + outputs[2:]
599
  return ((loss,) + output) if loss is not None else output
600
+
601
  return ModernVBertMaskedLMOutput(
602
  loss=loss,
603
  logits=logits.float(),
604
  hidden_states=outputs.hidden_states,
605
  attentions=outputs.attentions,
606
  image_hidden_states=outputs.image_hidden_states,
607
+ )
608
+
609
+
610
+ __all__ = ["ModernVBertPreTrainedModel", "ModernVBertModel", "ModernVBertForMaskedLM"]