diff options
Diffstat (limited to 'tagrss.py')
-rw-r--r-- | tagrss.py | 83 |
1 files changed, 36 insertions, 47 deletions
@@ -22,9 +22,34 @@ class SqliteMissingForeignKeySupportError(Exception): pass +class Sqlite3NotSerializedModeError(Exception): + pass + + +def fetch_parsed_feed(feed_source: str) -> tuple[feedparser.FeedParserDict, int]: + response = requests.get(feed_source) + epoch_downloaded: int = int(time.time()) + if response.status_code != requests.codes.ok: + raise FeedFetchError(feed_source, response.status_code) + try: + base: str = response.headers["Content-Location"] + except KeyError: + base: str = feed_source + parsed = feedparser.parse( + io.BytesIO(bytes(response.text, encoding="utf-8")), + response_headers={"Content-Location": base}, + ) + return (parsed, epoch_downloaded) + + class TagRss: def __init__(self, *, storage_path: str | pathlib.Path): - self.connection: sqlite3.Connection = sqlite3.connect(storage_path) + if sqlite3.threadsafety != 3: + raise Sqlite3NotSerializedModeError + + self.connection: sqlite3.Connection = sqlite3.connect( + storage_path, check_same_thread=False + ) with self.connection: with open("setup.sql", "r") as setup_script: @@ -32,21 +57,16 @@ class TagRss: if (1,) not in self.connection.execute("PRAGMA foreign_keys;").fetchmany(1): raise SqliteMissingForeignKeySupportError - def add_feed(self, feed_source: str, tags: list[str]) -> None: - response = requests.get(feed_source) - epoch_downloaded: int = int(time.time()) - if response.status_code != requests.codes.ok: - raise FeedFetchError(feed_source, response.status_code) - try: - base: str = response.headers["Content-Location"] - except KeyError: - base: str = feed_source - parsed = feedparser.parse( - io.BytesIO(bytes(response.text, encoding="utf-8")), - response_headers={"Content-Location": base}, - ) + def add_feed( + self, + *, + feed_source: str, + parsed_feed: feedparser.FeedParserDict, + epoch_downloaded: int, + tags: list[str], + ) -> None: with self.connection: - feed_title: str = parsed.feed.get("title", "") + feed_title: str = parsed_feed.feed.get("title", "") # type: ignore try: self.connection.execute( "INSERT INTO feeds(source, title) VALUES(?, ?);", @@ -61,7 +81,7 @@ class TagRss: "INSERT INTO feed_tags(feed_id, tag) VALUES(?, ?);", ((feed_id, tag) for tag in tags), ) - self.store_feed_entries(feed_id, parsed, epoch_downloaded) + self.store_feed_entries(feed_id, parsed_feed, epoch_downloaded) def get_entries( self, *, limit: int, offset: int = 0 @@ -152,18 +172,6 @@ class TagRss: ) return feeds - def get_all_feed_ids(self): - def inner(): - with self.connection: - resp = self.connection.execute("SELECT id FROM feeds;") - while True: - row = resp.fetchone() - if not row: - break - yield row[0] - - return inner - def get_entry_count(self) -> int: with self.connection: return self.connection.execute("SELECT count from entry_count;").fetchone()[ @@ -176,25 +184,6 @@ class TagRss: 0 ] - def fetch_new_feed_entries(self, feed_id: int) -> None: - with self.connection: - feed_source: str = self.connection.execute( - "SELECT source FROM feeds WHERE id = ?;", (feed_id,) - ).fetchone()[0] - response = requests.get(feed_source) - epoch_downloaded: int = int(time.time()) - if response.status_code != requests.codes.ok: - raise FeedFetchError(feed_source, response.status_code) - try: - base: str = response.headers["Content-Location"] - except KeyError: - base: str = feed_source - parsed = feedparser.parse( - io.BytesIO(bytes(response.text, encoding="utf-8")), - response_headers={"Content-Location": base}, - ) - self.store_feed_entries(feed_id, parsed, epoch_downloaded) - def store_feed_entries(self, feed_id: int, parsed_feed, epoch_downloaded: int): for entry in reversed(parsed_feed.entries): link: str = entry.get("link", None) |