aboutsummaryrefslogtreecommitdiff
path: root/tagrss.py
diff options
context:
space:
mode:
Diffstat (limited to 'tagrss.py')
-rw-r--r--tagrss.py65
1 files changed, 55 insertions, 10 deletions
diff --git a/tagrss.py b/tagrss.py
index 41c8a70..5d5ad01 100644
--- a/tagrss.py
+++ b/tagrss.py
@@ -185,13 +185,32 @@ class TagRss:
self.connection.execute("DELETE FROM feeds WHERE id = ?;", (feed_id,))
def get_feeds(
- self, *, limit: int, offset: int = 0, get_tags: bool = False
+ self,
+ *,
+ limit: int,
+ offset: int = 0,
+ included_feeds: typing.Optional[list[int]] = None,
+ included_tags: typing.Optional[list[str]] = None,
+ get_tags: bool = False,
) -> list[dict[str, typing.Any]]:
+ where_clause = "WHERE 1"
+ if included_feeds:
+ where_clause += f" AND id IN ({','.join('?' * len(included_feeds))})"
+ if included_tags:
+ where_clause += " AND id IN (SELECT id FROM feed_tags WHERE tag = ?)" * len(
+ included_tags
+ )
with self.connection:
resp = self.connection.execute(
- "SELECT id, source, title FROM feeds \
+ f"SELECT id, source, title FROM feeds \
+ {where_clause} \
ORDER BY id ASC LIMIT ? OFFSET ?;",
- (limit, offset),
+ (
+ *(included_feeds if included_feeds else ()),
+ *(included_tags if included_tags else ()),
+ limit,
+ offset,
+ ),
).fetchall()
feeds: dict[int, dict[str, typing.Any]] = {}
for row in resp:
@@ -201,7 +220,7 @@ class TagRss:
}
if get_tags:
feed_ids = feeds.keys()
- placeholder_str = ",".join(["?"] * len(feed_ids))
+ placeholder_str = ",".join("?" * len(feed_ids))
with self.connection:
resp = self.connection.execute(
f"SELECT feed_id, tag FROM feed_tags WHERE feed_id in ({placeholder_str});",
@@ -220,7 +239,10 @@ class TagRss:
"title": item[1]["title"],
}
if get_tags:
- feed["tags"] = item[1]["tags"]
+ try:
+ feed["tags"] = item[1]["tags"]
+ except KeyError:
+ feed["tags"] = []
result.append(feed)
return result
@@ -255,11 +277,34 @@ class TagRss:
),
).fetchone()[0]
- def get_feed_count(self) -> int:
- with self.connection:
- return self.connection.execute("SELECT count from feed_count;").fetchone()[
- 0
- ]
+ def get_feed_count(
+ self,
+ *,
+ included_feeds: typing.Optional[typing.Collection[int]] = None,
+ included_tags: typing.Optional[typing.Collection[str]] = None,
+ ) -> int:
+ if not (included_feeds or included_tags):
+ with self.connection:
+ return self.connection.execute(
+ "SELECT count from feed_count;"
+ ).fetchone()[0]
+ else:
+ where_clause: str = "WHERE 1"
+ if included_feeds:
+ where_clause += f" AND id IN ({','.join('?' * len(included_feeds))})"
+ if included_tags:
+ where_clause += (
+ " AND id IN (SELECT id FROM feed_tags WHERE tag = ?)"
+ * len(included_tags)
+ )
+ with self.connection:
+ return self.connection.execute(
+ f"SELECT COUNT(*) FROM feeds {where_clause}",
+ (
+ *(included_feeds if included_feeds else ()),
+ *(included_tags if included_tags else ()),
+ ),
+ ).fetchone()[0]
def store_feed_entries(self, feed_id: int, parsed_feed, epoch_downloaded: int):
for entry in reversed(parsed_feed.entries):