File size: 4,147 Bytes
16c3fe2
 
 
 
4204256
 
16c3fe2
 
 
 
4204256
16c3fe2
 
 
 
 
 
4204256
16c3fe2
 
 
 
 
 
 
4204256
16c3fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4204256
 
 
 
2ee497f
4204256
9ebf373
4204256
2ee497f
 
 
0f5e048
 
 
 
4204256
 
2ee497f
0f5e048
 
 
 
 
 
4204256
9ebf373
 
 
2ee497f
4204256
 
16c3fe2
2ee497f
16c3fe2
 
 
 
 
 
 
2ee497f
16c3fe2
 
4204256
16c3fe2
 
 
 
2ee497f
 
16c3fe2
 
 
 
 
 
 
 
 
 
4204256
 
 
 
 
 
 
 
 
 
 
16c3fe2
 
 
 
 
 
 
 
 
 
ec98327
16c3fe2
 
 
 
 
 
4204256
16c3fe2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import torch
import shutil
import gradio as gr
from mutagen.mp3 import MP3
from mutagen.flac import FLAC
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate
from convert import midi2xml, xml2abc, xml2mxl, xml2jpg

EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
TMP_DIR = "./__pycache__"

if EN_US:
    import huggingface_hub

    MODEL_PATH = huggingface_hub.snapshot_download(
        "Genius-Society/piano_trans",
        cache_dir=TMP_DIR,
    )

else:
    import modelscope

    MODEL_PATH = modelscope.snapshot_download(
        "Genius-Society/piano_trans",
        cache_dir=TMP_DIR,
    )


ZH2EN = {
    "五线谱": "Staff",
    "状态栏": "Status",
    "下载 MXL": "Download MXL",
    "ABC 记谱": "ABC notation",
    "上传音频": "Upload an audio",
    "下载 MIDI": "Download MIDI",
    "下载 PDF 乐谱": "Download PDF score",
    "下载 MusicXML": "Download MusicXML",
    "钢琴转谱工具": "Piano Transcription Tool",
    "请上传音频 100% 后再点提交": "Please make sure the audio is completely uploaded before clicking Submit",
}


def _L(zh_txt: str):
    return ZH2EN[zh_txt] if EN_US else zh_txt


def clean_cache(cache_dir):
    if os.path.exists(cache_dir):
        shutil.rmtree(cache_dir)

    os.mkdir(cache_dir)


def extract_meta(audio_path: str):
    if not audio_path:
        raise ValueError("文件路径为空!")

    artist = None
    name, ext = os.path.splitext(os.path.basename(audio_path))
    ext == ext.lower()
    if ext == ".mp3":
        audio = MP3(audio_path)
        title = audio.get("TIT2")
        artist = audio.get("TPE1")
        if title:
            title = title.text[0]
        if artist:
            artist = artist.text[0]

    elif ext == ".flac":
        audio = FLAC(audio_path)
        title = audio.get("TITLE")
        artist = audio.get("ARTIST")
        if title:
            title = title[0]
        if artist:
            artist = artist[0]

    if not title:
        title = name.strip().capitalize()

    return title, artist


def audio2midi(audio_path: str, cache_dir: str):
    title, artist = extract_meta(audio_path)
    audio, _ = load_audio(audio_path, sr=sample_rate, mono=True)
    transcriptor = PianoTranscription(
        device="cuda" if torch.cuda.is_available() else "cpu",
        checkpoint_path=f"{MODEL_PATH}/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth",
    )
    midi_path = f"{cache_dir}/output.mid"
    transcriptor.transcribe(audio, midi_path)
    return midi_path, title, artist


def upl_infer(audio_path: str, cache_dir=f"{TMP_DIR}/cache"):
    status = "Success"
    midi = pdf = xml = mxl = abc = jpg = None
    try:
        clean_cache(cache_dir)
        midi, title, artist = audio2midi(audio_path, cache_dir)
        xml = midi2xml(midi, title, artist)
        abc = xml2abc(xml)
        mxl = xml2mxl(xml)
        pdf, jpg = xml2jpg(xml)

    except Exception as e:
        status = f"{e}"

    return status, midi, pdf, xml, mxl, abc, jpg


def find_audio_files(folder_path=f"{MODEL_PATH}/examples"):
    wav_files = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".wav") or file.endswith(".mp3"):
                file_path = os.path.join(root, file)
                wav_files.append(file_path)

    return wav_files


if __name__ == "__main__":
    gr.Interface(
        fn=upl_infer,
        inputs=gr.Audio(label=_L("上传音频"), type="filepath"),
        outputs=[
            gr.Textbox(label=_L("状态栏"), show_copy_button=True),
            gr.File(label=_L("下载 MIDI")),
            gr.File(label=_L("下载 PDF 乐谱")),
            gr.File(label=_L("下载 MusicXML")),
            gr.File(label=_L("下载 MXL")),
            gr.TextArea(label=_L("ABC 记谱"), show_copy_button=True),
            gr.Image(label=_L("五线谱"), type="filepath", show_share_button=False),
        ],
        title=_L("钢琴转谱工具"),
        description=_L("请上传音频 100% 后再点提交"),
        flagging_mode="never",
        cache_examples=False,
        examples=find_audio_files(),
    ).launch()