| | from transformers import PretrainedConfig, PreTrainedModel, Pipeline
|
| | import torch
|
| |
|
| | from BeamDiffusionModel.beamInference import beam_inference
|
| | from BeamDiffusionModel.models.diffusionModel.StableDiffusion import StableDiffusion
|
| | from BeamDiffusionModel.models.diffusionModel.Flux import Flux
|
| |
|
| | class BeamDiffusionConfig(PretrainedConfig):
|
| | model_type = "beam_diffusion"
|
| | def __init__(self, sd="SD-2.1",latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.sd_name = sd
|
| | self.sd = None
|
| | self.get_model(sd)
|
| | self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
|
| | self.n_seeds = n_seeds
|
| | self.seeds = seeds if seeds else []
|
| | self.steps_back = steps_back
|
| | self.beam_width = beam_width
|
| | self.window_size = window_size
|
| | self.use_rand = use_rand
|
| |
|
| | def get_model(self, sd):
|
| | if self.sd_name == "flux":
|
| | self.sd = Flux()
|
| | elif self.sd_name == "SD-2.1":
|
| | self.sd = StableDiffusion()
|
| |
|
| | import torch.nn as nn
|
| | from huggingface_hub import ModelHubMixin
|
| |
|
| | class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
|
| | config_class = BeamDiffusionConfig
|
| | model_type = "beam_diffusion"
|
| |
|
| | def __init__(self, config):
|
| | super().__init__(config)
|
| | self.config = config
|
| | self.dummy_param = nn.Parameter(torch.zeros(1))
|
| |
|
| | def forward(self, input_data):
|
| | images = beam_inference(
|
| | self.config.sd,
|
| | steps=input_data.get('steps', []),
|
| | latents_idx=self.config.latents_idx,
|
| | n_seeds=self.config.n_seeds,
|
| | seeds=self.config.seeds,
|
| | steps_back=self.config.steps_back,
|
| | beam_width=self.config.beam_width,
|
| | window_size=self.config.window_size,
|
| | use_rand=self.config.use_rand,
|
| | )
|
| | return {"images": images}
|
| |
|
| |
|
| |
|
| | class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
|
| | def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
|
| | super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)
|
| |
|
| | def __call__(self, inputs):
|
| | return self._forward(inputs)
|
| |
|
| | def preprocess(self, inputs):
|
| | """Converts raw input data into model-ready format."""
|
| | return inputs
|
| |
|
| | def postprocess(self, model_outputs):
|
| | """Processes model output into a user-friendly format."""
|
| | return model_outputs["images"]
|
| |
|
| | def _sanitize_parameters(self, **kwargs):
|
| | """Handles unused parameters gracefully."""
|
| | return {}, {}, {}
|
| |
|
| | def _forward(self, model_inputs):
|
| | return self.model(model_inputs)
|
| |
|