RetinalGPT: Large Language-and-Vision Assistant for Retinal Health πŸ‘οΈ

RetinalGPT is a specialized multimodal vision-language model (VLM) based on the LLaVA-v1.5 architecture. It is specifically engineered for the high-precision domain of ophthalmology, with a focus on interpreting retinal fundus photography and Optical Coherence Tomography (OCT) scans.


πŸ“Œ Model Summary

RetinalGPT bridges the gap between general-purpose VLMs and specialized ophthalmic diagnostics. By fine-tuning on a curated corpus of retinal image-text pairs, the model demonstrates advanced capabilities in identifying pathologies such as Diabetic Retinopathy (DR), Glaucoma, and Age-related Macular Degeneration (AMD).

  • Base LLM: Llama-7b
  • Vision Tower: CLIP-ViT-L-14-336px
  • Connector: MLP Projection Layer
  • Domain: Ophthalmology / Retinal Imaging

πŸš€ Key Capabilities

RetinalGPT is trained to perform complex visual reasoning tasks including:

  • Automated Screening: Grading Diabetic Retinopathy severity (Stage 0-4).
  • Lesion Characterization: Identifying and describing microaneurysms, hemorrhages, and exudates.
  • Anatomical Mapping: Precise description of the optic disc, cup-to-disc ratio, and foveal reflex.
  • Clinical QA: Engaging in multi-turn dialogues about specific clinical findings in a retinal scan.

πŸ’» How to Use

RetinalGPT follows the standard LLaVA inference pipeline. You will need the llava library installed.

Usage

from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from PIL import Image
import torch

model_path = "your-username/retinalgpt"
model_name = get_model_name_from_path(model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path, 
    model_base=None, 
    model_name=model_name
)

# Prepare Image
image = Image.open("fundus_sample.jpg")
image_tensor = image_processor.preprocess(image, return_tensors='pt')['images'].half().cuda()

prompt = "Can you describe this image?"

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

# Generate Response
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        do_sample=True,
        temperature=0.2,
        max_new_tokens=512,
        use_cache=True
    )
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
Downloads last month
29
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for ASU-GSL/RetinalGPT

Finetuned
(123)
this model