medgemma-spider-finetuned

Fine-tuned MedGemma model for Spider dataset - Medical image analysis with multiple images per patient.

📋 Model Information

  • Base Model: google/medgemma-4b-it
  • Fine-tuning Method: LoRA (Low-Rank Adaptation)
  • Dataset: Spider dataset (series format - 1 patient = multiple images)
  • Training Framework: Unsloth (2x faster training)

📂 Folder Structure

output_medgemma_spider/
├── final_model/           # Full merged model (large)
├── lora_adapters/         # LoRA adapters only (recommended, lightweight)
├── checkpoint-*/          # Training checkpoints
├── trainer_state.json     # Training state
└── eval_metrics.json      # Evaluation metrics

🚀 Usage

1️⃣ Load LoRA Adapters (Recommended - Lightweight)

from unsloth import FastVisionModel

model, processor = FastVisionModel.from_pretrained(
    model_name="ImNotTam/medgemma-spider-finetuned",
    subfolder="lora_adapters",
    load_in_4bit=True,
)

# Enable inference mode
FastVisionModel.for_inference(model)

# Prepare input with multiple images
image_paths = ["path/to/image1.png", "path/to/image2.png", ...]
question = "What do you see in these images?"

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": img_path} for img_path in image_paths
        ] + [{"type": "text", "text": question}]
    }
]

# Generate response
inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
).to("cuda")

outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0], skip_special_tokens=True)
print(response)

2️⃣ Load Final Model (Full Model)

from transformers import AutoModelForVision2Seq, AutoProcessor

model = AutoModelForVision2Seq.from_pretrained(
    "ImNotTam/medgemma-spider-finetuned",
    subfolder="final_model",
    device_map="auto",
    torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(
    "ImNotTam/medgemma-spider-finetuned",
    subfolder="final_model"
)

# Use same inference code as above

3️⃣ Continue Training from LoRA Adapters

from unsloth import FastVisionModel
from trl import SFTTrainer

# Load LoRA adapter
model, processor = FastVisionModel.from_pretrained(
    model_name="ImNotTam/medgemma-spider-finetuned",
    subfolder="lora_adapters",
    load_in_4bit=True,
)

# Add new LoRA config để train tiếp
model = FastVisionModel.get_peft_model(
    model,
    r=24,
    lora_alpha=48,
    lora_dropout=0.1,
    finetune_vision_layers=True,
    finetune_language_layers=True,
)

# Train với data mới
trainer = SFTTrainer(
    model=model,
    tokenizer=processor,
    train_dataset=your_new_dataset,
    # ... training args
)
trainer.train()

📊 Training Details

  • LoRA Rank: 24
  • LoRA Alpha: 48
  • LoRA Dropout: 0.1
  • Batch Size: 2 (per device)
  • Gradient Accumulation: 12 steps
  • Effective Batch Size: 24
  • Learning Rate: 2e-4
  • Max Sequence Length: 1280
  • Max Images per Sample: 18
  • Epochs: 7

💡 Recommendations

  • For Inference: Use lora_adapters/ (lightweight, fast)
  • For Production: Use final_model/ (full model)
  • For Continued Training: Load lora_adapters/ + add new LoRA config

📦 Requirements

pip install unsloth transformers torch trl pillow

📄 License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support