aboutsummaryrefslogtreecommitdiff
path: root/tagrss.py
diff options
context:
space:
mode:
Diffstat (limited to 'tagrss.py')
-rw-r--r--tagrss.py62
1 files changed, 54 insertions, 8 deletions
diff --git a/tagrss.py b/tagrss.py
index 7c37bec..6ac5270 100644
--- a/tagrss.py
+++ b/tagrss.py
@@ -89,13 +89,34 @@ class TagRss:
self.store_feed_entries(feed_id, parsed_feed, epoch_downloaded)
def get_entries(
- self, *, limit: int, offset: int = 0
+ self,
+ *,
+ limit: int,
+ offset: int = 0,
+ included_feeds: typing.Optional[typing.Collection[int]] = None,
+ included_tags: typing.Optional[typing.Collection[str]] = None,
) -> list[dict[str, typing.Any]]:
+ where_clause: str = ""
+ if included_feeds or included_tags:
+ where_clause = "WHERE 1"
+ if included_feeds:
+ where_clause += f" AND feed_id IN ({','.join('?' * len(included_feeds))})"
+ if included_tags:
+ where_clause += (
+ " AND feed_id IN (SELECT feed_id FROM feed_tags WHERE tag = ?)"
+ * len(included_tags)
+ )
with self.connection:
resp = self.connection.execute(
- "SELECT id, feed_id, title, link, epoch_published, epoch_updated FROM entries \
+ f"SELECT id, feed_id, title, link, epoch_published, epoch_updated FROM entries \
+ {where_clause} \
ORDER BY id DESC LIMIT ? OFFSET ?;",
- (limit, offset),
+ (
+ *(included_feeds if included_feeds else ()),
+ *(included_tags if included_tags else ()),
+ limit,
+ offset,
+ ),
).fetchall()
entries = []
@@ -177,11 +198,36 @@ class TagRss:
)
return feeds
- def get_entry_count(self) -> int:
- with self.connection:
- return self.connection.execute("SELECT count from entry_count;").fetchone()[
- 0
- ]
+ def get_entry_count(
+ self,
+ *,
+ included_feeds: typing.Optional[typing.Collection[int]] = None,
+ included_tags: typing.Optional[typing.Collection[str]] = None,
+ ) -> int:
+ if not (included_feeds or included_tags):
+ with self.connection:
+ return self.connection.execute(
+ "SELECT count from entry_count;"
+ ).fetchone()[0]
+ else:
+ where_clause: str = "WHERE 1"
+ if included_feeds:
+ where_clause += (
+ f" AND feed_id IN ({','.join('?' * len(included_feeds))})"
+ )
+ if included_tags:
+ where_clause += (
+ " AND feed_id IN (SELECT feed_id FROM feed_tags WHERE tag = ?)"
+ * len(included_tags)
+ )
+ with self.connection:
+ return self.connection.execute(
+ f"SELECT COUNT(*) FROM entries {where_clause};",
+ (
+ *(included_feeds if included_feeds else ()),
+ *(included_tags if included_tags else ()),
+ ),
+ ).fetchone()[0]
def get_feed_count(self) -> int:
with self.connection: