File size: 5,014 Bytes
40b178e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import pytest
from unittest.mock import MagicMock, patch
import numpy as np
import torch
import os
import tempfile
from depthcrafter.inference import DepthCrafterInference


@pytest.fixture
def dummy_video_path():
    # Create a dummy video file (empty is fine since we mock read_video_frames)
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
        path = f.name

    yield path

    # Cleanup
    if os.path.exists(path):
        os.remove(path)


@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_init(mock_unet_cls, mock_pipeline_cls):
    mock_unet = MagicMock()
    mock_unet_cls.from_pretrained.return_value = mock_unet

    mock_pipeline = MagicMock()
    mock_pipeline_cls.from_pretrained.return_value = mock_pipeline

    # Test default (model offload)
    inference = DepthCrafterInference(
        unet_path="dummy_unet",
        pre_train_path="dummy_pretrain",
        cpu_offload="model",
        device="cpu",
    )
    mock_pipeline.enable_model_cpu_offload.assert_called()

    # Test sequential offload
    inference = DepthCrafterInference(
        unet_path="dummy_unet",
        pre_train_path="dummy_pretrain",
        cpu_offload="sequential",
        device="cpu",
    )
    mock_pipeline.enable_sequential_cpu_offload.assert_called()

    # Test no offload
    inference = DepthCrafterInference(
        unet_path="dummy_unet",
        pre_train_path="dummy_pretrain",
        cpu_offload=None,
        device="cpu",
    )
    mock_pipeline.to.assert_called_with("cpu")

    # Test invalid offload
    with pytest.raises(ValueError):
        DepthCrafterInference(
            unet_path="dummy_unet",
            pre_train_path="dummy_pretrain",
            cpu_offload="invalid",
            device="cpu",
        )


@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_clear_cache(mock_unet_cls, mock_pipeline_cls):
    mock_pipeline_cls.from_pretrained.return_value = MagicMock()
    inference = DepthCrafterInference("dummy", "dummy")

    with (
        patch("depthcrafter.inference.gc.collect") as mock_gc,
        patch("depthcrafter.inference.torch.cuda.empty_cache") as mock_cuda,
    ):
        inference.clear_cache()
        mock_gc.assert_called_once()
        mock_cuda.assert_called_once()


@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
def test_save_exr(mock_unet_cls, mock_pipeline_cls):
    mock_pipeline_cls.from_pretrained.return_value = MagicMock()
    inference = DepthCrafterInference("dummy", "dummy")

    # Mock OpenEXR and Imath
    with patch.dict("sys.modules", {"OpenEXR": MagicMock(), "Imath": MagicMock()}):
        import OpenEXR

        res = np.random.rand(2, 32, 32).astype(np.float32)
        with patch("depthcrafter.inference.os.makedirs") as mock_makedirs:
            inference._save_exr(res, "output_path")

            mock_makedirs.assert_called_with("output_path", exist_ok=True)
            assert OpenEXR.OutputFile.call_count == 2


@patch("depthcrafter.inference.DepthCrafterPipeline")
@patch("depthcrafter.inference.DiffusersUNetSpatioTemporalConditionModelDepthCrafter")
@patch("depthcrafter.inference.read_video_frames")
@patch("depthcrafter.inference.vis_sequence_depth")
@patch("depthcrafter.inference.save_video")
@patch("depthcrafter.inference.os.makedirs")
@patch("depthcrafter.inference.np.savez_compressed")
def test_infer(
    mock_savez,
    mock_makedirs,
    mock_save_video,
    mock_vis,
    mock_read_video,
    mock_unet_cls,
    mock_pipeline_cls,
    dummy_video_path,
):
    # Setup mocks
    mock_pipeline = MagicMock()
    mock_pipeline_cls.from_pretrained.return_value = mock_pipeline

    inference = DepthCrafterInference("dummy", "dummy", cpu_offload=None, device="cpu")

    # Mock read_video_frames
    frames = np.random.rand(10, 32, 32, 3).astype(np.float32)
    mock_read_video.return_value = (frames, 30)

    # Mock pipeline output
    mock_output = MagicMock()
    mock_output.frames = [np.random.rand(10, 32, 32, 3)]
    mock_pipeline.return_value = mock_output

    # Mock vis
    mock_vis.return_value = np.random.rand(10, 32, 32, 3)

    # Run infer
    result_paths = inference.infer(
        video_path=dummy_video_path,
        num_denoising_steps=1,
        guidance_scale=1.0,
        save_folder="output",
        save_npz=True,
        save_exr=True,
    )

    assert len(result_paths) == 3
    mock_read_video.assert_called()
    mock_pipeline.assert_called()
    mock_vis.assert_called()
    assert mock_save_video.call_count == 3  # depth, vis, input
    mock_savez.assert_called()
    # _save_exr is called internally, we can check if it logs error or runs if modules are present
    # Since we didn't mock OpenEXR here, it should log error and return, which is fine.