| from torch import nn, Tensor, tensor | |
| from typing import Union, List, Optional | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPoolingAndNoAttention, | |
| ImageClassifierOutputWithNoAttention | |
| ) | |
| from timm import create_model | |
| from .configuration_efficientnet import EfficientNetConfig | |
| class EfficientNetModel(PreTrainedModel): | |
| """ | |
| EfficientNet model wrapper using Hugging Face's PreTrainedModel. | |
| This class initializes an EfficientNet model from `timm` library | |
| and defines a forward method that extracts feature representations. | |
| Attributes | |
| ---------- | |
| config: | |
| Configuration object containing model parameters. | |
| model: | |
| Instantiated EfficientNet model. | |
| """ | |
| config_class = EfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.model = create_model( | |
| config.model_name, | |
| pretrained = config.pretrained, | |
| num_classes = config.num_classes, | |
| global_pool = config.global_pool, | |
| ) | |
| def forward(self, pixel_values: Tensor) -> BaseModelOutputWithPoolingAndNoAttention: | |
| """ | |
| Parameters | |
| ---------- | |
| pixel_values : torch.Tensor | |
| Input tensor representing image pixel values. | |
| Returns | |
| ------- | |
| BaseModelOutputWithPoolingAndNoAttention | |
| Object containing the `last_hidden_state` and `pooled_output`. | |
| """ | |
| last_hidden_state = self.model.forward_features(pixel_values) | |
| pooler_output = self.model.forward_head(last_hidden_state, pre_logits=True) | |
| return BaseModelOutputWithPoolingAndNoAttention( | |
| last_hidden_state = last_hidden_state, | |
| pooler_output=pooler_output | |
| ) | |
| class EfficientNetModelForImageClassification(PreTrainedModel): | |
| """ | |
| EfficientNet model wrapper using Hugging Face's PreTrainedModel. | |
| This class initializes an EfficientNet model from `timm` library | |
| and defines a forward method that return logits. | |
| It supports training when labels are provided | |
| Attributes | |
| ---------- | |
| config : | |
| Configuration object containing model parameters. | |
| model : | |
| Instantiated EfficientNet model. | |
| """ | |
| config_class = EfficientNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.model = create_model( | |
| config.model_name, | |
| pretrained = config.pretrained, | |
| num_classes = config.num_classes, | |
| global_pool = config.global_pool, | |
| ) | |
| def forward( | |
| self, | |
| pixel_values: Tensor, | |
| labels: Optional[Union[List[int], Tensor]] = None | |
| ) -> ImageClassifierOutputWithNoAttention: | |
| """ | |
| Parameters | |
| ---------- | |
| pixel_values : torch.Tensor | |
| Input tensor representing image pixel values. | |
| labels : Optional[Union[List[int], torch.Tensor]] | |
| Ground truth labels for training and computing loss. | |
| List of integers/tensor representing class IDs. | |
| Returns | |
| ------- | |
| ImageClassifierOutputWithNoAttention | |
| Object containing `logits` and `loss`. | |
| """ | |
| self.model.training = False if labels is None else True | |
| logits = self.model(pixel_values) | |
| loss = None | |
| if self.model.training: | |
| labels = tensor(labels) | |
| ce_loss = nn.CrossEntropyLoss() | |
| loss = ce_loss(logits, labels) | |
| return ImageClassifierOutputWithNoAttention( | |
| loss = loss, | |
| logits = logits, | |
| ) | |
| __all__ = [ | |
| "EfficientNetModel", | |
| "EfficientNetModelForImageClassification" | |
| ] |