File size: 4,021 Bytes
469beaf 5995f53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from torch import nn, Tensor, tensor
from typing import Union, List, Optional
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention
)
from timm import create_model
from .configuration_efficientnet import EfficientNetConfig
class EfficientNetModel(PreTrainedModel):
"""
EfficientNet model wrapper using Hugging Face's PreTrainedModel.
This class initializes an EfficientNet model from `timm` library
and defines a forward method that extracts feature representations.
Attributes
----------
config:
Configuration object containing model parameters.
model:
Instantiated EfficientNet model.
"""
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = create_model(
config.model_name,
pretrained = config.pretrained,
num_classes = config.num_classes,
global_pool = config.global_pool,
)
def forward(self, pixel_values: Tensor) -> BaseModelOutputWithPoolingAndNoAttention:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
Returns
-------
BaseModelOutputWithPoolingAndNoAttention
Object containing the `last_hidden_state` and `pooled_output`.
"""
last_hidden_state = self.model.forward_features(pixel_values)
pooler_output = self.model.forward_head(last_hidden_state, pre_logits=True)
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state = last_hidden_state,
pooler_output=pooler_output
)
class EfficientNetModelForImageClassification(PreTrainedModel):
"""
EfficientNet model wrapper using Hugging Face's PreTrainedModel.
This class initializes an EfficientNet model from `timm` library
and defines a forward method that return logits.
It supports training when labels are provided
Attributes
----------
config :
Configuration object containing model parameters.
model :
Instantiated EfficientNet model.
"""
config_class = EfficientNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = create_model(
config.model_name,
pretrained = config.pretrained,
num_classes = config.num_classes,
global_pool = config.global_pool,
)
def forward(
self,
pixel_values: Tensor,
labels: Optional[Union[List[int], Tensor]] = None
) -> ImageClassifierOutputWithNoAttention:
"""
Parameters
----------
pixel_values : torch.Tensor
Input tensor representing image pixel values.
labels : Optional[Union[List[int], torch.Tensor]]
Ground truth labels for training and computing loss.
List of integers/tensor representing class IDs.
Returns
-------
ImageClassifierOutputWithNoAttention
Object containing `logits` and `loss`.
"""
self.model.training = False if labels is None else True
logits = self.model(pixel_values)
loss = None
if self.model.training:
labels = tensor(labels)
ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return ImageClassifierOutputWithNoAttention(
loss = loss,
logits = logits,
)
__all__ = [
"EfficientNetModel",
"EfficientNetModelForImageClassification"
] |