medgemma-spider-finetuned
Fine-tuned MedGemma model for Spider dataset - Medical image analysis with multiple images per patient.
📋 Model Information
- Base Model:
google/medgemma-4b-it - Fine-tuning Method: LoRA (Low-Rank Adaptation)
- Dataset: Spider dataset (series format - 1 patient = multiple images)
- Training Framework: Unsloth (2x faster training)
📂 Folder Structure
output_medgemma_spider/
├── final_model/ # Full merged model (large)
├── lora_adapters/ # LoRA adapters only (recommended, lightweight)
├── checkpoint-*/ # Training checkpoints
├── trainer_state.json # Training state
└── eval_metrics.json # Evaluation metrics
🚀 Usage
1️⃣ Load LoRA Adapters (Recommended - Lightweight)
from unsloth import FastVisionModel
model, processor = FastVisionModel.from_pretrained(
model_name="ImNotTam/medgemma-spider-finetuned",
subfolder="lora_adapters",
load_in_4bit=True,
)
# Enable inference mode
FastVisionModel.for_inference(model)
# Prepare input with multiple images
image_paths = ["path/to/image1.png", "path/to/image2.png", ...]
question = "What do you see in these images?"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img_path} for img_path in image_paths
] + [{"type": "text", "text": question}]
}
]
# Generate response
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0], skip_special_tokens=True)
print(response)
2️⃣ Load Final Model (Full Model)
from transformers import AutoModelForVision2Seq, AutoProcessor
model = AutoModelForVision2Seq.from_pretrained(
"ImNotTam/medgemma-spider-finetuned",
subfolder="final_model",
device_map="auto",
torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(
"ImNotTam/medgemma-spider-finetuned",
subfolder="final_model"
)
# Use same inference code as above
3️⃣ Continue Training from LoRA Adapters
from unsloth import FastVisionModel
from trl import SFTTrainer
# Load LoRA adapter
model, processor = FastVisionModel.from_pretrained(
model_name="ImNotTam/medgemma-spider-finetuned",
subfolder="lora_adapters",
load_in_4bit=True,
)
# Add new LoRA config để train tiếp
model = FastVisionModel.get_peft_model(
model,
r=24,
lora_alpha=48,
lora_dropout=0.1,
finetune_vision_layers=True,
finetune_language_layers=True,
)
# Train với data mới
trainer = SFTTrainer(
model=model,
tokenizer=processor,
train_dataset=your_new_dataset,
# ... training args
)
trainer.train()
📊 Training Details
- LoRA Rank: 24
- LoRA Alpha: 48
- LoRA Dropout: 0.1
- Batch Size: 2 (per device)
- Gradient Accumulation: 12 steps
- Effective Batch Size: 24
- Learning Rate: 2e-4
- Max Sequence Length: 1280
- Max Images per Sample: 18
- Epochs: 7
💡 Recommendations
- For Inference: Use
lora_adapters/(lightweight, fast) - For Production: Use
final_model/(full model) - For Continued Training: Load
lora_adapters/+ add new LoRA config
📦 Requirements
pip install unsloth transformers torch trl pillow
📄 License
Apache 2.0
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support