Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import datetime | |
| import json | |
| import time | |
| import uuid | |
| from collections import OrderedDict | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio | |
| import gradio as gr | |
| import huggingface_hub | |
| from gradio import FlaggingCallback | |
| from gradio_client import utils as client_utils | |
| class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver): | |
| def flag( | |
| self, | |
| flag_data: list[Any], | |
| flag_option: str = "", | |
| username: str | None = None, | |
| ) -> int: | |
| if self.separate_dirs: | |
| # JSONL files to support dataset preview on the Hub | |
| current_utc_time = datetime.now(timezone.utc) | |
| iso_format_without_microseconds = current_utc_time.strftime( | |
| "%Y-%m-%dT%H:%M:%S" | |
| ) | |
| milliseconds = int(current_utc_time.microsecond / 1000) | |
| unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z" | |
| if username not in (None, ""): | |
| unique_id += f"_U_{username}" | |
| else: | |
| unique_id += f"_{str(uuid.uuid4())[:8]}" | |
| components_dir = self.dataset_dir / unique_id | |
| data_file = components_dir / "metadata.jsonl" | |
| path_in_repo = unique_id # upload in sub folder (safer for concurrency) | |
| else: | |
| # Unique CSV file | |
| components_dir = self.dataset_dir | |
| data_file = components_dir / "data.csv" | |
| path_in_repo = None # upload at root level | |
| return self._flag_in_dir( | |
| data_file=data_file, | |
| components_dir=components_dir, | |
| path_in_repo=path_in_repo, | |
| flag_data=flag_data, | |
| flag_option=flag_option, | |
| username=username or "", | |
| ) | |
| def _deserialize_components( | |
| self, | |
| data_dir: Path, | |
| flag_data: list[Any], | |
| flag_option: str = "", | |
| username: str = "", | |
| ) -> tuple[dict[Any, Any], list[Any]]: | |
| """Deserialize components and return the corresponding row for the flagged sample. | |
| Images/audio are saved to disk as individual files. | |
| """ | |
| # Components that can have a preview on dataset repos | |
| file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} | |
| # Generate the row corresponding to the flagged sample | |
| features = OrderedDict() | |
| row = [] | |
| for component, sample in zip(self.components, flag_data): | |
| # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-) | |
| label = component.label or "" | |
| save_dir = data_dir / client_utils.strip_invalid_filename_characters(label) | |
| save_dir.mkdir(exist_ok=True, parents=True) | |
| deserialized = component.flag(sample, save_dir) | |
| # Base component .flag method returns JSON; extract path from it when it is FileData | |
| if component.data_model: | |
| data = component.data_model.from_json(json.loads(deserialized)) | |
| if component.data_model == gr.data_classes.FileData: | |
| deserialized = data.path | |
| # Add deserialized object to row | |
| features[label] = {"dtype": "string", "_type": "Value"} | |
| try: | |
| deserialized_path = Path(deserialized) | |
| if not deserialized_path.exists(): | |
| raise FileNotFoundError(f"File {deserialized} not found") | |
| row.append(str(deserialized_path.relative_to(self.dataset_dir))) | |
| except (FileNotFoundError, TypeError, ValueError): | |
| deserialized = "" if deserialized is None else str(deserialized) | |
| row.append(deserialized) | |
| # If component is eligible for a preview, add the URL of the file | |
| # Be mindful that images and audio can be None | |
| if isinstance(component, tuple(file_preview_types)): # type: ignore | |
| for _component, _type in file_preview_types.items(): | |
| if isinstance(component, _component): | |
| features[label + " file"] = {"_type": _type} | |
| break | |
| if deserialized: | |
| path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL | |
| Path(deserialized).relative_to(self.dataset_dir) | |
| ).replace( | |
| "\\", "/" | |
| ) | |
| row.append( | |
| huggingface_hub.hf_hub_url( | |
| repo_id=self.dataset_id, | |
| filename=path_in_repo, | |
| repo_type="dataset", | |
| ) | |
| ) | |
| else: | |
| row.append("") | |
| features["flag"] = {"dtype": "string", "_type": "Value"} | |
| features["username"] = {"dtype": "string", "_type": "Value"} | |
| row.append(flag_option) | |
| row.append(username) | |
| return features, row | |
| class FlagMethod: | |
| """ | |
| Helper class that contains the flagging options and calls the flagging method. Also | |
| provides visual feedback to the user when flag is clicked. | |
| """ | |
| def __init__( | |
| self, | |
| flagging_callback: FlaggingCallback, | |
| label: str, | |
| value: str, | |
| visual_feedback: bool = True, | |
| ): | |
| self.flagging_callback = flagging_callback | |
| self.label = label | |
| self.value = value | |
| self.__name__ = "Flag" | |
| self.visual_feedback = visual_feedback | |
| def __call__( | |
| self, | |
| request: gr.Request, | |
| profile: gr.OAuthProfile | None, | |
| *flag_data, | |
| ): | |
| username = None | |
| if profile is not None: | |
| username = profile.username | |
| try: | |
| self.flagging_callback.flag( | |
| list(flag_data), flag_option=self.value, username=username | |
| ) | |
| except Exception as e: | |
| print(f"Error while sharing: {e}") | |
| if self.visual_feedback: | |
| return gr.Button(value="Sharing error", interactive=False) | |
| if not self.visual_feedback: | |
| return | |
| time.sleep(0.8) # to provide enough time for the user to observe button change | |
| return gr.Button(value="Sharing complete", interactive=False) |