| import argparse |
| import base64 |
| from io import BytesIO |
|
|
| from PIL import Image |
|
|
| from handler import EndpointHandler, decode_base64_image |
|
|
|
|
| def local_predict(prompts, encode_image): |
| |
| my_handler = EndpointHandler() |
| if encode_image: |
| response = my_handler({"inputs": prompts, "image": encode_image}) |
| else: |
| response = my_handler({"inputs": prompts}) |
|
|
| image = decode_base64_image(response["image"]) |
| image.save("local_output.png") |
|
|
|
|
| opt = argparse.ArgumentParser("Diffuser local test") |
| opt.add_argument("-prompts", "--prompts", default="", type=str, help="Diffuser prompts") |
| opt.add_argument("-image", "--image", default="", type=str, help="Init image") |
| if __name__ == '__main__': |
| args = opt.parse_args() |
|
|
| encoded_string = "" |
| if args.image: |
| with open(args.image, "rb") as image_file: |
| encoded_string = base64.b64encode(image_file.read()).decode() |
|
|
| local_predict(args.prompts, encoded_string) |
|
|