RetinalGPT: Large Language-and-Vision Assistant for Retinal Health ποΈ
RetinalGPT is a specialized multimodal vision-language model (VLM) based on the LLaVA-v1.5 architecture. It is specifically engineered for the high-precision domain of ophthalmology, with a focus on interpreting retinal fundus photography and Optical Coherence Tomography (OCT) scans.
π Model Summary
RetinalGPT bridges the gap between general-purpose VLMs and specialized ophthalmic diagnostics. By fine-tuning on a curated corpus of retinal image-text pairs, the model demonstrates advanced capabilities in identifying pathologies such as Diabetic Retinopathy (DR), Glaucoma, and Age-related Macular Degeneration (AMD).
- Base LLM: Llama-7b
- Vision Tower: CLIP-ViT-L-14-336px
- Connector: MLP Projection Layer
- Domain: Ophthalmology / Retinal Imaging
π Key Capabilities
RetinalGPT is trained to perform complex visual reasoning tasks including:
- Automated Screening: Grading Diabetic Retinopathy severity (Stage 0-4).
- Lesion Characterization: Identifying and describing microaneurysms, hemorrhages, and exudates.
- Anatomical Mapping: Precise description of the optic disc, cup-to-disc ratio, and foveal reflex.
- Clinical QA: Engaging in multi-turn dialogues about specific clinical findings in a retinal scan.
π» How to Use
RetinalGPT follows the standard LLaVA inference pipeline. You will need the llava library installed.
Usage
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from PIL import Image
import torch
model_path = "your-username/retinalgpt"
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=model_name
)
# Prepare Image
image = Image.open("fundus_sample.jpg")
image_tensor = image_processor.preprocess(image, return_tensors='pt')['images'].half().cuda()
prompt = "Can you describe this image?"
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
# Generate Response
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=512,
use_cache=True
)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
- Downloads last month
- 29
Model tree for ASU-GSL/RetinalGPT
Base model
openai/clip-vit-large-patch14