diff options
Diffstat (limited to 'tagrss.py')
-rw-r--r-- | tagrss.py | 65 |
1 files changed, 55 insertions, 10 deletions
@@ -185,13 +185,32 @@ class TagRss: self.connection.execute("DELETE FROM feeds WHERE id = ?;", (feed_id,)) def get_feeds( - self, *, limit: int, offset: int = 0, get_tags: bool = False + self, + *, + limit: int, + offset: int = 0, + included_feeds: typing.Optional[list[int]] = None, + included_tags: typing.Optional[list[str]] = None, + get_tags: bool = False, ) -> list[dict[str, typing.Any]]: + where_clause = "WHERE 1" + if included_feeds: + where_clause += f" AND id IN ({','.join('?' * len(included_feeds))})" + if included_tags: + where_clause += " AND id IN (SELECT id FROM feed_tags WHERE tag = ?)" * len( + included_tags + ) with self.connection: resp = self.connection.execute( - "SELECT id, source, title FROM feeds \ + f"SELECT id, source, title FROM feeds \ + {where_clause} \ ORDER BY id ASC LIMIT ? OFFSET ?;", - (limit, offset), + ( + *(included_feeds if included_feeds else ()), + *(included_tags if included_tags else ()), + limit, + offset, + ), ).fetchall() feeds: dict[int, dict[str, typing.Any]] = {} for row in resp: @@ -201,7 +220,7 @@ class TagRss: } if get_tags: feed_ids = feeds.keys() - placeholder_str = ",".join(["?"] * len(feed_ids)) + placeholder_str = ",".join("?" * len(feed_ids)) with self.connection: resp = self.connection.execute( f"SELECT feed_id, tag FROM feed_tags WHERE feed_id in ({placeholder_str});", @@ -220,7 +239,10 @@ class TagRss: "title": item[1]["title"], } if get_tags: - feed["tags"] = item[1]["tags"] + try: + feed["tags"] = item[1]["tags"] + except KeyError: + feed["tags"] = [] result.append(feed) return result @@ -255,11 +277,34 @@ class TagRss: ), ).fetchone()[0] - def get_feed_count(self) -> int: - with self.connection: - return self.connection.execute("SELECT count from feed_count;").fetchone()[ - 0 - ] + def get_feed_count( + self, + *, + included_feeds: typing.Optional[typing.Collection[int]] = None, + included_tags: typing.Optional[typing.Collection[str]] = None, + ) -> int: + if not (included_feeds or included_tags): + with self.connection: + return self.connection.execute( + "SELECT count from feed_count;" + ).fetchone()[0] + else: + where_clause: str = "WHERE 1" + if included_feeds: + where_clause += f" AND id IN ({','.join('?' * len(included_feeds))})" + if included_tags: + where_clause += ( + " AND id IN (SELECT id FROM feed_tags WHERE tag = ?)" + * len(included_tags) + ) + with self.connection: + return self.connection.execute( + f"SELECT COUNT(*) FROM feeds {where_clause}", + ( + *(included_feeds if included_feeds else ()), + *(included_tags if included_tags else ()), + ), + ).fetchone()[0] def store_feed_entries(self, feed_id: int, parsed_feed, epoch_downloaded: int): for entry in reversed(parsed_feed.entries): |