From e9611fc8c1041a7075c608daa82f21bb6e19b65b Mon Sep 17 00:00:00 2001 From: Untone Date: Mon, 25 Mar 2024 20:28:58 +0300 Subject: [PATCH] feed-filters-fix --- resolvers/reader.py | 204 ++++++++++++++++++++++---------------------- 1 file changed, 104 insertions(+), 100 deletions(-) diff --git a/resolvers/reader.py b/resolvers/reader.py index bf9c4afe..dcca3b02 100644 --- a/resolvers/reader.py +++ b/resolvers/reader.py @@ -1,5 +1,5 @@ from sqlalchemy import bindparam, distinct, or_, text -from sqlalchemy.orm import aliased, joinedload, selectinload +from sqlalchemy.orm import aliased, joinedload from sqlalchemy.sql.expression import and_, asc, case, desc, func, nulls_last, select from orm.author import Author, AuthorFollower @@ -16,8 +16,44 @@ from services.viewed import ViewedStorage from services.logger import root_logger as logger +def query_shouts(): + return select(Shout).options(joinedload(Shout.authors), joinedload(Shout.topics)).where( + and_( + Shout.published_at.is_not(None), + Shout.deleted_at.is_(None), + ) + ) + + +def filter_my(info, session, q): + reader_id = None + user_id = info.context.get('user_id') + if user_id: + reader = session.query(Author).filter(Author.user == user_id).first() + if reader: + reader_followed_authors = select(AuthorFollower.author).where( + AuthorFollower.follower == reader.id + ) + reader_followed_topics = select(TopicFollower.topic).where( + TopicFollower.follower == reader.id + ) + + subquery = ( + select(Shout.id) + .where(Shout.id == ShoutAuthor.shout) + .where(Shout.id == ShoutTopic.shout) + .where( + (ShoutAuthor.author.in_(reader_followed_authors)) + | (ShoutTopic.topic.in_(reader_followed_topics)) + ) + ) + q = q.filter(Shout.id.in_(subquery)) + reader_id = reader.id + return q, reader_id + + def apply_filters(q, filters, author_id=None): - if filters.get('reacted') and author_id: + if filters.get('reacted'): q.join(Reaction, Reaction.created_by == author_id) by_featured = filters.get('featured') @@ -43,11 +79,11 @@ def apply_filters(q, filters, author_id=None): @query.field('get_shout') async def get_shout(_, info, slug: str): with local_session() as session: - q = select(Shout).options(joinedload(Shout.authors), joinedload(Shout.topics)) + q = query_shouts() aliased_reaction = aliased(Reaction) q = add_reaction_stat_columns(q, aliased_reaction) q = q.filter(Shout.slug == slug) - q = q.filter(Shout.deleted_at.is_(None)).group_by(Shout.id) + q = q.group_by(Shout.id) results = session.execute(q).first() if results: @@ -122,11 +158,7 @@ async def load_shouts_by(_, _info, options): """ # base - q = ( - select(Shout) - .options(joinedload(Shout.authors), joinedload(Shout.topics)) - .where(and_(Shout.deleted_at.is_(None), Shout.published_at.is_not(None))) - ) + q = query_shouts() # stats aliased_reaction = aliased(Reaction) @@ -194,97 +226,73 @@ async def load_shouts_by(_, _info, options): @query.field('load_shouts_feed') @login_required async def load_shouts_feed(_, info, options): - user_id = info.context['user_id'] - shouts = [] with local_session() as session: - reader = session.query(Author).filter(Author.user == user_id).first() - if reader: - reader_followed_authors = select(AuthorFollower.author).where( - AuthorFollower.follower == reader.id - ) - reader_followed_topics = select(TopicFollower.topic).where( - TopicFollower.follower == reader.id - ) + q = query_shouts() - subquery = ( - select(Shout.id) - .where(Shout.id == ShoutAuthor.shout) - .where(Shout.id == ShoutTopic.shout) - .where( - (ShoutAuthor.author.in_(reader_followed_authors)) - | (ShoutTopic.topic.in_(reader_followed_topics)) - ) - ) + aliased_reaction = aliased(Reaction) + q = add_reaction_stat_columns(q, aliased_reaction) - q = ( - select(Shout) - .options(joinedload(Shout.authors), joinedload(Shout.topics)) - .where( + # filters + filters = options.get('filters') + if filters: + q, reader_id = filter_my(info, session, q) + q = apply_filters(q, filters, reader_id) + + # sort order + order_by = options.get( + 'order_by', + Shout.featured_at if filters.get('featured') else Shout.published_at, + ) + + query_order_by = ( + desc(order_by) if options.get('order_by_desc', True) else asc(order_by) + ) + + # pagination + offset = options.get('offset', 0) + limit = options.get('limit', 10) + + q = ( + q.group_by(Shout.id) + .order_by(nulls_last(query_order_by)) + .limit(limit) + .offset(offset) + ) + + # print(q.compile(compile_kwargs={"literal_binds": True})) + + for [ + shout, + reacted_stat, + commented_stat, + likes_stat, + dislikes_stat, + last_comment, + ] in session.execute(q).unique(): + main_topic = ( + session.query(Topic.slug) + .join( + ShoutTopic, and_( - Shout.published_at.is_not(None), - Shout.deleted_at.is_(None), - Shout.id.in_(subquery), - ) + ShoutTopic.topic == Topic.id, + ShoutTopic.shout == shout.id, + ShoutTopic.main.is_(True), + ), ) + .first() ) - aliased_reaction = aliased(Reaction) - q = add_reaction_stat_columns(q, aliased_reaction) - filters = options.get('filters', {}) - q = apply_filters(q, filters, reader.id) - - order_by = options.get( - 'order_by', - Shout.featured_at if filters.get('featured') else Shout.published_at, - ) - - query_order_by = ( - desc(order_by) if options.get('order_by_desc', True) else asc(order_by) - ) - offset = options.get('offset', 0) - limit = options.get('limit', 10) - - q = ( - q.group_by(Shout.id) - .order_by(nulls_last(query_order_by)) - .limit(limit) - .offset(offset) - ) - - # print(q.compile(compile_kwargs={"literal_binds": True})) - - for [ - shout, - reacted_stat, - commented_stat, - likes_stat, - dislikes_stat, - last_comment, - ] in session.execute(q).unique(): - main_topic = ( - session.query(Topic.slug) - .join( - ShoutTopic, - and_( - ShoutTopic.topic == Topic.id, - ShoutTopic.shout == shout.id, - ShoutTopic.main.is_(True), - ), - ) - .first() - ) - - if main_topic: - shout.main_topic = main_topic[0] - shout.stat = { - 'viewed': await ViewedStorage.get_shout(shout.slug), - 'reacted': reacted_stat, - 'commented': commented_stat, - 'rating': likes_stat - dislikes_stat, - 'last_comment': last_comment - } - shouts.append(shout) + if main_topic: + shout.main_topic = main_topic[0] + shout.stat = { + 'viewed': await ViewedStorage.get_shout(shout.slug), + 'reacted': reacted_stat, + 'commented': commented_stat, + 'rating': likes_stat - dislikes_stat, + 'last_comment': last_comment + } + shouts.append(shout) return shouts @@ -301,10 +309,8 @@ async def load_shouts_search(_, _info, text, limit=50, offset=0): @login_required @query.field('load_shouts_unrated') async def load_shouts_unrated(_, info, limit: int = 50, offset: int = 0): - q = ( - select(Shout) - .options(selectinload(Shout.authors), selectinload(Shout.topics)) - .outerjoin( + q = query_shouts() + q = q.outerjoin( Reaction, and_( Reaction.shout == Shout.id, @@ -313,16 +319,14 @@ async def load_shouts_unrated(_, info, limit: int = 50, offset: int = 0): [ReactionKind.LIKE.value, ReactionKind.DISLIKE.value] ), ), - ) - .outerjoin(Author, Author.user == bindparam('user_id')) - .where( + ).outerjoin(Author, Author.user == bindparam('user_id')).where( and_( Shout.deleted_at.is_(None), Shout.layout.is_not(None), or_(Author.id.is_(None), Reaction.created_by != Author.id), ) ) - ) + # 3 or fewer votes is 0, 1, 2 or 3 votes (null, reaction id1, reaction id2, reaction id3) q = q.having(func.count(distinct(Reaction.id)) <= 4)