efficientnet_b1 / modeling_efficientnet.py
Thastp's picture
Upload model
469beaf verified
from torch import nn, Tensor, tensor
from typing import Union, List, Optional
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention
)
from timm import create_model
from .configuration_efficientnet import EfficientNetConfig
class EfficientNetModel(PreTrainedModel):
"""
EfficientNet model wrapper using Hugging Face's PreTrainedModel.
This class initializes an EfficientNet model from `timm` library
and defines a forward method that extracts feature representations.
Attributes
----------
config:
Configuration object containing model parameters.
model:
Instantiated EfficientNet model.
"""
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = create_model(
config.model_name,
pretrained = config.pretrained,
num_classes = config.num_classes,
global_pool = config.global_pool,
)
def forward(self, pixel_values: Tensor) -> BaseModelOutputWithPoolingAndNoAttention:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
Returns
-------
BaseModelOutputWithPoolingAndNoAttention
Object containing the `last_hidden_state` and `pooled_output`.
"""
last_hidden_state = self.model.forward_features(pixel_values)
pooler_output = self.model.forward_head(last_hidden_state, pre_logits=True)
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state = last_hidden_state,
pooler_output=pooler_output
)
class EfficientNetModelForImageClassification(PreTrainedModel):
"""
EfficientNet model wrapper using Hugging Face's PreTrainedModel.
This class initializes an EfficientNet model from `timm` library
and defines a forward method that return logits.
It supports training when labels are provided
Attributes
----------
config :
Configuration object containing model parameters.
model :
Instantiated EfficientNet model.
"""
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = create_model(
config.model_name,
pretrained = config.pretrained,
num_classes = config.num_classes,
global_pool = config.global_pool,
)
def forward(
self,
pixel_values: Tensor,
labels: Optional[Union[List[int], Tensor]] = None
) -> ImageClassifierOutputWithNoAttention:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
labels : Optional[Union[List[int], torch.Tensor]]
Ground truth labels for training and computing loss.
List of integers/tensor representing class IDs.
Returns
-------
ImageClassifierOutputWithNoAttention
Object containing `logits` and `loss`.
"""
self.model.training = False if labels is None else True
logits = self.model(pixel_values)
loss = None
if self.model.training:
labels = tensor(labels)
ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return ImageClassifierOutputWithNoAttention(
loss = loss,
logits = logits,
)
__all__ = [
"EfficientNetModel",
"EfficientNetModelForImageClassification"
]