Spaces:
Sleeping
Sleeping
File size: 5,014 Bytes
40b178e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import pytest
from unittest.mock import MagicMock, patch
import numpy as np
import torch
import os
import tempfile
from depthcrafter.inference import DepthCrafterInference
@pytest.fixture
def dummy_video_path():
# Create a dummy video file (empty is fine since we mock read_video_frames)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
path = f.name
yield path
# Cleanup
if os.path.exists(path):
os.remove(path)
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_init(mock_unet_cls, mock_pipeline_cls):
mock_unet = MagicMock()
mock_unet_cls.from_pretrained.return_value = mock_unet
mock_pipeline = MagicMock()
mock_pipeline_cls.from_pretrained.return_value = mock_pipeline
# Test default (model offload)
inference = DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload="model",
device="cpu",
)
mock_pipeline.enable_model_cpu_offload.assert_called()
# Test sequential offload
inference = DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload="sequential",
device="cpu",
)
mock_pipeline.enable_sequential_cpu_offload.assert_called()
# Test no offload
inference = DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload=None,
device="cpu",
)
mock_pipeline.to.assert_called_with("cpu")
# Test invalid offload
with pytest.raises(ValueError):
DepthCrafterInference(
unet_path="dummy_unet",
pre_train_path="dummy_pretrain",
cpu_offload="invalid",
device="cpu",
)
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_clear_cache(mock_unet_cls, mock_pipeline_cls):
mock_pipeline_cls.from_pretrained.return_value = MagicMock()
inference = DepthCrafterInference("dummy", "dummy")
with (
patch("depthcrafter.inference.gc.collect") as mock_gc,
patch("depthcrafter.inference.torch.cuda.empty_cache") as mock_cuda,
):
inference.clear_cache()
mock_gc.assert_called_once()
mock_cuda.assert_called_once()
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_save_exr(mock_unet_cls, mock_pipeline_cls):
mock_pipeline_cls.from_pretrained.return_value = MagicMock()
inference = DepthCrafterInference("dummy", "dummy")
# Mock OpenEXR and Imath
with patch.dict("sys.modules", {"OpenEXR": MagicMock(), "Imath": MagicMock()}):
import OpenEXR
res = np.random.rand(2, 32, 32).astype(np.float32)
with patch("depthcrafter.inference.os.makedirs") as mock_makedirs:
inference._save_exr(res, "output_path")
mock_makedirs.assert_called_with("output_path", exist_ok=True)
assert OpenEXR.OutputFile.call_count == 2
@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
@patch("depthcrafter.inference.read_video_frames")
@patch("depthcrafter.inference.vis_sequence_depth")
@patch("depthcrafter.inference.save_video")
@patch("depthcrafter.inference.os.makedirs")
@patch("depthcrafter.inference.np.savez_compressed")
def test_infer(
mock_savez,
mock_makedirs,
mock_save_video,
mock_vis,
mock_read_video,
mock_unet_cls,
mock_pipeline_cls,
dummy_video_path,
):
# Setup mocks
mock_pipeline = MagicMock()
mock_pipeline_cls.from_pretrained.return_value = mock_pipeline
inference = DepthCrafterInference("dummy", "dummy", cpu_offload=None, device="cpu")
# Mock read_video_frames
frames = np.random.rand(10, 32, 32, 3).astype(np.float32)
mock_read_video.return_value = (frames, 30)
# Mock pipeline output
mock_output = MagicMock()
mock_output.frames = [np.random.rand(10, 32, 32, 3)]
mock_pipeline.return_value = mock_output
# Mock vis
mock_vis.return_value = np.random.rand(10, 32, 32, 3)
# Run infer
result_paths = inference.infer(
video_path=dummy_video_path,
num_denoising_steps=1,
guidance_scale=1.0,
save_folder="output",
save_npz=True,
save_exr=True,
)
assert len(result_paths) == 3
mock_read_video.assert_called()
mock_pipeline.assert_called()
mock_vis.assert_called()
assert mock_save_video.call_count == 3 # depth, vis, input
mock_savez.assert_called()
# _save_exr is called internally, we can check if it logs error or runs if modules are present
# Since we didn't mock OpenEXR here, it should log error and return, which is fine.
|