from torch import nn from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.utils import auto_docstring from transformers.utils.generic import TransformersKwargs, can_return_tuple from typing import Optional, Union from transformers.processing_utils import Unpack import torch from transformers import Cache, Qwen3Config from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from transformers.utils import logging logger = logging.get_logger(__name__) class ZeroEntropyTokenizer(PreTrainedTokenizerFast): def __init__(self, **kwargs): super().__init__(**kwargs) def __call__(self, pairs, *args, **kwargs): input_texts: list[str] = [] for query, document in pairs: messages = [ {"role": "system", "content": query.strip()}, {"role": "user", "content": document.strip()}, ] input_text = self.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) assert isinstance(input_text, str) input_texts.append(input_text) batch_inputs = super().__call__(input_texts, *args, **kwargs) return batch_inputs class ZeroEntropyConfig(Qwen3Config): model_type = "zeroentropy" def __init__(self, yes_token_id: int = 9454, **kwargs): super().__init__(**kwargs) self.yes_token_id = yes_token_id class ZeroEntropyForSequenceClassification(Qwen3PreTrainedModel): config: ZeroEntropyConfig _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Qwen3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Qwen3ForCausalLM >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) last_positions = attention_mask.sum(dim=1) - 1 batch_size = logits.shape[0] batch_indices = torch.arange(batch_size, device=logits.device) yes_logits = logits[batch_indices, last_positions, self.config.yes_token_id] yes_logits = yes_logits / 5.0 yes_logits = yes_logits.unsqueeze(-1) return SequenceClassifierOutputWithPast( loss=None, logits=yes_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )