Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from transformers import ViTFeatureExtractor, ViTModel | |
| from skops import hub_utils | |
| from einops import reduce | |
| from torchvision.transforms.functional import to_pil_image | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import pickle | |
| import os | |
| labels = [ | |
| 'tench', | |
| 'English springer', | |
| 'cassette player', | |
| 'chain saw', | |
| 'church', | |
| 'French horn', | |
| 'garbage truck', | |
| 'gas pump', | |
| 'golf ball', | |
| 'parachute' | |
| ] | |
| # load DINO | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16') | |
| model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device) | |
| # load logistic regression | |
| os.mkdir('emb-gam-dino') | |
| hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino') | |
| with open('emb-gam-dino/model.pkl', 'rb') as file: | |
| logistic_regression = pickle.load(file) | |
| def classify_and_heatmap(input_img): | |
| # get patch embeddings | |
| inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()} | |
| with torch.no_grad(): | |
| patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu() | |
| # get scores | |
| scores = dict(zip( | |
| labels, | |
| logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0] | |
| )) | |
| # make plot | |
| num_patches_side = model.config.image_size // model.config.patch_size | |
| # set up figure | |
| fig, axs = plt.subplots(2, 6, figsize=(12, 5)) | |
| gs = axs[0, 0].get_gridspec() | |
| for ax in axs[:, 0]: | |
| ax.remove() | |
| ax_orig_img = fig.add_subplot(gs[:, 0]) | |
| # plot original image | |
| img = to_pil_image( | |
| inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1) | |
| ) | |
| ax_orig_img.imshow(img) | |
| ax_orig_img.axis('off') | |
| # plot patch contributions | |
| patch_contributions = ( | |
| logistic_regression.coef_ \ | |
| + logistic_regression.intercept_.reshape(-1, 1) / (num_patches_side ** 2) | |
| ).reshape(-1, num_patches_side, num_patches_side) | |
| vmin = patch_contributions.min() | |
| vmax = patch_contributions.max() | |
| # print(len(list(axs[:, 1:].flat))) | |
| for i, ax in enumerate(axs[:, 1:].flat): | |
| sns.heatmap( | |
| patch_contributions[i].reshape(num_patches_side, num_patches_side), | |
| ax=ax, | |
| square=True, | |
| vmin=vmin, | |
| vmax=vmax, | |
| ) | |
| ax.set_title(labels[i]) | |
| ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}') | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| return scores, plt | |
| description=''' | |
| This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label. | |
| ''' | |
| article=''' | |
| Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam). | |
| Citation for stuff involved (not our papers): | |
| ```bibtex | |
| @article{singh2022emb, | |
| title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models}, | |
| author={Singh, Chandan and Gao, Jianfeng}, | |
| journal={arXiv preprint arXiv:2209.11799}, | |
| year={2022} | |
| } | |
| @InProceedings{Caron_2021_ICCV, | |
| author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, | |
| title = {Emerging Properties in Self-Supervised Vision Transformers}, | |
| booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | |
| month = {October}, | |
| year = {2021}, | |
| pages = {9650-9660} | |
| } | |
| @misc{imagenette, | |
| author = {fast.ai}, | |
| title = {Imagenette}, | |
| url = {https://github.com/fastai/imagenette}, | |
| } | |
| ``` | |
| ''' | |
| demo = gr.Interface( | |
| fn=classify_and_heatmap, | |
| inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'), | |
| outputs=[ | |
| gr.Label(label='Class'), | |
| gr.Plot(label='Patch Contributions') | |
| ], | |
| title='Emb-GAM DINO', | |
| description=description, | |
| article=article, | |
| examples=['./examples/english_springer.png', './examples/golf_ball.png'] | |
| ) | |
| demo.launch(debug=True) |