File size: 2,208 Bytes
c4cc546 b265421 c4cc546 b265421 c4cc546 b265421 c4cc546 b265421 c4cc546 b265421 c4cc546 b265421 a87fd2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from PIL import Image
from torch import Tensor, stack
from typing import Union, List
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class EfficientNetImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
model_name: str,
**kwargs,
):
self.model_name = model_name
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
super().__init__(**kwargs)
def preprocess(
self,
images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
) -> BatchFeature:
"""
Preprocesses input images by applying transformations and returning them as a BatchFeature.
Parameters
----------
images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
A single image or a list of images in one of the accepted formats.
Returns
-------
BatchFeature
A batch of transformed images
"""
images = [images] if not isinstance(images, list) else images
# TEST: empty list
if len(images) == 0:
raise ValueError("Received an empty list of images")
# TEST: validate input type
test_image = images[0]
if not isinstance(images[0], (Image.Image, Tensor)):
raise TypeError(
f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
f"but got {type(test_image).__name__} instead."
)
# Apply transformations
transforms = create_transform(**self.config)
transformed_images = [transforms(image) for image in images]
# Convert to batch tensor
transformed_image_tensors = stack(transformed_images)
data = {'pixel_values': transformed_image_tensors}
return BatchFeature(data=data)
__all__ = [
"EfficientNetImageProcessor"
] |