Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from google import genai | |
| from google.genai import types | |
| from pydub import AudioSegment | |
| import webrtcvad | |
| import yaml | |
| import json | |
| import io | |
| import os | |
| import re | |
| from datetime import timedelta | |
| import logging | |
| import asyncio | |
| from pydantic import BaseModel, Field | |
| logging.basicConfig(level=logging.INFO) | |
| app = FastAPI() | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| TARGET_CHUNK_DURATION_MIN = 11 | |
| TARGET_CHUNK_DURATION_MS = TARGET_CHUNK_DURATION_MIN * 60 * 1000 | |
| MIN_SPLIT_SEARCH_START_MIN = 10 | |
| MIN_SPLIT_SEARCH_START_MS = MIN_SPLIT_SEARCH_START_MIN * 60 * 1000 | |
| MAX_SPLIT_SEARCH_END_MIN = 13 | |
| MAX_SPLIT_SEARCH_END_MS = MAX_SPLIT_SEARCH_END_MIN * 60 * 1000 | |
| MIN_SILENCE_LEN_MS = 800 | |
| VAD_AGGRESSIVENESS = 3 | |
| class TranscriptionSegment(BaseModel): | |
| id: int = Field(description="מספר סידורי של הכתובית", ge=1) | |
| start_time: str = Field(description="שעת התחלה בפורמט HH:MM:SS,mmm") | |
| end_time: str = Field(description="שעת סיום בפורמט HH:MM:SS,mmm") | |
| text: str = Field(description="תוכן הכתובית") | |
| def load_prompts(): | |
| try: | |
| with open("instruct.yml", 'r', encoding='utf-8') as f: | |
| return yaml.safe_load(f) | |
| except Exception as e: | |
| logging.error(f"Error loading instruct.yml: {e}") | |
| raise HTTPException(status_code=500, detail="Server configuration error.") | |
| def parse_time_str_to_ms(time_str: str) -> int: | |
| if not isinstance(time_str, str): | |
| raise TypeError(f"Time string must be a string, got {type(time_str)}") | |
| time_str = time_str.replace(',', '.') | |
| last_colon_pos = time_str.rfind(':') | |
| last_period_pos = time_str.rfind('.') | |
| h, m, s, ms = 0, 0, 0, 0 | |
| try: | |
| if last_period_pos > last_colon_pos: | |
| hms_part = time_str[:last_period_pos] | |
| ms_part = time_str[last_period_pos+1:] | |
| ms = int(ms_part.ljust(3, '0')[:3]) | |
| time_components = list(map(int, hms_part.split(':'))) | |
| if len(time_components) == 3: h, m, s = time_components | |
| elif len(time_components) == 2: m, s = time_components | |
| elif len(time_components) == 1: s = time_components[0] | |
| else: | |
| components = list(map(int, time_str.split(':'))) | |
| # If the last component is > 59, it must be milliseconds | |
| if len(components) >= 2 and components[-1] > 59: | |
| ms = components[-1] | |
| s = components[-2] | |
| if len(components) == 3: m = components[0] | |
| elif len(components) > 3: h, m = components[0], components[1] # For very long times | |
| # Otherwise, it's a standard HH:MM:SS format | |
| else: | |
| if len(components) == 3: h, m, s = components | |
| elif len(components) == 2: m, s = components | |
| elif len(components) == 1: s = components[0] | |
| return (h * 3600000) + (m * 60000) + (s * 1000) + ms | |
| except (ValueError, IndexError) as e: | |
| raise ValueError(f"Could not parse adaptive time string: '{time_str}'. Error: {e}") | |
| def format_ms_to_srt_time(ms: int) -> str: | |
| td = timedelta(milliseconds=ms) | |
| total_seconds = int(td.total_seconds()) | |
| hours, remainder = divmod(total_seconds, 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| milliseconds = td.microseconds // 1000 | |
| return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}" | |
| def adjust_plain_text_timestamps(text_block: str, offset_ms: int) -> str: | |
| def replacer(match): | |
| time_str = match.group(1) | |
| try: | |
| time_ms = parse_time_str_to_ms(time_str) | |
| new_time_ms = time_ms + offset_ms | |
| new_time_str = format_ms_to_srt_time(new_time_ms).replace(',', '.') | |
| return f'[{new_time_str}]' | |
| except ValueError: | |
| return match.group(0) | |
| timestamp_pattern = r'\[(\d{2}:\d{2}:\d{2}[,.]\d{3})\]' | |
| return re.sub(timestamp_pattern, replacer, text_block) | |
| def find_silence_points_webrtcvad(audio_segment: AudioSegment, min_silence_len_ms: int, vad_aggressiveness: int): | |
| if audio_segment.frame_rate not in [8000, 16000, 32000, 48000]: | |
| audio_segment = audio_segment.set_frame_rate(16000) | |
| if audio_segment.channels > 1: audio_segment = audio_segment.set_channels(1) | |
| if audio_segment.sample_width != 2: audio_segment = audio_segment.set_sample_width(2) | |
| vad = webrtcvad.Vad(vad_aggressiveness) | |
| frame_duration_ms = 30 | |
| frame_size_bytes = int(audio_segment.frame_rate * (frame_duration_ms / 1000.0) * 2) | |
| silence_points_ms, silence_start_ms = [], None | |
| raw_data, num_frames = audio_segment.raw_data, len(audio_segment.raw_data) // frame_size_bytes | |
| for i in range(num_frames): | |
| frame = raw_data[i*frame_size_bytes:(i+1)*frame_size_bytes] | |
| if len(frame) < frame_size_bytes: break | |
| is_speech, current_time_ms = vad.is_speech(frame, audio_segment.frame_rate), i * frame_duration_ms | |
| if not is_speech: | |
| if silence_start_ms is None: silence_start_ms = current_time_ms | |
| elif silence_start_ms is not None: | |
| if current_time_ms - silence_start_ms >= min_silence_len_ms: | |
| silence_points_ms.append(silence_start_ms + (current_time_ms - silence_start_ms) // 2) | |
| silence_start_ms = None | |
| if silence_start_ms is not None and len(audio_segment) - silence_start_ms >= min_silence_len_ms: | |
| silence_points_ms.append(silence_start_ms + (len(audio_segment) - silence_start_ms) // 2) | |
| return silence_points_ms | |
| def split_audio_webrtcvad(audio_segment, min_silence_len): | |
| logging.info(f"Splitting with WebRTCVAD: Target Chunk {TARGET_CHUNK_DURATION_MIN}m, VAD Aggressiveness {VAD_AGGRESSIVENESS}") | |
| silence_points = find_silence_points_webrtcvad(audio_segment, min_silence_len, VAD_AGGRESSIVENESS) | |
| if not silence_points: | |
| logging.warning("WebRTCVAD found no significant silences. Splitting into fixed chunks.") | |
| return [audio_segment[i:i + TARGET_CHUNK_DURATION_MS] for i in range(0, len(audio_segment), TARGET_CHUNK_DURATION_MS)] | |
| final_chunks, current_offset, total_length = [], 0, len(audio_segment) | |
| while current_offset < total_length: | |
| if total_length - current_offset <= MAX_SPLIT_SEARCH_END_MS: | |
| final_chunks.append(audio_segment[current_offset:]) | |
| break | |
| ideal_split_point = current_offset + TARGET_CHUNK_DURATION_MS | |
| candidate_points = [p for p in silence_points if (current_offset + MIN_SPLIT_SEARCH_START_MS) <= p < (current_offset + MAX_SPLIT_SEARCH_END_MS)] | |
| best_split_point = min(candidate_points, key=lambda p: abs(p - ideal_split_point)) if candidate_points else -1 | |
| split_at = best_split_point if best_split_point != -1 else ideal_split_point | |
| final_chunks.append(audio_segment[current_offset:int(split_at)]) | |
| current_offset = int(split_at) | |
| logging.info(f"File successfully split into {len(final_chunks)} chunks using WebRTCVAD.") | |
| logging.info(f"Chunk durations (seconds): {[round(len(c) / 1000) for c in final_chunks]}") | |
| return final_chunks | |
| def _trim_markdown_fences(text: str) -> str: | |
| """Removes markdown code block fences from a string.""" | |
| text = text.strip() | |
| if text.startswith("```") and text.endswith("```"): | |
| text = text[3:-3].strip() | |
| if text.startswith("json"): | |
| text = text[4:].strip() | |
| return text | |
| def validate_and_correct_segments(segments_from_api, chunk_duration_ms): | |
| corrected_segments, last_corrected_end_ms = [], 0 | |
| for seg in segments_from_api: | |
| try: | |
| start_ms = parse_time_str_to_ms(seg.get('start_time')) | |
| end_ms = parse_time_str_to_ms(seg.get('end_time')) | |
| if start_ms >= chunk_duration_ms: | |
| logging.warning(f"Skipping segment with true hallucinatory start_time ({format_ms_to_srt_time(start_ms)}) outside of chunk duration ({format_ms_to_srt_time(chunk_duration_ms)}).") | |
| continue | |
| if end_ms > chunk_duration_ms: end_ms = chunk_duration_ms | |
| if start_ms >= end_ms: end_ms = min(start_ms + 3000, chunk_duration_ms) | |
| if start_ms < last_corrected_end_ms: start_ms = last_corrected_end_ms | |
| if start_ms >= end_ms: continue | |
| seg['start_time_relative'], seg['end_time_relative'] = start_ms, end_ms | |
| corrected_segments.append(seg) | |
| last_corrected_end_ms = end_ms | |
| except (ValueError, TypeError, KeyError) as e: | |
| logging.warning(f"Skipping segment due to invalid format or value: {seg}. Error: {e}") | |
| continue | |
| return corrected_segments | |
| def transcribe_chunk(chunk_audio, api_key, system_prompt, pydantic_schema, model_name, user_prompt): | |
| client = genai.Client(api_key=api_key) | |
| buffer = io.BytesIO() | |
| chunk_audio.export(buffer, format="wav") | |
| audio_part = types.Part.from_bytes(data=buffer.getvalue(), mime_type='audio/wav') | |
| contents: list[types.Content] = [] | |
| if user_prompt: | |
| contents.append(types.Content( | |
| role="user", | |
| parts=[types.Part(text=user_prompt)] | |
| )) | |
| contents.append(types.Content( | |
| role="model", | |
| parts=[types.Part(text="אני מוכן. אנא שלח את קטע השמע לתמלול.")] | |
| )) | |
| prompt = "אנא תמלל את קטע השמע בהתאם להנחיית המערכת ולסכמת ה-JSON." | |
| contents.append(types.Content( | |
| role="user", | |
| parts=[types.Part(text=prompt), audio_part] | |
| )) | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=contents, | |
| config=types.GenerateContentConfig( | |
| system_instruction=system_prompt, | |
| response_mime_type="application/json", | |
| response_schema=list[pydantic_schema] | |
| ) | |
| ) | |
| cleaned_text = _trim_markdown_fences(response.text) | |
| return json.loads(cleaned_text), None | |
| def transcribe_chunk_plain_text(chunk_audio, api_key, system_prompt, model_name, user_prompt): | |
| client = genai.Client(api_key=api_key) | |
| buffer = io.BytesIO() | |
| chunk_audio.export(buffer, format="wav") | |
| audio_part = types.Part.from_bytes(data=buffer.getvalue(), mime_type='audio/wav') | |
| contents: list[types.Content] = [] | |
| if user_prompt: | |
| contents.append(types.Content( | |
| role="user", parts=[types.Part(text=user_prompt)] | |
| )) | |
| contents.append(types.Content( | |
| role="model", parts=[types.Part(text="אני מוכן. אנא שלח את קטע השמע לעיבוד.")] | |
| )) | |
| prompt = "אנא עבד את קטע השמע המצורף והפק את התמלול הערוך בפורמט טקסט כפי שהוגדר בהנחיית המערכת." | |
| contents.append(types.Content( | |
| role="user", parts=[types.Part(text=prompt), audio_part] | |
| )) | |
| try: | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=contents, | |
| config=types.GenerateContentConfig(system_instruction=system_prompt) | |
| ) | |
| cleaned_text = _trim_markdown_fences(response.text) | |
| return cleaned_text, None | |
| except Exception as e: | |
| logging.error(f"Error in transcribe_chunk_plain_text: {e}") | |
| return None, str(e) | |
| def generate_srt_content(segments): | |
| lines = [] | |
| for i, seg in enumerate(segments, 1): | |
| lines.extend([str(i), f"{format_ms_to_srt_time(seg['start_time_abs'])} --> {format_ms_to_srt_time(seg['end_time_abs'])}", seg['text'], ""]) | |
| return "\n".join(lines) | |
| async def _transcribe_and_stream(api_key: str, file_content: bytes, model_name: str, user_prompt: str, output_format: str): | |
| def send_event(type: str, message: str = "", percent: int = 0, data: str = ""): | |
| return json.dumps({"type": type, "message": message, "percent": percent, "data": data}) + "\n\n" | |
| try: | |
| prompts = load_prompts() | |
| yield send_event("progress", "מעבד את קובץ המדיה...", 5) | |
| audio = AudioSegment.from_file(io.BytesIO(file_content)) | |
| if len(audio) < MIN_SPLIT_SEARCH_START_MS: | |
| chunks = [audio] | |
| yield send_event("progress", f"אורך פס הקול {len(audio) / 60000:.1f} דקות. הקובץ קצר ויעובד כמקטע יחיד.", 15) | |
| else: | |
| yield send_event("progress", f"אורך פס הקול {len(audio) / 60000:.1f} דקות. מבצע חלוקה...", 15) | |
| chunks = await asyncio.to_thread(split_audio_webrtcvad, audio, MIN_SILENCE_LEN_MS) | |
| if not chunks: raise ValueError("לא נוצרו מקטעי שמע לעיבוד.") | |
| yield send_event("progress", f"הקובץ חולק ל-{len(chunks)} מקטעים. מתחיל תמלול...", 20) | |
| offset = 0 | |
| if output_format == 'srt': | |
| pydantic_schema = TranscriptionSegment | |
| system_prompt = prompts['system_prompt'] | |
| all_segs = [] | |
| for i, ch in enumerate(chunks): | |
| progress_percent = 20 + int(((i + 1) / len(chunks)) * 75) | |
| yield send_event("progress", f"מתמלל מקטע {i+1} מתוך {len(chunks)}...", progress_percent) | |
| data, error_msg = await asyncio.to_thread(transcribe_chunk, ch, api_key, system_prompt, pydantic_schema, model_name, user_prompt) | |
| if error_msg: raise ValueError(f"שגיאה בעיבוד מקטע {i+1}: {error_msg}") | |
| if data and isinstance(data, list): | |
| corrected_segments = validate_and_correct_segments(data, len(ch)) | |
| for seg in corrected_segments: | |
| seg['start_time_abs'] = seg['start_time_relative'] + offset | |
| seg['end_time_abs'] = seg['end_time_relative'] + offset | |
| all_segs.append(seg) | |
| offset += len(ch) | |
| if not all_segs: raise ValueError("התמלול נכשל. לא נוצר תוכן תקני.") | |
| yield send_event("result", "התהליך הושלם בהצלחה!", 100, data=generate_srt_content(all_segs)) | |
| elif output_format == 'plain_text': | |
| system_prompt = prompts['plain_text_prompt'] | |
| all_text_blocks = [] | |
| for i, ch in enumerate(chunks): | |
| progress_percent = 20 + int(((i + 1) / len(chunks)) * 75) | |
| yield send_event("progress", f"מעבד מקטע טקסט {i+1} מתוך {len(chunks)}...", progress_percent) | |
| text_block, error_msg = await asyncio.to_thread(transcribe_chunk_plain_text, ch, api_key, system_prompt, model_name, user_prompt) | |
| if error_msg: raise ValueError(f"שגיאה בעיבוד מקטע {i+1}: {error_msg}") | |
| if text_block: | |
| adjusted_block = adjust_plain_text_timestamps(text_block, offset) | |
| all_text_blocks.append(adjusted_block) | |
| offset += len(ch) | |
| if not all_text_blocks: raise ValueError("התמלול נכשל. לא נוצר תוכן.") | |
| final_text = "\n\n".join(all_text_blocks).strip() | |
| standardized_text = re.sub(r'\[(\d{2}:\d{2}:\d{2}),(\d{3})\]', r'[\1.\2]', final_text) | |
| if standardized_text.startswith("```") and standardized_text.endswith("```"): | |
| standardized_text = standardized_text[3:-3].strip() | |
| total_duration_ms = len(audio) | |
| final_data_with_duration = f"{standardized_text}\n<!-- DURATION_MS:{total_duration_ms} -->" | |
| yield send_event("result", "התהליך הושלם בהצלחה!", 100, data=final_data_with_duration) | |
| except Exception as e: | |
| logging.error(f"Streaming transcription failed: {e}", exc_info=True) | |
| yield send_event("error", f"אירעה שגיאה: {e}", 100) | |
| def parse_srt_for_text_conversion(srt_content: str): | |
| segments = [] | |
| blocks = srt_content.strip().replace('\r\n', '\n').split('\n\n') | |
| for block in blocks: | |
| if not block.strip(): | |
| continue | |
| lines = block.split('\n') | |
| if len(lines) >= 2: | |
| try: | |
| int(lines[0]) | |
| if '-->' in lines[1]: | |
| start_str, end_str = lines[1].split(' --> ') | |
| text = '\n'.join(lines[2:]) | |
| segments.append({ | |
| 'start_ms': parse_time_str_to_ms(start_str), | |
| 'end_ms': parse_time_str_to_ms(end_str), | |
| 'text': text.strip() | |
| }) | |
| except (ValueError, IndexError): | |
| logging.warning(f"Skipping malformed SRT block during text conversion: {block}") | |
| continue | |
| return segments | |
| def convert_srt_to_formatted_text(srt_content: str) -> str: | |
| segments = parse_srt_for_text_conversion(srt_content) | |
| if not segments: | |
| return "" | |
| segments.sort(key=lambda s: s['start_ms']) | |
| result_parts = [] | |
| if segments: | |
| result_parts.append(segments[0]['text']) | |
| for i in range(1, len(segments)): | |
| prev_seg = segments[i-1] | |
| curr_seg = segments[i] | |
| time_gap_s = (curr_seg['start_ms'] - prev_seg['end_ms']) / 1000.0 | |
| prev_text = prev_seg['text'] | |
| if time_gap_s > 1.5 or prev_text.endswith(('?', '!')): | |
| prefix = '\n\n' | |
| elif not prev_text.endswith(('.', ',')): | |
| prefix = ' ' | |
| # Rule 3: Default connection for all other cases (e.g., after ./,) | |
| else: | |
| prefix = ' ' | |
| result_parts.append(prefix + curr_seg['text']) | |
| return "".join(result_parts) | |
| def convert_text_to_srt(text_content: str) -> str: | |
| duration_ms = None | |
| duration_match = re.search(r'<!-- DURATION_MS:(\d+) -->\s*$', text_content) | |
| if duration_match: | |
| duration_ms = int(duration_match.group(1)) | |
| text_content = text_content[:duration_match.start()].strip() | |
| pattern = re.compile(r'\[(\d{2}:\d{2}:\d{2}[.,]\d{3})\]\s*([\s\S]*?)(?=\s*\[\d{2}:\d{2}:\d{2}[.,]\d{3}\]|$)', re.DOTALL) | |
| matches = pattern.findall(text_content) | |
| if not matches: | |
| return "" | |
| segments = [] | |
| for i, (start_str, text) in enumerate(matches): | |
| start_ms = parse_time_str_to_ms(start_str) | |
| cleaned_text = text.strip() | |
| if cleaned_text: | |
| segments.append({ | |
| 'id': i + 1, | |
| 'start_ms': start_ms, | |
| 'text': cleaned_text | |
| }) | |
| for i in range(len(segments)): | |
| if i < len(segments) - 1: | |
| segments[i]['end_ms'] = segments[i+1]['start_ms'] - 1 | |
| else: | |
| if duration_ms is not None: | |
| segments[i]['end_ms'] = duration_ms | |
| else: | |
| segments[i]['end_ms'] = segments[i]['start_ms'] + 5000 | |
| if segments[i]['end_ms'] <= segments[i]['start_ms']: | |
| segments[i]['end_ms'] = segments[i]['start_ms'] + 1000 | |
| srt_output = [] | |
| for seg in segments: | |
| start_time_srt = format_ms_to_srt_time(seg['start_ms']) | |
| end_time_srt = format_ms_to_srt_time(seg['end_ms']) | |
| srt_output.append(f"{seg['id']}\n{start_time_srt} --> {end_time_srt}\n{seg['text']}\n") | |
| return "\n".join(srt_output) | |
| class SrtContent(BaseModel): | |
| srt_data: str | |
| class PlainTextContent(BaseModel): | |
| text_data: str | |
| async def handle_text_conversion(payload: SrtContent): | |
| if not payload.srt_data: | |
| raise HTTPException(status_code=400, detail="SRT data is missing.") | |
| try: | |
| formatted_text = convert_srt_to_formatted_text(payload.srt_data) | |
| return {"text": formatted_text} | |
| except Exception as e: | |
| logging.error(f"Error during SRT to text conversion: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to convert SRT to text.") | |
| async def handle_srt_conversion(payload: PlainTextContent): | |
| if not payload.text_data: | |
| raise HTTPException(status_code=400, detail="Text data is missing.") | |
| try: | |
| srt_formatted_text = convert_text_to_srt(payload.text_data) | |
| return {"srt": srt_formatted_text} | |
| except Exception as e: | |
| logging.error(f"Error during Text to SRT conversion: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to convert text to SRT.") | |
| async def read_root(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def handle_transcription_stream( | |
| api_key: str = Form(...), | |
| model_name: str = Form(...), | |
| output_format: str = Form(...), | |
| user_prompt: str = Form(""), | |
| audio_file: UploadFile = File(...) | |
| ): | |
| if not all([api_key, model_name, audio_file, output_format]): | |
| raise HTTPException(status_code=400, detail="Required form fields are missing.") | |
| file_content = await audio_file.read() | |
| return StreamingResponse(_transcribe_and_stream(api_key, file_content, model_name, user_prompt, output_format), media_type="text/event-stream") |