efficientnet_b1 / image_processing_efficientnet.py
Thastp's picture
Upload processor
c4cc546 verified
from PIL import Image
from torch import Tensor, stack
from typing import Union, List
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class EfficientNetImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
model_name: str,
**kwargs,
):
self.model_name = model_name
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
super().__init__(**kwargs)
def preprocess(
self,
images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
) -> BatchFeature:
"""
Preprocesses input images by applying transformations and returning them as a BatchFeature.
Parameters
----------
images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
A single image or a list of images in one of the accepted formats.
Returns
-------
BatchFeature
A batch of transformed images
"""
images = [images] if not isinstance(images, list) else images
# TEST: empty list
if len(images) == 0:
raise ValueError("Received an empty list of images")
# TEST: validate input type
test_image = images[0]
if not isinstance(images[0], (Image.Image, Tensor)):
raise TypeError(
f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
f"but got {type(test_image).__name__} instead."
)
# Apply transformations
transforms = create_transform(**self.config)
transformed_images = [transforms(image) for image in images]
# Convert to batch tensor
transformed_image_tensors = stack(transformed_images)
data = {'pixel_values': transformed_image_tensors}
return BatchFeature(data=data)
__all__ = [
"EfficientNetImageProcessor"
]