| from typing import List, Union |
| from PIL import Image |
| import torch |
| import numpy as np |
|
|
| from diffusers.modular_pipelines import ( |
| PipelineState, |
| ModularPipelineBlocks, |
| InputParam, |
| ComponentSpec, |
| OutputParam, |
| ) |
| from controlnet_aux import CannyDetector |
| import numpy as np |
|
|
|
|
| class CannyBlock(ModularPipelineBlocks): |
| @property |
| def expected_components(self): |
| return [] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=Union[Image.Image, np.ndarray], |
| required=True, |
| description="Image to compute canny filter on", |
| ), |
| InputParam( |
| "low_threshold", |
| type_hint=int, |
| default=50, |
| ), |
| InputParam("high_threshold", type_hint=int, default=200), |
| InputParam( |
| "detect_resolution", |
| type_hint=int, |
| default=1024, |
| description="Resolution to resize to when running the Canny filtering process.", |
| ), |
| InputParam( |
| "image_resolution", |
| type_hint=int, |
| default=1024, |
| description="Resolution to resize the detected Canny edge map to.", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "control_image", |
| type_hint=Image, |
| description="Canny map for input image", |
| ) |
| ] |
|
|
| def compute_canny(self, image, low_threshold, high_threshold, detect_resolution, image_resolution): |
| canny_detector = CannyDetector() |
| canny_map = canny_detector( |
| input_image=image, |
| low_threshold=low_threshold, |
| high_threshold=high_threshold, |
| detect_resolution=detect_resolution, |
| image_resolution=image_resolution, |
| ) |
| return canny_map |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| block_state.control_image = self.compute_canny( |
| block_state.image, |
| block_state.low_threshold, |
| block_state.high_threshold, |
| block_state.detect_resolution, |
| block_state.image_resolution, |
| ) |
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|