Skip to content

Commit ee6308b

Browse files
committed
Auto save
1 parent 2b220a6 commit ee6308b

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

Diff for: config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
DEBUG = False
44
HERE = Path(__file__).parent.absolute()
5-
DOWNLOAD_DIR = HERE / ".download"
6-
OUTPUT_DIR = HERE / "output"
5+
DOWNLOAD_DIR: str | Path = HERE / ".download"
6+
OUTPUT_DIR: str | Path = HERE / "output"
77
WHISPER_MODEL = "small"
88
FFMPEG_BIN = "ffmpeg"
99
FFMPEG_PREFIX_OPTS = "-hide_banner -loglevel error -y"

Diff for: core/utils.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,23 @@ def format_timestamp(
2424
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
2525

2626

27-
def write_vtt(transcript: list[dict|Segment], audio: str, language: str | None = None) -> str:
27+
def write_vtt(
28+
transcript: list[dict | Segment], audio: str, language: str | None = None
29+
) -> str:
2830
if language is not None:
2931
vtt_filename = f"{audio.removesuffix('.mp3')}.{language}.vtt"
3032
else:
3133
vtt_filename = f"{audio.removesuffix('.mp3')}.vtt"
3234
return _write_vtt(transcript, vtt_filename)
3335

3436

35-
def _write_vtt(transcript: list[dict|Segment], vtt_filename: str) -> str:
36-
37+
def _write_vtt(transcript: list[dict | Segment], vtt_filename: str) -> str:
3738
with open(vtt_filename, "w") as f:
3839
f.write("WEBVTT\n")
3940
for segment in transcript:
4041
if isinstance(segment, Segment):
4142
segment = segment.model_dump()
43+
print(segment, "segment")
4244
start = format_timestamp(segment["start"])
4345
end = format_timestamp(segment["end"])
4446
f.write(f"\n{start} --> {end}\n{segment['text'].strip()}\n")
@@ -100,8 +102,15 @@ def parse_vtt(text: str) -> list[Segment]:
100102
timestamp, text, _ = list(items)
101103
except ValueError:
102104
timestamp, text = list(items)
103-
start, end = timestamp.split('-->')
104-
texts.append(Segment(timestamp=timestamp, start=get_seconds(start), end=get_seconds(end), text=text))
105+
start, end = timestamp.split("-->")
106+
texts.append(
107+
Segment(
108+
timestamp=timestamp,
109+
start=get_seconds(start),
110+
end=get_seconds(end),
111+
text=text,
112+
)
113+
)
105114
return texts
106115

107116

Diff for: webui.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# TODO:
2+
# 1. auto save
13
import mimetypes
24
import os.path
35

@@ -19,13 +21,14 @@
1921
)
2022
from core.audio import generate_audio, generate_subtitle
2123
from core.downloader import download
24+
from core.utils import _write_vtt, parse_vtt
2225
from core.video import generate_video
23-
from core.utils import parse_vtt, _write_vtt
2426

2527
if "subtitle_path" not in st.session_state:
2628
st.session_state.update(
2729
{
2830
"segments": [],
31+
"replaced_segments": [],
2932
"current_seg_index": 0,
3033
"current_time": 0,
3134
"vtt_content": "",
@@ -46,6 +49,7 @@ def wrapper(label, *args, **kwargs):
4649
widget_value = f(*args, **kwargs)
4750
st.session_state.widget_values[label] = widget_value
4851
return widget_value
52+
4953
return wrapper
5054

5155

@@ -80,8 +84,6 @@ def wrapper(label, *args, **kwargs):
8084

8185
st.write("## Subtitle")
8286

83-
subtitle_auto_save = st.checkbox('Audo Save')
84-
8587
language: str = (
8688
st.selectbox(
8789
"Language",
@@ -109,12 +111,14 @@ def wrapper(label, *args, **kwargs):
109111

110112
video_path = st.text_input("Video Path or URL")
111113
subtitle_path = st.text_input("VTT Path or URL", value=st.session_state.subtitle_path)
112-
if subtitle_path:
114+
115+
if subtitle_path and not st.session_state.segments:
113116
with open(subtitle_path) as f:
114117
content = f.read()
115118
st.session_state.vtt_content = content
116119
if not st.session_state.segments:
117120
st.session_state.segments = parse_vtt(content)
121+
st.session_state.replaced_segments = parse_vtt(content)
118122

119123

120124
def get_current_vtt_content() -> str:
@@ -123,11 +127,12 @@ def get_current_vtt_content() -> str:
123127
return ""
124128
if not values["video_component"]:
125129
return ""
126-
current_time = values["video_component"]['current_time']
130+
current_time = values["video_component"]["current_time"]
127131
for index, seg in enumerate(st.session_state.segments):
128132
if seg.start <= current_time <= seg.end:
129-
st.session_state['current_seg_index'] = index
133+
st.session_state["current_seg_index"] = index
130134
return seg.text
135+
return ""
131136

132137

133138
def subtitle_callback(path: str) -> None:
@@ -152,8 +157,8 @@ def save_callback() -> None:
152157
if not subtitle_path:
153158
st.error("Subtitle is required.")
154159
else:
155-
_write_vtt(st.session_state.segments, subtitle_path)
156-
st.success(f"Subtitle saved")
160+
_write_vtt(st.session_state.replaced_segments, subtitle_path)
161+
st.session_state.segments = st.session_state.replaced_segments
157162

158163

159164
def preview_callback() -> None:
@@ -170,6 +175,7 @@ def preview_callback() -> None:
170175
current_time="",
171176
)
172177

178+
173179
def generate_callback() -> None:
174180
if not video_path:
175181
st.error("Video Path or URL is required.")
@@ -216,7 +222,4 @@ def generate_callback() -> None:
216222

217223
current_vtt = st.text_input("Subtitle", value=get_current_vtt_content())
218224
if st.session_state.segments and current_vtt:
219-
st.session_state.segments[st.session_state.current_seg_index].text = current_vtt
220-
221-
if subtitle_auto_save and subtitle_path and st.session_state.segments:
222-
_write_vtt(st.session_state.segments, subtitle_path)
225+
st.session_state.replaced_segments[st.session_state.current_seg_index].text = current_vtt

0 commit comments

Comments
 (0)