aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArjun Satarkar <me@arjunsatarkar.net>2024-12-21 10:49:13 +0000
committerArjun Satarkar <me@arjunsatarkar.net>2024-12-21 10:49:13 +0000
commit4ffc19bbbb2395b8204b852df325fe4b3c07e273 (patch)
tree18dcf0be788bf3a9823572a83c3bf7812d1210f8
parentc36ce550b282f11bac802e9224ed9927b24876a8 (diff)
downloadsrtfilter-4ffc19bbbb2395b8204b852df325fe4b3c07e273.tar
srtfilter-4ffc19bbbb2395b8204b852df325fe4b3c07e273.tar.gz
srtfilter-4ffc19bbbb2395b8204b852df325fe4b3c07e273.zip
Add type checking, rework parsing, fix bugs
-rw-r--r--parse_srt.py75
-rw-r--r--pyrightconfig.json3
-rwxr-xr-xrebreak_lines.py12
-rw-r--r--requirements.txt3
4 files changed, 42 insertions, 51 deletions
diff --git a/parse_srt.py b/parse_srt.py
index acd72c0..2e710ea 100644
--- a/parse_srt.py
+++ b/parse_srt.py
@@ -1,69 +1,45 @@
from __future__ import annotations
import dataclasses
-import enum
-import itertools
import re
-from typing import List
@dataclasses.dataclass
class Event:
- start: str | None = None
- end: str | None = None
- content: str | None = None
+ start: str
+ end: str
+ content: str
class SRT:
def __init__(self):
- self.events: List[Event] = []
+ self.events: list[Event] = []
@staticmethod
def from_str(text: str) -> SRT:
- class ParseState(enum.Enum):
- COUNTER = enum.auto()
- TIMING = enum.auto()
- CONTENT = enum.auto()
-
- PARSE_STATES = itertools.cycle(iter(ParseState))
TIMESTAMP_CAPTURE = r"(\d\d:\d\d:\d\d,\d\d\d)"
TIMING_REGEX = rf"{TIMESTAMP_CAPTURE} --> {TIMESTAMP_CAPTURE}"
srt = SRT()
- lines = text.split("\n")
counter = 1
- state = next(PARSE_STATES)
- event = Event()
- for line_num, line in enumerate(lines, 1):
- if not line:
- match state:
- case ParseState.CONTENT:
- srt.events.append(event)
- event = Event()
- state = next(PARSE_STATES)
- case ParseState.COUNTER:
- pass
- case _:
- raise ParseError(f"Unexpected blank line (line {line_num})")
- continue
- match state:
- case ParseState.COUNTER:
- if int(line) == counter:
- counter += 1
- state = next(PARSE_STATES)
- else:
- raise ParseError(
- f"Invalid counter, expected {counter} (line {line_num})"
- )
- case ParseState.TIMING:
- match = re.fullmatch(TIMING_REGEX, line)
- if match is None:
- raise ParseError(f"Invalid timing info (line {line_num})")
- event.start, event.end = match[1], match[2]
- state = next(PARSE_STATES)
- case ParseState.CONTENT:
- event.content = (
- event.content if event.content is not None else ""
- ) + f"{line}\n"
+ events = [event for event in text.split("\n\n") if event.strip()]
+ for event_str in events:
+ lines = event_str.split("\n")
+ counter_str, timing_str, content_lines = lines[0], lines[1], lines[2:]
+
+ if int(counter_str) != counter:
+ raise ParseError(
+ f"Invalid counter '{counter_str}'; expected {counter}", event_str
+ )
+ counter += 1
+
+ match = re.fullmatch(TIMING_REGEX, timing_str)
+ if match is None:
+ raise ParseError(f"Invalid timing info '{timing_str}'", event_str)
+
+ content = "\n".join(content_lines + [""])
+
+ srt.events.append(Event(match[1], match[2], content))
+
return srt
def __str__(self):
@@ -76,4 +52,7 @@ class SRT:
class ParseError(Exception):
- pass
+ def __init__(self, reason: str, event_str: str):
+ super().__init__(f"{reason}\nwhile parsing event:\n{event_str}")
+ self.reason = reason
+ self.event_str = event_str
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000..864fa90
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,3 @@
+{
+ "strict": ["."]
+} \ No newline at end of file
diff --git a/rebreak_lines.py b/rebreak_lines.py
index 41594a3..54473e2 100755
--- a/rebreak_lines.py
+++ b/rebreak_lines.py
@@ -13,7 +13,7 @@ import click
import parse_srt
import math
import sys
-from typing import List
+import typing
# May still be exceeded if there are no word boundaries to wrap at
MAX_LINE_LENGTH = 42
@@ -33,21 +33,27 @@ def main(in_file_path: str):
def rebreak(text: str) -> str:
- get_target_line_num = lambda length: math.ceil(length / MAX_LINE_LENGTH)
+ get_target_line_num: typing.Callable[[int], int] = lambda length: math.ceil(
+ length / MAX_LINE_LENGTH
+ )
text = " ".join(text.split("\n"))
target_line_num = get_target_line_num(len(text))
- lines: List[str] = []
+ lines: list[str] = []
for _ in range(target_line_num):
partition_at = round(len(text) / target_line_num) - 1
# Move to a word boundary
+ steps_backward = 0
for steps_backward, c in enumerate(text[partition_at::-1]):
if c.isspace():
break
if partition_at - steps_backward != 0:
partition_at -= steps_backward
else:
+ # Moving the partition backward would give us an empty line, so
+ # move forward instead to ensure we always make progress.
+ steps_forward = 0
for steps_forward, c in enumerate(text[partition_at:]):
if c.isspace():
break
diff --git a/requirements.txt b/requirements.txt
index befb51e..c948824 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,9 @@
black==24.10.0
click==8.1.7
mypy-extensions==1.0.0
+nodeenv==1.9.1
packaging==24.2
pathspec==0.12.1
platformdirs==4.3.6
+pyright==1.1.391
+typing_extensions==4.12.2