NHLOCAL's picture
refactor: Remove chunk timestamp display
78ed788
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from google import genai
from google.genai import types
from pydub import AudioSegment
import webrtcvad
import yaml
import json
import io
import os
import re
from datetime import timedelta
import logging
import asyncio
from pydantic import BaseModel, Field
logging.basicConfig(level=logging.INFO)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
TARGET_CHUNK_DURATION_MIN = 11
TARGET_CHUNK_DURATION_MS = TARGET_CHUNK_DURATION_MIN * 60 * 1000
MIN_SPLIT_SEARCH_START_MIN = 10
MIN_SPLIT_SEARCH_START_MS = MIN_SPLIT_SEARCH_START_MIN * 60 * 1000
MAX_SPLIT_SEARCH_END_MIN = 13
MAX_SPLIT_SEARCH_END_MS = MAX_SPLIT_SEARCH_END_MIN * 60 * 1000
MIN_SILENCE_LEN_MS = 800
VAD_AGGRESSIVENESS = 3
class TranscriptionSegment(BaseModel):
id: int = Field(description="מספר סידורי של הכתובית", ge=1)
start_time: str = Field(description="שעת התחלה בפורמט HH:MM:SS,mmm")
end_time: str = Field(description="שעת סיום בפורמט HH:MM:SS,mmm")
text: str = Field(description="תוכן הכתובית")
def load_prompts():
try:
with open("instruct.yml", 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except Exception as e:
logging.error(f"Error loading instruct.yml: {e}")
raise HTTPException(status_code=500, detail="Server configuration error.")
def parse_time_str_to_ms(time_str: str) -> int:
if not isinstance(time_str, str):
raise TypeError(f"Time string must be a string, got {type(time_str)}")
time_str = time_str.replace(',', '.')
last_colon_pos = time_str.rfind(':')
last_period_pos = time_str.rfind('.')
h, m, s, ms = 0, 0, 0, 0
try:
if last_period_pos > last_colon_pos:
hms_part = time_str[:last_period_pos]
ms_part = time_str[last_period_pos+1:]
ms = int(ms_part.ljust(3, '0')[:3])
time_components = list(map(int, hms_part.split(':')))
if len(time_components) == 3: h, m, s = time_components
elif len(time_components) == 2: m, s = time_components
elif len(time_components) == 1: s = time_components[0]
else:
components = list(map(int, time_str.split(':')))
# If the last component is > 59, it must be milliseconds
if len(components) >= 2 and components[-1] > 59:
ms = components[-1]
s = components[-2]
if len(components) == 3: m = components[0]
elif len(components) > 3: h, m = components[0], components[1] # For very long times
# Otherwise, it's a standard HH:MM:SS format
else:
if len(components) == 3: h, m, s = components
elif len(components) == 2: m, s = components
elif len(components) == 1: s = components[0]
return (h * 3600000) + (m * 60000) + (s * 1000) + ms
except (ValueError, IndexError) as e:
raise ValueError(f"Could not parse adaptive time string: '{time_str}'. Error: {e}")
def format_ms_to_srt_time(ms: int) -> str:
td = timedelta(milliseconds=ms)
total_seconds = int(td.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
milliseconds = td.microseconds // 1000
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
def adjust_plain_text_timestamps(text_block: str, offset_ms: int) -> str:
def replacer(match):
time_str = match.group(1)
try:
time_ms = parse_time_str_to_ms(time_str)
new_time_ms = time_ms + offset_ms
new_time_str = format_ms_to_srt_time(new_time_ms).replace(',', '.')
return f'[{new_time_str}]'
except ValueError:
return match.group(0)
timestamp_pattern = r'\[(\d{2}:\d{2}:\d{2}[,.]\d{3})\]'
return re.sub(timestamp_pattern, replacer, text_block)
def find_silence_points_webrtcvad(audio_segment: AudioSegment, min_silence_len_ms: int, vad_aggressiveness: int):
if audio_segment.frame_rate not in [8000, 16000, 32000, 48000]:
audio_segment = audio_segment.set_frame_rate(16000)
if audio_segment.channels > 1: audio_segment = audio_segment.set_channels(1)
if audio_segment.sample_width != 2: audio_segment = audio_segment.set_sample_width(2)
vad = webrtcvad.Vad(vad_aggressiveness)
frame_duration_ms = 30
frame_size_bytes = int(audio_segment.frame_rate * (frame_duration_ms / 1000.0) * 2)
silence_points_ms, silence_start_ms = [], None
raw_data, num_frames = audio_segment.raw_data, len(audio_segment.raw_data) // frame_size_bytes
for i in range(num_frames):
frame = raw_data[i*frame_size_bytes:(i+1)*frame_size_bytes]
if len(frame) < frame_size_bytes: break
is_speech, current_time_ms = vad.is_speech(frame, audio_segment.frame_rate), i * frame_duration_ms
if not is_speech:
if silence_start_ms is None: silence_start_ms = current_time_ms
elif silence_start_ms is not None:
if current_time_ms - silence_start_ms >= min_silence_len_ms:
silence_points_ms.append(silence_start_ms + (current_time_ms - silence_start_ms) // 2)
silence_start_ms = None
if silence_start_ms is not None and len(audio_segment) - silence_start_ms >= min_silence_len_ms:
silence_points_ms.append(silence_start_ms + (len(audio_segment) - silence_start_ms) // 2)
return silence_points_ms
def split_audio_webrtcvad(audio_segment, min_silence_len):
logging.info(f"Splitting with WebRTCVAD: Target Chunk {TARGET_CHUNK_DURATION_MIN}m, VAD Aggressiveness {VAD_AGGRESSIVENESS}")
silence_points = find_silence_points_webrtcvad(audio_segment, min_silence_len, VAD_AGGRESSIVENESS)
if not silence_points:
logging.warning("WebRTCVAD found no significant silences. Splitting into fixed chunks.")
return [audio_segment[i:i + TARGET_CHUNK_DURATION_MS] for i in range(0, len(audio_segment), TARGET_CHUNK_DURATION_MS)]
final_chunks, current_offset, total_length = [], 0, len(audio_segment)
while current_offset < total_length:
if total_length - current_offset <= MAX_SPLIT_SEARCH_END_MS:
final_chunks.append(audio_segment[current_offset:])
break
ideal_split_point = current_offset + TARGET_CHUNK_DURATION_MS
candidate_points = [p for p in silence_points if (current_offset + MIN_SPLIT_SEARCH_START_MS) <= p < (current_offset + MAX_SPLIT_SEARCH_END_MS)]
best_split_point = min(candidate_points, key=lambda p: abs(p - ideal_split_point)) if candidate_points else -1
split_at = best_split_point if best_split_point != -1 else ideal_split_point
final_chunks.append(audio_segment[current_offset:int(split_at)])
current_offset = int(split_at)
logging.info(f"File successfully split into {len(final_chunks)} chunks using WebRTCVAD.")
logging.info(f"Chunk durations (seconds): {[round(len(c) / 1000) for c in final_chunks]}")
return final_chunks
def _trim_markdown_fences(text: str) -> str:
"""Removes markdown code block fences from a string."""
text = text.strip()
if text.startswith("```") and text.endswith("```"):
text = text[3:-3].strip()
if text.startswith("json"):
text = text[4:].strip()
return text
def validate_and_correct_segments(segments_from_api, chunk_duration_ms):
corrected_segments, last_corrected_end_ms = [], 0
for seg in segments_from_api:
try:
start_ms = parse_time_str_to_ms(seg.get('start_time'))
end_ms = parse_time_str_to_ms(seg.get('end_time'))
if start_ms >= chunk_duration_ms:
logging.warning(f"Skipping segment with true hallucinatory start_time ({format_ms_to_srt_time(start_ms)}) outside of chunk duration ({format_ms_to_srt_time(chunk_duration_ms)}).")
continue
if end_ms > chunk_duration_ms: end_ms = chunk_duration_ms
if start_ms >= end_ms: end_ms = min(start_ms + 3000, chunk_duration_ms)
if start_ms < last_corrected_end_ms: start_ms = last_corrected_end_ms
if start_ms >= end_ms: continue
seg['start_time_relative'], seg['end_time_relative'] = start_ms, end_ms
corrected_segments.append(seg)
last_corrected_end_ms = end_ms
except (ValueError, TypeError, KeyError) as e:
logging.warning(f"Skipping segment due to invalid format or value: {seg}. Error: {e}")
continue
return corrected_segments
def transcribe_chunk(chunk_audio, api_key, system_prompt, pydantic_schema, model_name, user_prompt):
client = genai.Client(api_key=api_key)
buffer = io.BytesIO()
chunk_audio.export(buffer, format="wav")
audio_part = types.Part.from_bytes(data=buffer.getvalue(), mime_type='audio/wav')
contents: list[types.Content] = []
if user_prompt:
contents.append(types.Content(
role="user",
parts=[types.Part(text=user_prompt)]
))
contents.append(types.Content(
role="model",
parts=[types.Part(text="אני מוכן. אנא שלח את קטע השמע לתמלול.")]
))
prompt = "אנא תמלל את קטע השמע בהתאם להנחיית המערכת ולסכמת ה-JSON."
contents.append(types.Content(
role="user",
parts=[types.Part(text=prompt), audio_part]
))
response = client.models.generate_content(
model=model_name,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type="application/json",
response_schema=list[pydantic_schema]
)
)
cleaned_text = _trim_markdown_fences(response.text)
return json.loads(cleaned_text), None
def transcribe_chunk_plain_text(chunk_audio, api_key, system_prompt, model_name, user_prompt):
client = genai.Client(api_key=api_key)
buffer = io.BytesIO()
chunk_audio.export(buffer, format="wav")
audio_part = types.Part.from_bytes(data=buffer.getvalue(), mime_type='audio/wav')
contents: list[types.Content] = []
if user_prompt:
contents.append(types.Content(
role="user", parts=[types.Part(text=user_prompt)]
))
contents.append(types.Content(
role="model", parts=[types.Part(text="אני מוכן. אנא שלח את קטע השמע לעיבוד.")]
))
prompt = "אנא עבד את קטע השמע המצורף והפק את התמלול הערוך בפורמט טקסט כפי שהוגדר בהנחיית המערכת."
contents.append(types.Content(
role="user", parts=[types.Part(text=prompt), audio_part]
))
try:
response = client.models.generate_content(
model=model_name,
contents=contents,
config=types.GenerateContentConfig(system_instruction=system_prompt)
)
cleaned_text = _trim_markdown_fences(response.text)
return cleaned_text, None
except Exception as e:
logging.error(f"Error in transcribe_chunk_plain_text: {e}")
return None, str(e)
def generate_srt_content(segments):
lines = []
for i, seg in enumerate(segments, 1):
lines.extend([str(i), f"{format_ms_to_srt_time(seg['start_time_abs'])} --> {format_ms_to_srt_time(seg['end_time_abs'])}", seg['text'], ""])
return "\n".join(lines)
async def _transcribe_and_stream(api_key: str, file_content: bytes, model_name: str, user_prompt: str, output_format: str):
def send_event(type: str, message: str = "", percent: int = 0, data: str = ""):
return json.dumps({"type": type, "message": message, "percent": percent, "data": data}) + "\n\n"
try:
prompts = load_prompts()
yield send_event("progress", "מעבד את קובץ המדיה...", 5)
audio = AudioSegment.from_file(io.BytesIO(file_content))
if len(audio) < MIN_SPLIT_SEARCH_START_MS:
chunks = [audio]
yield send_event("progress", f"אורך פס הקול {len(audio) / 60000:.1f} דקות. הקובץ קצר ויעובד כמקטע יחיד.", 15)
else:
yield send_event("progress", f"אורך פס הקול {len(audio) / 60000:.1f} דקות. מבצע חלוקה...", 15)
chunks = await asyncio.to_thread(split_audio_webrtcvad, audio, MIN_SILENCE_LEN_MS)
if not chunks: raise ValueError("לא נוצרו מקטעי שמע לעיבוד.")
yield send_event("progress", f"הקובץ חולק ל-{len(chunks)} מקטעים. מתחיל תמלול...", 20)
offset = 0
if output_format == 'srt':
pydantic_schema = TranscriptionSegment
system_prompt = prompts['system_prompt']
all_segs = []
for i, ch in enumerate(chunks):
progress_percent = 20 + int(((i + 1) / len(chunks)) * 75)
yield send_event("progress", f"מתמלל מקטע {i+1} מתוך {len(chunks)}...", progress_percent)
data, error_msg = await asyncio.to_thread(transcribe_chunk, ch, api_key, system_prompt, pydantic_schema, model_name, user_prompt)
if error_msg: raise ValueError(f"שגיאה בעיבוד מקטע {i+1}: {error_msg}")
if data and isinstance(data, list):
corrected_segments = validate_and_correct_segments(data, len(ch))
for seg in corrected_segments:
seg['start_time_abs'] = seg['start_time_relative'] + offset
seg['end_time_abs'] = seg['end_time_relative'] + offset
all_segs.append(seg)
offset += len(ch)
if not all_segs: raise ValueError("התמלול נכשל. לא נוצר תוכן תקני.")
yield send_event("result", "התהליך הושלם בהצלחה!", 100, data=generate_srt_content(all_segs))
elif output_format == 'plain_text':
system_prompt = prompts['plain_text_prompt']
all_text_blocks = []
for i, ch in enumerate(chunks):
progress_percent = 20 + int(((i + 1) / len(chunks)) * 75)
yield send_event("progress", f"מעבד מקטע טקסט {i+1} מתוך {len(chunks)}...", progress_percent)
text_block, error_msg = await asyncio.to_thread(transcribe_chunk_plain_text, ch, api_key, system_prompt, model_name, user_prompt)
if error_msg: raise ValueError(f"שגיאה בעיבוד מקטע {i+1}: {error_msg}")
if text_block:
adjusted_block = adjust_plain_text_timestamps(text_block, offset)
all_text_blocks.append(adjusted_block)
offset += len(ch)
if not all_text_blocks: raise ValueError("התמלול נכשל. לא נוצר תוכן.")
final_text = "\n\n".join(all_text_blocks).strip()
standardized_text = re.sub(r'\[(\d{2}:\d{2}:\d{2}),(\d{3})\]', r'[\1.\2]', final_text)
if standardized_text.startswith("```") and standardized_text.endswith("```"):
standardized_text = standardized_text[3:-3].strip()
total_duration_ms = len(audio)
final_data_with_duration = f"{standardized_text}\n<!-- DURATION_MS:{total_duration_ms} -->"
yield send_event("result", "התהליך הושלם בהצלחה!", 100, data=final_data_with_duration)
except Exception as e:
logging.error(f"Streaming transcription failed: {e}", exc_info=True)
yield send_event("error", f"אירעה שגיאה: {e}", 100)
def parse_srt_for_text_conversion(srt_content: str):
segments = []
blocks = srt_content.strip().replace('\r\n', '\n').split('\n\n')
for block in blocks:
if not block.strip():
continue
lines = block.split('\n')
if len(lines) >= 2:
try:
int(lines[0])
if '-->' in lines[1]:
start_str, end_str = lines[1].split(' --> ')
text = '\n'.join(lines[2:])
segments.append({
'start_ms': parse_time_str_to_ms(start_str),
'end_ms': parse_time_str_to_ms(end_str),
'text': text.strip()
})
except (ValueError, IndexError):
logging.warning(f"Skipping malformed SRT block during text conversion: {block}")
continue
return segments
def convert_srt_to_formatted_text(srt_content: str) -> str:
segments = parse_srt_for_text_conversion(srt_content)
if not segments:
return ""
segments.sort(key=lambda s: s['start_ms'])
result_parts = []
if segments:
result_parts.append(segments[0]['text'])
for i in range(1, len(segments)):
prev_seg = segments[i-1]
curr_seg = segments[i]
time_gap_s = (curr_seg['start_ms'] - prev_seg['end_ms']) / 1000.0
prev_text = prev_seg['text']
if time_gap_s > 1.5 or prev_text.endswith(('?', '!')):
prefix = '\n\n'
elif not prev_text.endswith(('.', ',')):
prefix = ' '
# Rule 3: Default connection for all other cases (e.g., after ./,)
else:
prefix = ' '
result_parts.append(prefix + curr_seg['text'])
return "".join(result_parts)
def convert_text_to_srt(text_content: str) -> str:
duration_ms = None
duration_match = re.search(r'<!-- DURATION_MS:(\d+) -->\s*$', text_content)
if duration_match:
duration_ms = int(duration_match.group(1))
text_content = text_content[:duration_match.start()].strip()
pattern = re.compile(r'\[(\d{2}:\d{2}:\d{2}[.,]\d{3})\]\s*([\s\S]*?)(?=\s*\[\d{2}:\d{2}:\d{2}[.,]\d{3}\]|$)', re.DOTALL)
matches = pattern.findall(text_content)
if not matches:
return ""
segments = []
for i, (start_str, text) in enumerate(matches):
start_ms = parse_time_str_to_ms(start_str)
cleaned_text = text.strip()
if cleaned_text:
segments.append({
'id': i + 1,
'start_ms': start_ms,
'text': cleaned_text
})
for i in range(len(segments)):
if i < len(segments) - 1:
segments[i]['end_ms'] = segments[i+1]['start_ms'] - 1
else:
if duration_ms is not None:
segments[i]['end_ms'] = duration_ms
else:
segments[i]['end_ms'] = segments[i]['start_ms'] + 5000
if segments[i]['end_ms'] <= segments[i]['start_ms']:
segments[i]['end_ms'] = segments[i]['start_ms'] + 1000
srt_output = []
for seg in segments:
start_time_srt = format_ms_to_srt_time(seg['start_ms'])
end_time_srt = format_ms_to_srt_time(seg['end_ms'])
srt_output.append(f"{seg['id']}\n{start_time_srt} --> {end_time_srt}\n{seg['text']}\n")
return "\n".join(srt_output)
class SrtContent(BaseModel):
srt_data: str
class PlainTextContent(BaseModel):
text_data: str
@app.post("/convert-to-text")
async def handle_text_conversion(payload: SrtContent):
if not payload.srt_data:
raise HTTPException(status_code=400, detail="SRT data is missing.")
try:
formatted_text = convert_srt_to_formatted_text(payload.srt_data)
return {"text": formatted_text}
except Exception as e:
logging.error(f"Error during SRT to text conversion: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to convert SRT to text.")
@app.post("/convert-to-srt")
async def handle_srt_conversion(payload: PlainTextContent):
if not payload.text_data:
raise HTTPException(status_code=400, detail="Text data is missing.")
try:
srt_formatted_text = convert_text_to_srt(payload.text_data)
return {"srt": srt_formatted_text}
except Exception as e:
logging.error(f"Error during Text to SRT conversion: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to convert text to SRT.")
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/transcribe-stream")
async def handle_transcription_stream(
api_key: str = Form(...),
model_name: str = Form(...),
output_format: str = Form(...),
user_prompt: str = Form(""),
audio_file: UploadFile = File(...)
):
if not all([api_key, model_name, audio_file, output_format]):
raise HTTPException(status_code=400, detail="Required form fields are missing.")
file_content = await audio_file.read()
return StreamingResponse(_transcribe_and_stream(api_key, file_content, model_name, user_prompt, output_format), media_type="text/event-stream")