from typing import Dict from transformers.configuration_utils import PretrainedConfig from optimum.exporters.onnx.model_configs import ViTOnnxConfig MODEL_NAMES = [ 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_l2' ] class EfficientNetConfig(PretrainedConfig): model_type = 'efficientnet' def __init__( self, model_name: str = 'efficientnet_b0', pretrained: bool = False, num_classes: int = 1000, global_pool: str = 'avg', **kwargs, ): if model_name not in MODEL_NAMES: raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}') self.model_name = model_name self.pretrained = pretrained self.num_classes = num_classes self.global_pool = global_pool super().__init__(**kwargs) class EfficientNetOnnxConfig(ViTOnnxConfig): @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self.task == "image-classification": common_outputs["logits"] = {0: "batch_size", 1: "num_classes"} return common_outputs __all__ = [ 'EfficientNetConfig', 'EfficientNetOnnxConfig' ]