DepthCrafter / unit_tests /test_inference.py
fusodoya's picture
Upload folder using huggingface_hub
40b178e verified
import pytest
from unittest.mock import MagicMock, patch
import numpy as np
import torch
import os
import tempfile
from depthcrafter.inference import DepthCrafterInference
@pytest.fixture
def dummy_video_path():
# Create a dummy video file (empty is fine since we mock read_video_frames)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
path = f.name
yield path
# Cleanup
if os.path.exists(path):
os.remove(path)
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_init(mock_unet_cls, mock_pipeline_cls):
mock_unet = MagicMock()
mock_unet_cls.from_pretrained.return_value = mock_unet
mock_pipeline = MagicMock()
mock_pipeline_cls.from_pretrained.return_value = mock_pipeline
# Test default (model offload)
inference = DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload="model",
device="cpu",
)
mock_pipeline.enable_model_cpu_offload.assert_called()
# Test sequential offload
inference = DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload="sequential",
device="cpu",
)
mock_pipeline.enable_sequential_cpu_offload.assert_called()
# Test no offload
inference = DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload=None,
device="cpu",
)
mock_pipeline.to.assert_called_with("cpu")
# Test invalid offload
with pytest.raises(ValueError):
DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload="invalid",
device="cpu",
)
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_clear_cache(mock_unet_cls, mock_pipeline_cls):
mock_pipeline_cls.from_pretrained.return_value = MagicMock()
inference = DepthCrafterInference("dummy", "dummy")
with (
patch("depthcrafter.inference.gc.collect") as mock_gc,
patch("depthcrafter.inference.torch.cuda.empty_cache") as mock_cuda,
):
inference.clear_cache()
mock_gc.assert_called_once()
mock_cuda.assert_called_once()
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_save_exr(mock_unet_cls, mock_pipeline_cls):
mock_pipeline_cls.from_pretrained.return_value = MagicMock()
inference = DepthCrafterInference("dummy", "dummy")
# Mock OpenEXR and Imath
with patch.dict("sys.modules", {"OpenEXR": MagicMock(), "Imath": MagicMock()}):
import OpenEXR
res = np.random.rand(2, 32, 32).astype(np.float32)
with patch("depthcrafter.inference.os.makedirs") as mock_makedirs:
inference._save_exr(res, "output_path")
mock_makedirs.assert_called_with("output_path", exist_ok=True)
assert OpenEXR.OutputFile.call_count == 2
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
@patch("depthcrafter.inference.read_video_frames")
@patch("depthcrafter.inference.vis_sequence_depth")
@patch("depthcrafter.inference.save_video")
@patch("depthcrafter.inference.os.makedirs")
@patch("depthcrafter.inference.np.savez_compressed")
def test_infer(
mock_savez,
mock_makedirs,
mock_save_video,
mock_vis,
mock_read_video,
mock_unet_cls,
mock_pipeline_cls,
dummy_video_path,
):
# Setup mocks
mock_pipeline = MagicMock()
mock_pipeline_cls.from_pretrained.return_value = mock_pipeline
inference = DepthCrafterInference("dummy", "dummy", cpu_offload=None, device="cpu")
# Mock read_video_frames
frames = np.random.rand(10, 32, 32, 3).astype(np.float32)
mock_read_video.return_value = (frames, 30)
# Mock pipeline output
mock_output = MagicMock()
mock_output.frames = [np.random.rand(10, 32, 32, 3)]
mock_pipeline.return_value = mock_output
# Mock vis
mock_vis.return_value = np.random.rand(10, 32, 32, 3)
# Run infer
result_paths = inference.infer(
video_path=dummy_video_path,
num_denoising_steps=1,
guidance_scale=1.0,
save_folder="output",
save_npz=True,
save_exr=True,
)
assert len(result_paths) == 3
mock_read_video.assert_called()
mock_pipeline.assert_called()
mock_vis.assert_called()
assert mock_save_video.call_count == 3 # depth, vis, input
mock_savez.assert_called()
# _save_exr is called internally, we can check if it logs error or runs if modules are present
# Since we didn't mock OpenEXR here, it should log error and return, which is fine.