import spaces import json import yaml import os import torch import gradio as gr from huggingface_hub import hf_hub_download from model.pipeline import JiTModel, JiTConfig from model.config import ClassContextConfig MODEL_REPO = os.environ.get("MODEL_REPO", "p1atdev/JiT-AnimeFace-experiment") MODEL_PATH = os.environ.get( "MODEL_PATH", "jit-b256-p16-cls/12-jit-animeface_00043e_033368s.safetensors" ) LABEL2ID_PATH = os.environ.get("LABEL2ID_PATH", "jit-b256-p16-cls/label2id.json") CONFIG_PATH = os.environ.get("CONFIG_PATH", "jit-b256-p16-cls/config.yml") DEVICE = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") ) DTYPE = torch.bfloat16 if DEVICE.type in ["cuda"] else torch.float16 MAX_TOKEN_LENGTH = 32 model_map: dict[str, JiTModel] = {} # {model_path: model} label2id_map: dict[str, dict] = {} # {label2id_path: label2id} def get_file_path(repo: str, path: str) -> str: """Hugging Face Hub からファイルを取得""" return hf_hub_download(repo, path) def load_label2id(label2id_path: str) -> dict: """label2id.json を読み込む""" with open(label2id_path, "r") as f: return json.load(f) def load_config(config_path: str) -> JiTConfig: """設定ファイルを読み込む""" with open(config_path, "r") as f: if config_path.endswith(".json"): config_dict = json.load(f) elif config_path.endswith((".yaml", ".yml")): config_dict = yaml.safe_load(f) else: raise ValueError("Unsupported config file format. Use .json or .yaml/.yml") return JiTConfig.model_validate(config_dict) def load_model( model_path: str, label2id_path: str, config_path: str, device: torch.device, dtype: torch.dtype = DTYPE, ) -> tuple[JiTModel, dict]: """モデルを読み込む""" if model_path in model_map: # use cache model = model_map[model_path] label2id = label2id_map[label2id_path] return model, label2id config = load_config(get_file_path(MODEL_REPO, config_path)) if isinstance(config.context_encoder, ClassContextConfig): config.context_encoder.label2id_map_path = get_file_path( MODEL_REPO, label2id_path ) model = JiTModel.from_pretrained( config=config, checkpoint_path=get_file_path(MODEL_REPO, model_path), ) model.eval() model.requires_grad_(False) model.to(device=device, dtype=dtype) model_map[model_path] = model # cache label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path)) label2id_map[label2id_path] = label2id # cache return model, label2id @spaces.GPU(duration=6) def generate_images( prompt: str, negative_prompt: str, num_steps: int, cfg_scale: float, batch_size: int, size: int, seed: int, # model_path: str = MODEL_PATH, label2id_path: str = LABEL2ID_PATH, config_path: str = CONFIG_PATH, progress=gr.Progress(track_tqdm=True), ): model, _label2id = load_model( model_path=model_path, label2id_path=label2id_path, config_path=config_path, device=DEVICE, dtype=DTYPE, ) with torch.inference_mode(), torch.autocast(device_type=DEVICE.type, dtype=DTYPE): images = model.generate( prompt=[prompt] * batch_size, negative_prompt=negative_prompt, num_inference_steps=num_steps, cfg_scale=cfg_scale, height=size, width=size, max_token_length=MAX_TOKEN_LENGTH, cfg_time_range=[0.1, 1.0], seed=seed if seed >= 0 else None, device=DEVICE, execution_dtype=DTYPE, ) return images LABEL2ID_URL = f"https://huggingface.co/{MODEL_REPO}/blob/main/{LABEL2ID_PATH}" def demo(): with gr.Blocks() as ui: gr.Markdown(f""" # JiT-AnimeFace Demo Pixel-space x-prediction flow-matching 90M parameter model for anime face generation, trained from scratch. - See full supported tags: [label2id.json]({LABEL2ID_URL}). 対応しているタグ一覧は [こちら]({LABEL2ID_URL}) から確認できます。ここに載っていないタグは反応しません。 - Current model: [{MODEL_PATH}](https://huggingface.co/{MODEL_REPO}/blob/main/{MODEL_PATH}) """) with gr.Row(): with gr.Column(): prompt = gr.TextArea( label="Prompt", info=f"Space-separated tags. Not all of danbooru tags are supported. See [the full supported tags]({LABEL2ID_URL}). スペースで区切ってください。カンマ区切りは対応してません。", value="general 1girl solo portrait looking_at_viewer medium_hair parted_lips blue_ribbon hair_ornament hairclip half_updo halterneck bokeh depth_of_field blurry_background head_tilt", placeholder="e.g.: general 1girl solo portrait looking_at_viewer", ) negative_prompt = gr.TextArea( label="Negative Prompt", info="Space-separated negative tags to avoid in generation. スペースで区切ってください。カンマ区切りは対応してません。", value="retro_artstyle 1990s_(style) sketch", lines=2, placeholder="e.g.: retro_artstyle 1990s_(style) sketch", ) num_steps = gr.Slider( minimum=1, maximum=100, value=25, step=1, label="Number of Steps", info="Recommended: more than 20 steps for better quality.", ) cfg_scale = gr.Slider( minimum=1.0, maximum=10.0, value=5.0, step=0.25, label="CFG Scale", info="Recommended: more than 2.0 for better adherence to the prompt.", ) batch_size = gr.Slider( minimum=1, maximum=64, value=25, step=1, label="Batch Size", info="Number of images to generate in one batch.", ) size = gr.Slider( minimum=64, maximum=320, value=256, step=64, label="Image Size", info="Only 256x256 is supported in the current model. Other sizes may cause quality degradation.", ) seed = gr.Number( value=-1, label="Seed (-1 for random)", ) with gr.Column(scale=2): generate_button = gr.Button("Generate Images", variant="primary") output_gallery = gr.Gallery( label="Generated Images", columns=5, height="768px", preview=False, show_label=True, ) gr.Examples( examples=[ [ "general 1girl solo portrait looking_at_viewer medium_hair parted_lips blue_ribbon hair_ornament hairclip half_updo halterneck bokeh depth_of_field blurry_background head_tilt", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl solo portrait looking_at_viewer", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl solo portrait looking_at_viewer blue_hair short_hair blush open_mouth cat_ears animal_ears red_eyes white_background", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl aqua_eyes baseball_cap blonde_hair closed_mouth earrings green_background hat jewelry looking_at_viewer shirt short_hair simple_background solo portrait yellow_shirt", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl solo portrait looking_at_viewer brown_hair ahoge long_hair :| expressionless closed_mouth swept_bangs pink_eyes pink_background simple_background dutch_angle", "retro_artstyle 1990s_(style) sketch smile", ], [ "general 1girl solo portrait looking_at_viewer hatsune_miku twintails long_hair blue_eyes one_eye_closed simple_background green_background", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl portrait looking_at_viewer sketch head_tilt white_background monochrome open_mouth long_hair", "retro_artstyle 1990s_(style)", ], [ "general 1girl solo from_behind short_hair simple_background black_background", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl portrait looking_to_the_side glasses", "retro_artstyle 1990s_(style) sketch", ], [ "general 1girl portrait looking_at_viewer cat_ears purple_theme ;d forehead animal_ears animal_ear_fluff cat_ears", "retro_artstyle 1990s_(style) sketch", ], ], inputs=[prompt, negative_prompt], label="Examples", examples_per_page=20, ) gr.on( triggers=[generate_button.click, prompt.submit], fn=generate_images, inputs=[ prompt, negative_prompt, num_steps, cfg_scale, batch_size, size, seed, ], outputs=output_gallery, ) return ui if __name__ == "__main__": load_model( model_path=MODEL_PATH, label2id_path=LABEL2ID_PATH, config_path=CONFIG_PATH, device=DEVICE, dtype=DTYPE, ) demo().launch()