| | import os
|
| | import torch
|
| | from model import AutoModel, Config
|
| |
|
| | def load_model(model_path, config_path):
|
| | """
|
| | 加载模型权重和配置
|
| | """
|
| |
|
| | if not os.path.exists(config_path):
|
| | raise FileNotFoundError(f"配置文件未找到: {config_path}")
|
| | print(f"加载配置文件: {config_path}")
|
| | config = Config()
|
| |
|
| |
|
| | model = AutoModel(config)
|
| |
|
| |
|
| | if not os.path.exists(model_path):
|
| | raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
| | print(f"加载模型权重: {model_path}")
|
| | state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
| | model.load_state_dict(state_dict)
|
| | model.eval()
|
| | print("模型加载成功并设置为评估模式。")
|
| |
|
| | return model, config
|
| |
|
| |
|
| | def run_inference(model, config):
|
| | """
|
| | 使用模型运行推理
|
| | """
|
| |
|
| | image = torch.randn(1, 3, 224, 224)
|
| | text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
|
| | audio = torch.randn(1, config.audio_sample_rate)
|
| |
|
| |
|
| | outputs = model(image, text, audio)
|
| | vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = outputs
|
| |
|
| |
|
| | print("\n推理结果:")
|
| | print(f"VQA output shape: {vqa_output.shape}")
|
| | print(f"Caption output shape: {caption_output.shape}")
|
| | print(f"Retrieval output shape: {retrieval_output.shape}")
|
| | print(f"ASR output shape: {asr_output.shape}")
|
| | print(f"Realtime ASR output shape: {realtime_asr_output.shape}")
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | model_path = "AutoModel.pth"
|
| | config_path = "config.json"
|
| |
|
| |
|
| | try:
|
| | model, config = load_model(model_path, config_path)
|
| |
|
| |
|
| | run_inference(model, config)
|
| | except Exception as e:
|
| | print(f"运行失败: {e}")
|
| |
|