Spaces:
Runtime error
Runtime error
Commit
·
fec8ab6
0
Parent(s):
Upload
Browse files- .gitattributes +36 -0
- README.md +13 -0
- app.py +107 -0
- libminigpt4.so +3 -0
- minigpt4_library.py +741 -0
- requirements.txt +21 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Minigpt4 Ggml
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.36.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import ctypes
|
| 4 |
+
import pathlib
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import enum
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import argparse
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
import minigpt4_library
|
| 12 |
+
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
|
| 15 |
+
model_path = hf_hub_download(repo_id='maknee/minigpt4-13b-ggml', filename='minigpt4-13B-f16.bin')
|
| 16 |
+
llm_model_path = hf_hub_download(repo_id='maknee/ggml-vicuna-v0-quantized', filename='ggml-vicuna-13B-v0-q5_k.bin')
|
| 17 |
+
|
| 18 |
+
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
| 19 |
+
description = """<h3>This is the demo of MiniGPT-4 with ggml (cpu only!). Upload your images and start chatting!</h3>"""
|
| 20 |
+
article = """<div style='display:flex; gap: 0.25rem; '><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></div>"""
|
| 21 |
+
|
| 22 |
+
global minigpt4_chatbot
|
| 23 |
+
minigpt4_chatbot: minigpt4_library.MiniGPT4ChatBot
|
| 24 |
+
|
| 25 |
+
def user(message, history):
|
| 26 |
+
history = history or []
|
| 27 |
+
# Append the user's message to the conversation history
|
| 28 |
+
history.append([message, ""])
|
| 29 |
+
return "", history
|
| 30 |
+
|
| 31 |
+
def chat(history, limit: int = 1024, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, repeat_penalty: float = 1.1):
|
| 32 |
+
history = history or []
|
| 33 |
+
|
| 34 |
+
message = history[-1][0]
|
| 35 |
+
|
| 36 |
+
history[-1][1] = ""
|
| 37 |
+
for output in minigpt4_chatbot.generate(
|
| 38 |
+
message,
|
| 39 |
+
limit = int(limit),
|
| 40 |
+
temp = float(temp),
|
| 41 |
+
top_k = int(top_k),
|
| 42 |
+
top_p = float(top_p),
|
| 43 |
+
):
|
| 44 |
+
answer = output
|
| 45 |
+
history[-1][1] += answer
|
| 46 |
+
# stream the response
|
| 47 |
+
yield history, history
|
| 48 |
+
|
| 49 |
+
def clear_state(history, chat_message, image):
|
| 50 |
+
history = []
|
| 51 |
+
minigpt4_chatbot.reset_chat()
|
| 52 |
+
return history, gr.update(value=None, interactive=True), gr.update(placeholder='Upload image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True)
|
| 53 |
+
|
| 54 |
+
def upload_image(image, history):
|
| 55 |
+
if image is None:
|
| 56 |
+
return None, None, gr.update(interactive=True), history
|
| 57 |
+
history = []
|
| 58 |
+
minigpt4_chatbot.upload_image(image)
|
| 59 |
+
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), history
|
| 60 |
+
|
| 61 |
+
def start():
|
| 62 |
+
with gr.Blocks() as demo:
|
| 63 |
+
gr.Markdown(title)
|
| 64 |
+
gr.Markdown(description)
|
| 65 |
+
gr.Markdown(article)
|
| 66 |
+
|
| 67 |
+
with gr.Row():
|
| 68 |
+
with gr.Column(scale=0.5):
|
| 69 |
+
image = gr.Image(type="pil")
|
| 70 |
+
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
| 71 |
+
|
| 72 |
+
max_tokens = gr.Slider(1, 1024, label="Max Tokens", step=1, value=128)
|
| 73 |
+
temperature = gr.Slider(0.0, 1.0, label="Temperature", step=0.05, value=0.8)
|
| 74 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
|
| 75 |
+
top_k = gr.Slider(0, 100, label="Top K", step=1, value=40)
|
| 76 |
+
repeat_penalty = gr.Slider(0.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)
|
| 77 |
+
|
| 78 |
+
with gr.Column():
|
| 79 |
+
chatbot = gr.Chatbot(label='MiniGPT-4')
|
| 80 |
+
message = gr.Textbox(label='User', placeholder='Upload image first', interactive=False)
|
| 81 |
+
history = gr.State()
|
| 82 |
+
|
| 83 |
+
with gr.Row():
|
| 84 |
+
submit = gr.Button(value="Send message", variant="secondary").style(full_width=True)
|
| 85 |
+
clear = gr.Button(value="Reset", variant="secondary").style(full_width=False)
|
| 86 |
+
# stop = gr.Button(value="Stop", variant="secondary").style(full_width=False)
|
| 87 |
+
|
| 88 |
+
clear.click(clear_state, inputs=[history, image, message], outputs=[history, image, message, upload_button], queue=False)
|
| 89 |
+
|
| 90 |
+
upload_button.click(upload_image, inputs=[image, history], outputs=[image, message, upload_button, history])
|
| 91 |
+
|
| 92 |
+
submit_click_event = submit.click(
|
| 93 |
+
fn=user, inputs=[message, history], outputs=[message, history], queue=True
|
| 94 |
+
).then(
|
| 95 |
+
fn=chat, inputs=[history, max_tokens, temperature, top_p, top_k, repeat_penalty], outputs=[chatbot, history], queue=True
|
| 96 |
+
)
|
| 97 |
+
message_submit_event = message.submit(
|
| 98 |
+
fn=user, inputs=[message, history], outputs=[message, history], queue=True
|
| 99 |
+
).then(
|
| 100 |
+
fn=chat, inputs=[history, max_tokens, temperature, top_p, top_k, repeat_penalty], outputs=[chatbot, history], queue=True
|
| 101 |
+
)
|
| 102 |
+
# stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event, message_submit_event], queue=False)
|
| 103 |
+
|
| 104 |
+
demo.launch(enable_queue=True)
|
| 105 |
+
|
| 106 |
+
minigpt4_chatbot = minigpt4_library.MiniGPT4ChatBot(model_path, llm_model_path, verbosity=minigpt4_library.Verbosity.SILENT)
|
| 107 |
+
start()
|
libminigpt4.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be54434ed2aa0f41c69bab0531b90e6a9ecb18c805bb8082307bb5e5aa1658d4
|
| 3 |
+
size 1227064
|
minigpt4_library.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import ctypes
|
| 4 |
+
import pathlib
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import enum
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
class DataType(enum.IntEnum):
|
| 10 |
+
def __str__(self):
|
| 11 |
+
return str(self.name)
|
| 12 |
+
|
| 13 |
+
F16 = 0
|
| 14 |
+
F32 = 1
|
| 15 |
+
I32 = 2
|
| 16 |
+
L64 = 3
|
| 17 |
+
Q4_0 = 4
|
| 18 |
+
Q4_1 = 5
|
| 19 |
+
Q5_0 = 6
|
| 20 |
+
Q5_1 = 7
|
| 21 |
+
Q8_0 = 8
|
| 22 |
+
Q8_1 = 9
|
| 23 |
+
Q2_K = 10
|
| 24 |
+
Q3_K = 11
|
| 25 |
+
Q4_K = 12
|
| 26 |
+
Q5_K = 13
|
| 27 |
+
Q6_K = 14
|
| 28 |
+
Q8_K = 15
|
| 29 |
+
|
| 30 |
+
class Verbosity(enum.IntEnum):
|
| 31 |
+
SILENT = 0
|
| 32 |
+
ERR = 1
|
| 33 |
+
INFO = 2
|
| 34 |
+
DEBUG = 3
|
| 35 |
+
|
| 36 |
+
class ImageFormat(enum.IntEnum):
|
| 37 |
+
UNKNOWN = 0
|
| 38 |
+
F32 = 1
|
| 39 |
+
U8 = 2
|
| 40 |
+
|
| 41 |
+
I32 = ctypes.c_int32
|
| 42 |
+
U32 = ctypes.c_uint32
|
| 43 |
+
F32 = ctypes.c_float
|
| 44 |
+
SIZE_T = ctypes.c_size_t
|
| 45 |
+
VOID_PTR = ctypes.c_void_p
|
| 46 |
+
CHAR_PTR = ctypes.POINTER(ctypes.c_char)
|
| 47 |
+
FLOAT_PTR = ctypes.POINTER(ctypes.c_float)
|
| 48 |
+
INT_PTR = ctypes.POINTER(ctypes.c_int32)
|
| 49 |
+
CHAR_PTR_PTR = ctypes.POINTER(ctypes.POINTER(ctypes.c_char))
|
| 50 |
+
|
| 51 |
+
MiniGPT4ContextP = VOID_PTR
|
| 52 |
+
class MiniGPT4Context:
|
| 53 |
+
def __init__(self, ptr: ctypes.pointer):
|
| 54 |
+
self.ptr = ptr
|
| 55 |
+
|
| 56 |
+
class MiniGPT4Image(ctypes.Structure):
|
| 57 |
+
_fields_ = [
|
| 58 |
+
('data', VOID_PTR),
|
| 59 |
+
('width', I32),
|
| 60 |
+
('height', I32),
|
| 61 |
+
('channels', I32),
|
| 62 |
+
('format', I32)
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
class MiniGPT4Embedding(ctypes.Structure):
|
| 66 |
+
_fields_ = [
|
| 67 |
+
('data', FLOAT_PTR),
|
| 68 |
+
('n_embeddings', SIZE_T),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
MiniGPT4ImageP = ctypes.POINTER(MiniGPT4Image)
|
| 72 |
+
MiniGPT4EmbeddingP = ctypes.POINTER(MiniGPT4Embedding)
|
| 73 |
+
|
| 74 |
+
class MiniGPT4SharedLibrary:
|
| 75 |
+
"""
|
| 76 |
+
Python wrapper around minigpt4.cpp shared library.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, shared_library_path: str):
|
| 80 |
+
"""
|
| 81 |
+
Loads the shared library from specified file.
|
| 82 |
+
In case of any error, this method will throw an exception.
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
shared_library_path : str
|
| 87 |
+
Path to minigpt4.cpp shared library. On Windows, it would look like 'minigpt4.dll'. On UNIX, 'minigpt4.so'.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
| 91 |
+
|
| 92 |
+
self.library.minigpt4_model_load.argtypes = [
|
| 93 |
+
CHAR_PTR, # const char *path
|
| 94 |
+
CHAR_PTR, # const char *llm_model
|
| 95 |
+
I32, # int verbosity
|
| 96 |
+
I32, # int seed
|
| 97 |
+
I32, # int n_ctx
|
| 98 |
+
I32, # int n_batch
|
| 99 |
+
I32, # int numa
|
| 100 |
+
]
|
| 101 |
+
self.library.minigpt4_model_load.restype = MiniGPT4ContextP
|
| 102 |
+
|
| 103 |
+
self.library.minigpt4_image_load_from_file.argtypes = [
|
| 104 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 105 |
+
CHAR_PTR, # const char *path
|
| 106 |
+
MiniGPT4ImageP, # struct MiniGPT4Image *image
|
| 107 |
+
I32, # int flags
|
| 108 |
+
]
|
| 109 |
+
self.library.minigpt4_image_load_from_file.restype = I32
|
| 110 |
+
|
| 111 |
+
self.library.minigpt4_encode_image.argtypes = [
|
| 112 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 113 |
+
MiniGPT4ImageP, # const struct MiniGPT4Image *image
|
| 114 |
+
MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding
|
| 115 |
+
I32, # size_t n_threads
|
| 116 |
+
]
|
| 117 |
+
self.library.minigpt4_encode_image.restype = I32
|
| 118 |
+
|
| 119 |
+
self.library.minigpt4_begin_chat_image.argtypes = [
|
| 120 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 121 |
+
MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding
|
| 122 |
+
CHAR_PTR, # const char *s
|
| 123 |
+
I32, # size_t n_threads
|
| 124 |
+
]
|
| 125 |
+
self.library.minigpt4_begin_chat_image.restype = I32
|
| 126 |
+
|
| 127 |
+
self.library.minigpt4_end_chat_image.argtypes = [
|
| 128 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 129 |
+
CHAR_PTR_PTR, # const char **token
|
| 130 |
+
I32, # size_t n_threads
|
| 131 |
+
F32, # float temp
|
| 132 |
+
I32, # int32_t top_k
|
| 133 |
+
F32, # float top_p
|
| 134 |
+
F32, # float tfs_z
|
| 135 |
+
F32, # float typical_p
|
| 136 |
+
I32, # int32_t repeat_last_n
|
| 137 |
+
F32, # float repeat_penalty
|
| 138 |
+
F32, # float alpha_presence
|
| 139 |
+
F32, # float alpha_frequency
|
| 140 |
+
I32, # int mirostat
|
| 141 |
+
F32, # float mirostat_tau
|
| 142 |
+
F32, # float mirostat_eta
|
| 143 |
+
I32, # int penalize_nl
|
| 144 |
+
]
|
| 145 |
+
self.library.minigpt4_end_chat_image.restype = I32
|
| 146 |
+
|
| 147 |
+
self.library.minigpt4_system_prompt.argtypes = [
|
| 148 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 149 |
+
I32, # size_t n_threads
|
| 150 |
+
]
|
| 151 |
+
self.library.minigpt4_system_prompt.restype = I32
|
| 152 |
+
|
| 153 |
+
self.library.minigpt4_begin_chat.argtypes = [
|
| 154 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 155 |
+
CHAR_PTR, # const char *s
|
| 156 |
+
I32, # size_t n_threads
|
| 157 |
+
]
|
| 158 |
+
self.library.minigpt4_begin_chat.restype = I32
|
| 159 |
+
|
| 160 |
+
self.library.minigpt4_end_chat.argtypes = [
|
| 161 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 162 |
+
CHAR_PTR_PTR, # const char **token
|
| 163 |
+
I32, # size_t n_threads
|
| 164 |
+
F32, # float temp
|
| 165 |
+
I32, # int32_t top_k
|
| 166 |
+
F32, # float top_p
|
| 167 |
+
F32, # float tfs_z
|
| 168 |
+
F32, # float typical_p
|
| 169 |
+
I32, # int32_t repeat_last_n
|
| 170 |
+
F32, # float repeat_penalty
|
| 171 |
+
F32, # float alpha_presence
|
| 172 |
+
F32, # float alpha_frequency
|
| 173 |
+
I32, # int mirostat
|
| 174 |
+
F32, # float mirostat_tau
|
| 175 |
+
F32, # float mirostat_eta
|
| 176 |
+
I32, # int penalize_nl
|
| 177 |
+
]
|
| 178 |
+
self.library.minigpt4_end_chat.restype = I32
|
| 179 |
+
|
| 180 |
+
self.library.minigpt4_reset_chat.argtypes = [
|
| 181 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 182 |
+
]
|
| 183 |
+
self.library.minigpt4_reset_chat.restype = I32
|
| 184 |
+
|
| 185 |
+
self.library.minigpt4_contains_eos_token.argtypes = [
|
| 186 |
+
CHAR_PTR, # const char *s
|
| 187 |
+
]
|
| 188 |
+
self.library.minigpt4_contains_eos_token.restype = I32
|
| 189 |
+
|
| 190 |
+
self.library.minigpt4_is_eos.argtypes = [
|
| 191 |
+
CHAR_PTR, # const char *s
|
| 192 |
+
]
|
| 193 |
+
self.library.minigpt4_is_eos.restype = I32
|
| 194 |
+
|
| 195 |
+
self.library.minigpt4_free.argtypes = [
|
| 196 |
+
MiniGPT4ContextP, # struct MiniGPT4Context *ctx
|
| 197 |
+
]
|
| 198 |
+
self.library.minigpt4_free.restype = I32
|
| 199 |
+
|
| 200 |
+
self.library.minigpt4_free_image.argtypes = [
|
| 201 |
+
MiniGPT4ImageP, # struct MiniGPT4Image *image
|
| 202 |
+
]
|
| 203 |
+
self.library.minigpt4_free_image.restype = I32
|
| 204 |
+
|
| 205 |
+
self.library.minigpt4_free_embedding.argtypes = [
|
| 206 |
+
MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding
|
| 207 |
+
]
|
| 208 |
+
self.library.minigpt4_free_embedding.restype = I32
|
| 209 |
+
|
| 210 |
+
self.library.minigpt4_error_code_to_string.argtypes = [
|
| 211 |
+
I32, # int error_code
|
| 212 |
+
]
|
| 213 |
+
self.library.minigpt4_error_code_to_string.restype = CHAR_PTR
|
| 214 |
+
|
| 215 |
+
self.library.minigpt4_quantize_model.argtypes = [
|
| 216 |
+
CHAR_PTR, # const char *in_path
|
| 217 |
+
CHAR_PTR, # const char *out_path
|
| 218 |
+
I32, # int data_type
|
| 219 |
+
]
|
| 220 |
+
self.library.minigpt4_quantize_model.restype = I32
|
| 221 |
+
|
| 222 |
+
self.library.minigpt4_set_verbosity.argtypes = [
|
| 223 |
+
I32, # int verbosity
|
| 224 |
+
]
|
| 225 |
+
self.library.minigpt4_set_verbosity.restype = None
|
| 226 |
+
|
| 227 |
+
def panic_if_error(self, error_code: int) -> None:
|
| 228 |
+
"""
|
| 229 |
+
Raises an exception if the error code is not 0.
|
| 230 |
+
|
| 231 |
+
Parameters
|
| 232 |
+
----------
|
| 233 |
+
error_code : int
|
| 234 |
+
Error code to check.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
if error_code != 0:
|
| 238 |
+
raise RuntimeError(self.library.minigpt4_error_code_to_string(I32(error_code)))
|
| 239 |
+
|
| 240 |
+
def minigpt4_model_load(self, model_path: str, llm_model_path: str, verbosity: int = 1, seed: int = 1337, n_ctx: int = 2048, n_batch: int = 512, numa: int = 0) -> MiniGPT4Context:
|
| 241 |
+
"""
|
| 242 |
+
Loads a model from a file.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
model_path (str): Path to model file.
|
| 246 |
+
llm_model_path (str): Path to LLM model file.
|
| 247 |
+
verbosity (int): Verbosity level: 0 = silent, 1 = error, 2 = info, 3 = debug. Defaults to 0.
|
| 248 |
+
n_ctx (int): Size of context for llm model. Defaults to 2048.
|
| 249 |
+
seed (int): Seed for llm model. Defaults to 1337.
|
| 250 |
+
numa (int): NUMA node to use (0 = NUMA disabled, 1 = NUMA enabled). Defaults to 0.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
MiniGPT4Context: Context.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
ptr = self.library.minigpt4_model_load(
|
| 257 |
+
model_path.encode('utf-8'),
|
| 258 |
+
llm_model_path.encode('utf-8'),
|
| 259 |
+
I32(verbosity),
|
| 260 |
+
I32(seed),
|
| 261 |
+
I32(n_ctx),
|
| 262 |
+
I32(n_batch),
|
| 263 |
+
I32(numa),
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
assert ptr is not None, 'minigpt4_model_load failed'
|
| 267 |
+
|
| 268 |
+
return MiniGPT4Context(ptr)
|
| 269 |
+
|
| 270 |
+
def minigpt4_image_load_from_file(self, ctx: MiniGPT4Context, path: str, flags: int) -> MiniGPT4Image:
|
| 271 |
+
"""
|
| 272 |
+
Loads an image from a file
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
ctx (MiniGPT4Context): context
|
| 276 |
+
path (str): path
|
| 277 |
+
flags (int): flags
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
MiniGPT4Image: image
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
image = MiniGPT4Image()
|
| 284 |
+
self.panic_if_error(self.library.minigpt4_image_load_from_file(ctx.ptr, path.encode('utf-8'), ctypes.pointer(image), I32(flags)))
|
| 285 |
+
return image
|
| 286 |
+
|
| 287 |
+
def minigpt4_preprocess_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, flags: int = 0) -> MiniGPT4Image:
|
| 288 |
+
"""
|
| 289 |
+
Preprocesses an image
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
ctx (MiniGPT4Context): Context
|
| 293 |
+
image (MiniGPT4Image): Image
|
| 294 |
+
flags (int): Flags. Defaults to 0.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
MiniGPT4Image: Preprocessed image
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
preprocessed_image = MiniGPT4Image()
|
| 301 |
+
self.panic_if_error(self.library.minigpt4_preprocess_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(preprocessed_image), I32(flags)))
|
| 302 |
+
return preprocessed_image
|
| 303 |
+
|
| 304 |
+
def minigpt4_encode_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, n_threads: int = 0) -> MiniGPT4Embedding:
|
| 305 |
+
"""
|
| 306 |
+
Encodes an image into embedding
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
ctx (MiniGPT4Context): Context.
|
| 310 |
+
image (MiniGPT4Image): Image.
|
| 311 |
+
n_threads (int): Number of threads to use, if 0, uses all available. Defaults to 0.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
embedding (MiniGPT4Embedding): Output embedding.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
embedding = MiniGPT4Embedding()
|
| 318 |
+
self.panic_if_error(self.library.minigpt4_encode_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(embedding), n_threads))
|
| 319 |
+
return embedding
|
| 320 |
+
|
| 321 |
+
def minigpt4_begin_chat_image(self, ctx: MiniGPT4Context, image_embedding: MiniGPT4Embedding, s: str, n_threads: int = 0):
|
| 322 |
+
"""
|
| 323 |
+
Begins a chat with an image.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
ctx (MiniGPT4Context): Context.
|
| 327 |
+
image_embedding (MiniGPT4Embedding): Image embedding.
|
| 328 |
+
s (str): Question to ask about the image.
|
| 329 |
+
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
None
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
self.panic_if_error(self.library.minigpt4_begin_chat_image(ctx.ptr, ctypes.pointer(image_embedding), s.encode('utf-8'), n_threads))
|
| 336 |
+
|
| 337 |
+
def minigpt4_end_chat_image(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str:
|
| 338 |
+
"""
|
| 339 |
+
Ends a chat with an image.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
ctx (MiniGPT4Context): Context.
|
| 343 |
+
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
|
| 344 |
+
temp (float, optional): Temperature. Defaults to 0.8.
|
| 345 |
+
top_k (int, optional): Top K. Defaults to 40.
|
| 346 |
+
top_p (float, optional): Top P. Defaults to 0.9.
|
| 347 |
+
tfs_z (float, optional): Tfs Z. Defaults to 1.0.
|
| 348 |
+
typical_p (float, optional): Typical P. Defaults to 1.0.
|
| 349 |
+
repeat_last_n (int, optional): Repeat last N. Defaults to 64.
|
| 350 |
+
repeat_penalty (float, optional): Repeat penality. Defaults to 1.1.
|
| 351 |
+
alpha_presence (float, optional): Alpha presence. Defaults to 1.0.
|
| 352 |
+
alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0.
|
| 353 |
+
mirostat (int, optional): Mirostat. Defaults to 0.
|
| 354 |
+
mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0.
|
| 355 |
+
mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0.
|
| 356 |
+
penalize_nl (int, optional): Penalize NL. Defaults to 1.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
str: Token generated.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
token = CHAR_PTR()
|
| 363 |
+
self.panic_if_error(self.library.minigpt4_end_chat_image(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl))
|
| 364 |
+
return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8')
|
| 365 |
+
|
| 366 |
+
def minigpt4_system_prompt(self, ctx: MiniGPT4Context, n_threads: int = 0):
|
| 367 |
+
"""
|
| 368 |
+
Generates a system prompt.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
ctx (MiniGPT4Context): Context.
|
| 372 |
+
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
self.panic_if_error(self.library.minigpt4_system_prompt(ctx.ptr, n_threads))
|
| 376 |
+
|
| 377 |
+
def minigpt4_begin_chat(self, ctx: MiniGPT4Context, s: str, n_threads: int = 0):
|
| 378 |
+
"""
|
| 379 |
+
Begins a chat continuing after minigpt4_begin_chat_image
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
ctx (MiniGPT4Context): Context.
|
| 383 |
+
s (str): Question to ask about the image.
|
| 384 |
+
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
None
|
| 388 |
+
"""
|
| 389 |
+
self.panic_if_error(self.library.minigpt4_begin_chat(ctx.ptr, s.encode('utf-8'), n_threads))
|
| 390 |
+
|
| 391 |
+
def minigpt4_end_chat(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str:
|
| 392 |
+
"""
|
| 393 |
+
Ends a chat.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
ctx (MiniGPT4Context): Context.
|
| 397 |
+
n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
|
| 398 |
+
temp (float, optional): Temperature. Defaults to 0.8.
|
| 399 |
+
top_k (int, optional): Top K. Defaults to 40.
|
| 400 |
+
top_p (float, optional): Top P. Defaults to 0.9.
|
| 401 |
+
tfs_z (float, optional): Tfs Z. Defaults to 1.0.
|
| 402 |
+
typical_p (float, optional): Typical P. Defaults to 1.0.
|
| 403 |
+
repeat_last_n (int, optional): Repeat last N. Defaults to 64.
|
| 404 |
+
repeat_penalty (float, optional): Repeat penality. Defaults to 1.1.
|
| 405 |
+
alpha_presence (float, optional): Alpha presence. Defaults to 1.0.
|
| 406 |
+
alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0.
|
| 407 |
+
mirostat (int, optional): Mirostat. Defaults to 0.
|
| 408 |
+
mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0.
|
| 409 |
+
mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0.
|
| 410 |
+
penalize_nl (int, optional): Penalize NL. Defaults to 1.
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
str: Token generated.
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
token = CHAR_PTR()
|
| 417 |
+
self.panic_if_error(self.library.minigpt4_end_chat(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl))
|
| 418 |
+
return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8')
|
| 419 |
+
|
| 420 |
+
def minigpt4_reset_chat(self, ctx: MiniGPT4Context):
|
| 421 |
+
"""
|
| 422 |
+
Resets the chat.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
ctx (MiniGPT4Context): Context.
|
| 426 |
+
"""
|
| 427 |
+
self.panic_if_error(self.library.minigpt4_reset_chat(ctx.ptr))
|
| 428 |
+
|
| 429 |
+
def minigpt4_contains_eos_token(self, s: str) -> bool:
|
| 430 |
+
|
| 431 |
+
"""
|
| 432 |
+
Checks if a string contains an EOS token.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
s (str): String to check.
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
bool: True if the string contains an EOS token, False otherwise.
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
return self.library.minigpt4_contains_eos_token(s.encode('utf-8'))
|
| 442 |
+
|
| 443 |
+
def minigpt4_is_eos(self, s: str) -> bool:
|
| 444 |
+
|
| 445 |
+
"""
|
| 446 |
+
Checks if a string is EOS.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
s (str): String to check.
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
bool: True if the string contains an EOS, False otherwise.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
return self.library.minigpt4_is_eos(s.encode('utf-8'))
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def minigpt4_free(self, ctx: MiniGPT4Context) -> None:
|
| 459 |
+
"""
|
| 460 |
+
Frees a context.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
ctx (MiniGPT4Context): Context.
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
self.panic_if_error(self.library.minigpt4_free(ctx.ptr))
|
| 467 |
+
|
| 468 |
+
def minigpt4_free_image(self, image: MiniGPT4Image) -> None:
|
| 469 |
+
"""
|
| 470 |
+
Frees an image.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
image (MiniGPT4Image): Image.
|
| 474 |
+
"""
|
| 475 |
+
|
| 476 |
+
self.panic_if_error(self.library.minigpt4_free_image(ctypes.pointer(image)))
|
| 477 |
+
|
| 478 |
+
def minigpt4_free_embedding(self, embedding: MiniGPT4Embedding) -> None:
|
| 479 |
+
"""
|
| 480 |
+
Frees an embedding.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
embedding (MiniGPT4Embedding): Embedding.
|
| 484 |
+
"""
|
| 485 |
+
|
| 486 |
+
self.panic_if_error(self.library.minigpt4_free_embedding(ctypes.pointer(embedding)))
|
| 487 |
+
|
| 488 |
+
def minigpt4_error_code_to_string(self, error_code: int) -> str:
|
| 489 |
+
"""
|
| 490 |
+
Converts an error code to a string.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
error_code (int): Error code.
|
| 494 |
+
|
| 495 |
+
Returns:
|
| 496 |
+
str: Error string.
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
return self.library.minigpt4_error_code_to_string(error_code).decode('utf-8')
|
| 500 |
+
|
| 501 |
+
def minigpt4_quantize_model(self, in_path: str, out_path: str, data_type: DataType):
|
| 502 |
+
"""
|
| 503 |
+
Quantizes a model file.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
in_path (str): Path to input model file.
|
| 507 |
+
out_path (str): Path to write output model file.
|
| 508 |
+
data_type (DataType): Must be one DataType enum values.
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
self.panic_if_error(self.library.minigpt4_quantize_model(in_path.encode('utf-8'), out_path.encode('utf-8'), data_type))
|
| 512 |
+
|
| 513 |
+
def minigpt4_set_verbosity(self, verbosity: Verbosity):
|
| 514 |
+
"""
|
| 515 |
+
Sets verbosity.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
verbosity (int): Verbosity.
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
self.library.minigpt4_set_verbosity(I32(verbosity))
|
| 522 |
+
|
| 523 |
+
def load_library() -> MiniGPT4SharedLibrary:
|
| 524 |
+
"""
|
| 525 |
+
Attempts to find minigpt4.cpp shared library and load it.
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
file_name: str
|
| 529 |
+
|
| 530 |
+
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
| 531 |
+
file_name = 'minigpt4.dll'
|
| 532 |
+
elif 'darwin' in sys.platform:
|
| 533 |
+
file_name = 'libminigpt4.dylib'
|
| 534 |
+
else:
|
| 535 |
+
file_name = 'libminigpt4.so'
|
| 536 |
+
|
| 537 |
+
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
|
| 538 |
+
|
| 539 |
+
paths = [
|
| 540 |
+
# If we are in "minigpt4" directory
|
| 541 |
+
f'../bin/Release/{file_name}',
|
| 542 |
+
# If we are in repo root directory
|
| 543 |
+
f'bin/Release/{file_name}',
|
| 544 |
+
# If we compiled in build directory
|
| 545 |
+
f'build/bin/Release/{file_name}',
|
| 546 |
+
# If we compiled in build directory
|
| 547 |
+
f'build/{file_name}',
|
| 548 |
+
f'../build/{file_name}',
|
| 549 |
+
# Search relative to this file
|
| 550 |
+
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
| 551 |
+
# Fallback
|
| 552 |
+
str(repo_root_dir / file_name)
|
| 553 |
+
]
|
| 554 |
+
|
| 555 |
+
for path in paths:
|
| 556 |
+
if os.path.isfile(path):
|
| 557 |
+
return MiniGPT4SharedLibrary(path)
|
| 558 |
+
|
| 559 |
+
return MiniGPT4SharedLibrary(paths[-1])
|
| 560 |
+
|
| 561 |
+
class MiniGPT4ChatBot:
|
| 562 |
+
def __init__(self, model_path: str, llm_model_path: str, verbosity: Verbosity = Verbosity.SILENT, n_threads: int = 0):
|
| 563 |
+
"""
|
| 564 |
+
Creates a new MiniGPT4ChatBot instance.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
model_path (str): Path to model file.
|
| 568 |
+
llm_model_path (str): Path to language model model file.
|
| 569 |
+
verbosity (Verbosity, optional): Verbosity. Defaults to Verbosity.SILENT.
|
| 570 |
+
n_threads (int, optional): Number of threads to use. Defaults to 0.
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
self.library = load_library()
|
| 574 |
+
self.ctx = self.library.minigpt4_model_load(model_path, llm_model_path, verbosity)
|
| 575 |
+
self.n_threads = n_threads
|
| 576 |
+
|
| 577 |
+
from PIL import Image
|
| 578 |
+
from torchvision import transforms
|
| 579 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 580 |
+
self.image_size = 224
|
| 581 |
+
|
| 582 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 583 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 584 |
+
self.transform = transforms.Compose(
|
| 585 |
+
[
|
| 586 |
+
transforms.RandomResizedCrop(
|
| 587 |
+
self.image_size,
|
| 588 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 589 |
+
),
|
| 590 |
+
transforms.ToTensor(),
|
| 591 |
+
transforms.Normalize(mean, std)
|
| 592 |
+
]
|
| 593 |
+
)
|
| 594 |
+
self.embedding: Optional[MiniGPT4Embedding] = None
|
| 595 |
+
self.is_image_chat = False
|
| 596 |
+
self.chat_history = []
|
| 597 |
+
|
| 598 |
+
def free(self):
|
| 599 |
+
if self.ctx:
|
| 600 |
+
self.library.minigpt4_free(self.ctx)
|
| 601 |
+
|
| 602 |
+
def generate(self, message: str, limit: int = 1024, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1):
|
| 603 |
+
"""
|
| 604 |
+
Generates a chat response.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
message (str): Message.
|
| 608 |
+
limit (int, optional): Limit. Defaults to 1024.
|
| 609 |
+
temp (float, optional): Temperature. Defaults to 0.8.
|
| 610 |
+
top_k (int, optional): Top K. Defaults to 40.
|
| 611 |
+
top_p (float, optional): Top P. Defaults to 0.9.
|
| 612 |
+
tfs_z (float, optional): TFS Z. Defaults to 1.0.
|
| 613 |
+
typical_p (float, optional): Typical P. Defaults to 1.0.
|
| 614 |
+
repeat_last_n (int, optional): Repeat last N. Defaults to 64.
|
| 615 |
+
repeat_penalty (float, optional): Repeat penalty. Defaults to 1.1.
|
| 616 |
+
alpha_presence (float, optional): Alpha presence. Defaults to 1.0.
|
| 617 |
+
alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0.
|
| 618 |
+
mirostat (int, optional): Mirostat. Defaults to 0.
|
| 619 |
+
mirostat_tau (float, optional): Mirostat tau. Defaults to 5.0.
|
| 620 |
+
mirostat_eta (float, optional): Mirostat eta. Defaults to 1.0.
|
| 621 |
+
penalize_nl (int, optional): Penalize NL. Defaults to 1.
|
| 622 |
+
"""
|
| 623 |
+
if self.is_image_chat:
|
| 624 |
+
self.is_image_chat = False
|
| 625 |
+
self.library.minigpt4_begin_chat_image(self.ctx, self.embedding, message, self.n_threads)
|
| 626 |
+
chat = ''
|
| 627 |
+
for _ in range(limit):
|
| 628 |
+
token = self.library.minigpt4_end_chat_image(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)
|
| 629 |
+
chat += token
|
| 630 |
+
if self.library.minigpt4_contains_eos_token(token):
|
| 631 |
+
continue
|
| 632 |
+
if self.library.minigpt4_is_eos(chat):
|
| 633 |
+
break
|
| 634 |
+
yield token
|
| 635 |
+
else:
|
| 636 |
+
self.library.minigpt4_begin_chat(self.ctx, message, self.n_threads)
|
| 637 |
+
chat = ''
|
| 638 |
+
for _ in range(limit):
|
| 639 |
+
token = self.library.minigpt4_end_chat(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)
|
| 640 |
+
chat += token
|
| 641 |
+
if self.library.minigpt4_contains_eos_token(token):
|
| 642 |
+
continue
|
| 643 |
+
if self.library.minigpt4_is_eos(chat):
|
| 644 |
+
break
|
| 645 |
+
yield token
|
| 646 |
+
|
| 647 |
+
def reset_chat(self):
|
| 648 |
+
"""
|
| 649 |
+
Resets the chat.
|
| 650 |
+
"""
|
| 651 |
+
|
| 652 |
+
self.is_image_chat = False
|
| 653 |
+
if self.embedding:
|
| 654 |
+
self.library.minigpt4_free_embedding(self.embedding)
|
| 655 |
+
self.embedding = None
|
| 656 |
+
|
| 657 |
+
self.library.minigpt4_reset_chat(self.ctx)
|
| 658 |
+
self.library.minigpt4_system_prompt(self.ctx, self.n_threads)
|
| 659 |
+
|
| 660 |
+
def upload_image(self, image):
|
| 661 |
+
"""
|
| 662 |
+
Uploads an image.
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
image (Image): Image.
|
| 666 |
+
"""
|
| 667 |
+
|
| 668 |
+
self.reset_chat()
|
| 669 |
+
|
| 670 |
+
image = self.transform(image)
|
| 671 |
+
image = image.unsqueeze(0)
|
| 672 |
+
image = image.numpy()
|
| 673 |
+
image = image.ctypes.data_as(ctypes.c_void_p)
|
| 674 |
+
minigpt4_image = MiniGPT4Image(image, self.image_size, self.image_size, 3, ImageFormat.F32)
|
| 675 |
+
self.embedding = self.library.minigpt4_encode_image(self.ctx, minigpt4_image, self.n_threads)
|
| 676 |
+
|
| 677 |
+
self.is_image_chat = True
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
if __name__ == "__main__":
|
| 681 |
+
import argparse
|
| 682 |
+
parser = argparse.ArgumentParser(description='Test loading minigpt4')
|
| 683 |
+
parser.add_argument('model_path', help='Path to model file')
|
| 684 |
+
parser.add_argument('llm_model_path', help='Path to llm model file')
|
| 685 |
+
parser.add_argument('-i', '--image_path', help='Image to test', default='images/llama.png')
|
| 686 |
+
parser.add_argument('-p', '--prompts', help='Text to test', default='what is the text in the picture?,what is the color of it?')
|
| 687 |
+
args = parser.parse_args()
|
| 688 |
+
|
| 689 |
+
model_path = args.model_path
|
| 690 |
+
llm_model_path = args.llm_model_path
|
| 691 |
+
image_path = args.image_path
|
| 692 |
+
prompts = args.prompts
|
| 693 |
+
|
| 694 |
+
if not Path(model_path).exists():
|
| 695 |
+
print(f'Model does not exist: {model_path}')
|
| 696 |
+
exit(1)
|
| 697 |
+
|
| 698 |
+
if not Path(llm_model_path).exists():
|
| 699 |
+
print(f'LLM Model does not exist: {llm_model_path}')
|
| 700 |
+
exit(1)
|
| 701 |
+
|
| 702 |
+
prompts = prompts.split(',')
|
| 703 |
+
|
| 704 |
+
print('Loading minigpt4 shared library...')
|
| 705 |
+
library = load_library()
|
| 706 |
+
print(f'Loaded library {library}')
|
| 707 |
+
ctx = library.minigpt4_model_load(model_path, llm_model_path, Verbosity.DEBUG)
|
| 708 |
+
image = library.minigpt4_image_load_from_file(ctx, image_path, 0)
|
| 709 |
+
preprocessed_image = library.minigpt4_preprocess_image(ctx, image, 0)
|
| 710 |
+
|
| 711 |
+
question = prompts[0]
|
| 712 |
+
n_threads = 0
|
| 713 |
+
embedding = library.minigpt4_encode_image(ctx, preprocessed_image, n_threads)
|
| 714 |
+
library.minigpt4_system_prompt(ctx, n_threads)
|
| 715 |
+
library.minigpt4_begin_chat_image(ctx, embedding, question, n_threads)
|
| 716 |
+
chat = ''
|
| 717 |
+
while True:
|
| 718 |
+
token = library.minigpt4_end_chat_image(ctx, n_threads)
|
| 719 |
+
chat += token
|
| 720 |
+
if library.minigpt4_contains_eos_token(token):
|
| 721 |
+
continue
|
| 722 |
+
if library.minigpt4_is_eos(chat):
|
| 723 |
+
break
|
| 724 |
+
print(token, end='')
|
| 725 |
+
|
| 726 |
+
for i in range(1, len(prompts)):
|
| 727 |
+
prompt = prompts[i]
|
| 728 |
+
library.minigpt4_begin_chat(ctx, prompt, n_threads)
|
| 729 |
+
chat = ''
|
| 730 |
+
while True:
|
| 731 |
+
token = library.minigpt4_end_chat(ctx, n_threads)
|
| 732 |
+
chat += token
|
| 733 |
+
if library.minigpt4_contains_eos_token(token):
|
| 734 |
+
continue
|
| 735 |
+
if library.minigpt4_is_eos(chat):
|
| 736 |
+
break
|
| 737 |
+
print(token, end='')
|
| 738 |
+
|
| 739 |
+
library.minigpt4_free_image(image)
|
| 740 |
+
library.minigpt4_free_image(preprocessed_image)
|
| 741 |
+
library.minigpt4_free(ctx)
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
timm==0.6.7
|
| 2 |
+
deepspeed==0.9.2
|
| 3 |
+
data
|
| 4 |
+
einops==0.6.1
|
| 5 |
+
ftfy==6.1.1
|
| 6 |
+
iopath==0.1.10
|
| 7 |
+
ipdb==0.13.13
|
| 8 |
+
numpy==1.24.3
|
| 9 |
+
peft==0.3.0
|
| 10 |
+
Pillow==9.5.0
|
| 11 |
+
PyYAML==6.0
|
| 12 |
+
regex==2022.10.31
|
| 13 |
+
torchvision==0.14.1
|
| 14 |
+
torchaudio==0.13.1
|
| 15 |
+
pytorchvideo
|
| 16 |
+
fvcore
|
| 17 |
+
decord==0.6.0
|
| 18 |
+
tqdm==4.64.1
|
| 19 |
+
transformers==4.29.1
|
| 20 |
+
gradio
|
| 21 |
+
huggingface_hub
|