from torch import nn, Tensor, tensor from typing import Union, List, Optional from transformers import PreTrainedModel from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention ) from timm import create_model from .configuration_efficientnet import EfficientNetConfig class EfficientNetModel(PreTrainedModel): """ EfficientNet model wrapper using Hugging Face's PreTrainedModel. This class initializes an EfficientNet model from `timm` library and defines a forward method that extracts feature representations. Attributes ---------- config: Configuration object containing model parameters. model: Instantiated EfficientNet model. """ config_class = EfficientNetConfig def __init__(self, config): super().__init__(config) self.config = config self.model = create_model( config.model_name, pretrained = config.pretrained, num_classes = config.num_classes, global_pool = config.global_pool, ) def forward(self, pixel_values: Tensor) -> BaseModelOutputWithPoolingAndNoAttention: """ Parameters ---------- pixel_values : torch.Tensor Input tensor representing image pixel values. Returns ------- BaseModelOutputWithPoolingAndNoAttention Object containing the `last_hidden_state` and `pooled_output`. """ last_hidden_state = self.model.forward_features(pixel_values) pooler_output = self.model.forward_head(last_hidden_state, pre_logits=True) return BaseModelOutputWithPoolingAndNoAttention( last_hidden_state = last_hidden_state, pooler_output=pooler_output ) class EfficientNetModelForImageClassification(PreTrainedModel): """ EfficientNet model wrapper using Hugging Face's PreTrainedModel. This class initializes an EfficientNet model from `timm` library and defines a forward method that return logits. It supports training when labels are provided Attributes ---------- config : Configuration object containing model parameters. model : Instantiated EfficientNet model. """ config_class = EfficientNetConfig def __init__(self, config): super().__init__(config) self.config = config self.model = create_model( config.model_name, pretrained = config.pretrained, num_classes = config.num_classes, global_pool = config.global_pool, ) def forward( self, pixel_values: Tensor, labels: Optional[Union[List[int], Tensor]] = None ) -> ImageClassifierOutputWithNoAttention: """ Parameters ---------- pixel_values : torch.Tensor Input tensor representing image pixel values. labels : Optional[Union[List[int], torch.Tensor]] Ground truth labels for training and computing loss. List of integers/tensor representing class IDs. Returns ------- ImageClassifierOutputWithNoAttention Object containing `logits` and `loss`. """ self.model.training = False if labels is None else True logits = self.model(pixel_values) loss = None if self.model.training: labels = tensor(labels) ce_loss = nn.CrossEntropyLoss() loss = ce_loss(logits, labels) return ImageClassifierOutputWithNoAttention( loss = loss, logits = logits, ) __all__ = [ "EfficientNetModel", "EfficientNetModelForImageClassification" ]