|
|
from PIL import Image
|
|
|
from torch import Tensor, stack
|
|
|
from typing import Union, List
|
|
|
|
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
|
|
from timm import create_model
|
|
|
from timm.data import resolve_data_config
|
|
|
from timm.data.transforms_factory import create_transform
|
|
|
|
|
|
class EfficientNetImageProcessor(BaseImageProcessor):
|
|
|
model_input_names = ["pixel_values"]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model_name: str,
|
|
|
**kwargs,
|
|
|
):
|
|
|
self.model_name = model_name
|
|
|
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
def preprocess(
|
|
|
self,
|
|
|
images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
|
|
|
) -> BatchFeature:
|
|
|
"""
|
|
|
Preprocesses input images by applying transformations and returning them as a BatchFeature.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
|
|
|
A single image or a list of images in one of the accepted formats.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
BatchFeature
|
|
|
A batch of transformed images
|
|
|
"""
|
|
|
images = [images] if not isinstance(images, list) else images
|
|
|
|
|
|
|
|
|
if len(images) == 0:
|
|
|
raise ValueError("Received an empty list of images")
|
|
|
|
|
|
|
|
|
test_image = images[0]
|
|
|
if not isinstance(images[0], (Image.Image, Tensor)):
|
|
|
raise TypeError(
|
|
|
f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
|
|
|
f"but got {type(test_image).__name__} instead."
|
|
|
)
|
|
|
|
|
|
|
|
|
transforms = create_transform(**self.config)
|
|
|
transformed_images = [transforms(image) for image in images]
|
|
|
|
|
|
|
|
|
transformed_image_tensors = stack(transformed_images)
|
|
|
|
|
|
data = {'pixel_values': transformed_image_tensors}
|
|
|
return BatchFeature(data=data)
|
|
|
|
|
|
__all__ = [
|
|
|
"EfficientNetImageProcessor"
|
|
|
] |