Spaces:
Sleeping
Sleeping
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import torch | |
| import torchvision | |
| from ..core import register | |
| from .utils import IntermediateLayerGetter | |
| __all__ = ['TorchVisionModel'] | |
| class TorchVisionModel(torch.nn.Module): | |
| def __init__(self, name, return_layers, weights=None, **kwargs) -> None: | |
| super().__init__() | |
| if weights is not None: | |
| weights = getattr(torchvision.models.get_model_weights(name), weights) | |
| model = torchvision.models.get_model(name, weights=weights, **kwargs) | |
| # TODO hard code. | |
| if hasattr(model, 'features'): | |
| model = IntermediateLayerGetter(model.features, return_layers) | |
| else: | |
| model = IntermediateLayerGetter(model, return_layers) | |
| self.model = model | |
| def forward(self, x): | |
| return self.model(x) | |
| # TorchVisionModel('swin_t', return_layers=['5', '7']) | |
| # TorchVisionModel('resnet34', return_layers=['layer2','layer3', 'layer4']) | |
| # TorchVisionModel: | |
| # name: swin_t | |
| # return_layers: ['5', '7'] | |
| # weights: DEFAULT | |
| # model: | |
| # type: TorchVisionModel | |
| # name: resnet34 | |
| # return_layers: ['layer2','layer3', 'layer4'] | |
| # weights: DEFAULT | |