import os import torch import shutil import gradio as gr from mutagen.mp3 import MP3 from mutagen.flac import FLAC from piano_transcription_inference import PianoTranscription, load_audio, sample_rate from convert import midi2xml, xml2abc, xml2mxl, xml2jpg EN_US = os.getenv("LANG") != "zh_CN.UTF-8" TMP_DIR = "./__pycache__" if EN_US: import huggingface_hub MODEL_PATH = huggingface_hub.snapshot_download( "Genius-Society/piano_trans", cache_dir=TMP_DIR, ) else: import modelscope MODEL_PATH = modelscope.snapshot_download( "Genius-Society/piano_trans", cache_dir=TMP_DIR, ) ZH2EN = { "五线谱": "Staff", "状态栏": "Status", "下载 MXL": "Download MXL", "ABC 记谱": "ABC notation", "上传音频": "Upload an audio", "下载 MIDI": "Download MIDI", "下载 PDF 乐谱": "Download PDF score", "下载 MusicXML": "Download MusicXML", "钢琴转谱工具": "Piano Transcription Tool", "请上传音频 100% 后再点提交": "Please make sure the audio is completely uploaded before clicking Submit", } def _L(zh_txt: str): return ZH2EN[zh_txt] if EN_US else zh_txt def clean_cache(cache_dir): if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.mkdir(cache_dir) def extract_meta(audio_path: str): if not audio_path: raise ValueError("文件路径为空!") artist = None name, ext = os.path.splitext(os.path.basename(audio_path)) ext == ext.lower() if ext == ".mp3": audio = MP3(audio_path) title = audio.get("TIT2") artist = audio.get("TPE1") if title: title = title.text[0] if artist: artist = artist.text[0] elif ext == ".flac": audio = FLAC(audio_path) title = audio.get("TITLE") artist = audio.get("ARTIST") if title: title = title[0] if artist: artist = artist[0] if not title: title = name.strip().capitalize() return title, artist def audio2midi(audio_path: str, cache_dir: str): title, artist = extract_meta(audio_path) audio, _ = load_audio(audio_path, sr=sample_rate, mono=True) transcriptor = PianoTranscription( device="cuda" if torch.cuda.is_available() else "cpu", checkpoint_path=f"{MODEL_PATH}/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth", ) midi_path = f"{cache_dir}/output.mid" transcriptor.transcribe(audio, midi_path) return midi_path, title, artist def upl_infer(audio_path: str, cache_dir=f"{TMP_DIR}/cache"): status = "Success" midi = pdf = xml = mxl = abc = jpg = None try: clean_cache(cache_dir) midi, title, artist = audio2midi(audio_path, cache_dir) xml = midi2xml(midi, title, artist) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) except Exception as e: status = f"{e}" return status, midi, pdf, xml, mxl, abc, jpg def find_audio_files(folder_path=f"{MODEL_PATH}/examples"): wav_files = [] for root, _, files in os.walk(folder_path): for file in files: if file.endswith(".wav") or file.endswith(".mp3"): file_path = os.path.join(root, file) wav_files.append(file_path) return wav_files if __name__ == "__main__": gr.Interface( fn=upl_infer, inputs=gr.Audio(label=_L("上传音频"), type="filepath"), outputs=[ gr.Textbox(label=_L("状态栏"), show_copy_button=True), gr.File(label=_L("下载 MIDI")), gr.File(label=_L("下载 PDF 乐谱")), gr.File(label=_L("下载 MusicXML")), gr.File(label=_L("下载 MXL")), gr.TextArea(label=_L("ABC 记谱"), show_copy_button=True), gr.Image(label=_L("五线谱"), type="filepath", show_share_button=False), ], title=_L("钢琴转谱工具"), description=_L("请上传音频 100% 后再点提交"), flagging_mode="never", cache_examples=False, examples=find_audio_files(), ).launch()