diff --git a/resolvers/reaction.py b/resolvers/reaction.py index e2407ce8..d2bb7586 100644 --- a/resolvers/reaction.py +++ b/resolvers/reaction.py @@ -18,25 +18,14 @@ logging.basicConfig() logger = logging.getLogger("\t[resolvers.reaction]\t") logger.setLevel(logging.DEBUG) -def add_reaction_stat_columns(q): +def add_stat_columns(q): aliased_reaction = aliased(Reaction) - q = q.outerjoin(aliased_reaction, Reaction.id == aliased_reaction.reply_to).add_columns( - func.sum(aliased_reaction.id).label("reacted_stat"), - func.sum(case((aliased_reaction.kind == ReactionKind.COMMENT.value, 1), else_=0)).label("commented_stat"), - func.sum( - case( - (aliased_reaction.kind == ReactionKind.AGREE.value, 1), - (aliased_reaction.kind == ReactionKind.DISAGREE.value, -1), - (aliased_reaction.kind == ReactionKind.PROOF.value, 1), - (aliased_reaction.kind == ReactionKind.DISPROOF.value, -1), - (aliased_reaction.kind == ReactionKind.ACCEPT.value, 1), - (aliased_reaction.kind == ReactionKind.REJECT.value, -1), - (aliased_reaction.kind == ReactionKind.LIKE.value, 1), - (aliased_reaction.kind == ReactionKind.DISLIKE.value, -1), - else_=0, - ) - ).label("rating_stat"), + q = q.outerjoin(aliased_reaction).add_columns( + func.sum(case((aliased_reaction.kind == ReactionKind.COMMENT.value, 1), else_=0)).label("comments_stat"), + func.sum(case((aliased_reaction.kind == ReactionKind.LIKE.value, 1), else_=0)).label("likes_stat"), + func.sum(case((aliased_reaction.kind == ReactionKind.DISLIKE.value, 1), else_=0)).label("dislikes_stat"), + func.max(case((aliased_reaction.kind != ReactionKind.COMMENT.value, None),else_=aliased_reaction.created_at)).label("last_comment"), ) return q, aliased_reaction @@ -281,10 +270,10 @@ async def update_reaction(_, info, rid, reaction): user_id = info.context["user_id"] with local_session() as session: q = select(Reaction).filter(Reaction.id == rid) - q, aliased_reaction = add_reaction_stat_columns(q) + q, aliased_reaction = add_stat_columns(q) q = q.group_by(Reaction.id) - [r, reacted_stat, commented_stat, rating_stat] = session.execute(q).unique().one() + [r, commented_stat, likes_stat, dislikes_stat, _l] = session.execute(q).unique().one() if not r: return {"error": "invalid reaction id"} @@ -303,15 +292,14 @@ async def update_reaction(_, info, rid, reaction): session.commit() r.stat = { "commented": commented_stat, - "reacted": reacted_stat, - "rating": rating_stat, + "rating": int(likes_stat or 0) - int(dislikes_stat or 0), } await notify_reaction(r.dict(), "update") return {"reaction": r} else: - return {"error": "not authorized"} + return {"error": "not authorized"} return {"error": "cannot create reaction"} @mutation.field("delete_reaction") @@ -397,7 +385,7 @@ async def load_reactions_by(_, info, by, limit=50, offset=0): ) # calculate counters - q, aliased_reaction = add_reaction_stat_columns(q) + q, aliased_reaction = add_stat_columns(q) # filter q = apply_reaction_filters(by, q) diff --git a/resolvers/reader.py b/resolvers/reader.py index a392daa7..845d69f0 100644 --- a/resolvers/reader.py +++ b/resolvers/reader.py @@ -13,18 +13,7 @@ from services.schema import query from services.search import SearchService from services.viewed import ViewedStorage from resolvers.topic import get_random_topic - -def add_stat_columns(q): - aliased_reaction = aliased(Reaction) - - q = q.outerjoin(aliased_reaction).add_columns( - func.sum(case((aliased_reaction.kind == ReactionKind.COMMENT.value, 1), else_=0)).label("comments_stat"), - func.sum(case((aliased_reaction.kind == ReactionKind.LIKE.value, 1), else_=0)).label("likes_stat"), - func.sum(case((aliased_reaction.kind == ReactionKind.DISLIKE.value, 1), else_=0)).label("dislikes_stat"), - func.max(case((aliased_reaction.kind != ReactionKind.COMMENT.value, None),else_=aliased_reaction.created_at)).label("last_comment"), - ) - - return q +from resolvers.reaction import add_stat_columns def apply_filters(q, filters, author_id=None): # noqa: C901 @@ -59,7 +48,7 @@ async def get_shout(_, _info, slug=None, shout_id=None): joinedload(Shout.topics), ) - q = add_stat_columns(q) + q, _ar = add_stat_columns(q) if slug is not None: q = q.filter(Shout.slug == slug) @@ -133,7 +122,7 @@ async def load_shouts_by(_, _info, options): ) # stats - q = add_stat_columns(q) + q, _ar = add_stat_columns(q) # filters q = apply_filters(q, options.get("filters", {})) @@ -242,7 +231,7 @@ async def load_shouts_feed(_, info, options): .where(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None), Shout.id.in_(subquery))) ) - q = add_stat_columns(q) + q, _ar = add_stat_columns(q) q = apply_filters(q, options.get("filters", {}), reader.id) order_by = options.get("order_by", Shout.published_at) @@ -342,7 +331,7 @@ async def load_shouts_unrated(_, info, limit: int = 50, offset: int = 0): # 3 or fewer votes is 0, 1, 2 or 3 votes (null, reaction id1, reaction id2, reaction id3) q = q.having(func.count(distinct(Reaction.id)) <= 4) - q = add_stat_columns(q) + q, _ar = add_stat_columns(q) q = q.group_by(Shout.id).order_by(func.random()).limit(limit).offset(offset) user_id = info.context.get("user_id") @@ -408,7 +397,7 @@ async def load_shouts_random_top(_, _info, options): .where(Shout.id.in_(subquery)) ) - q = add_stat_columns(q) + q, _ar = add_stat_columns(q) limit = options.get("limit", 10) q = q.group_by(Shout.id).order_by(func.random()).limit(limit) @@ -431,7 +420,7 @@ async def load_shouts_random_topic(_, info, limit: int = 10): .filter(and_(Shout.deleted_at.is_(None), Shout.visibility == "public", Shout.topics.any(slug=topic.slug))) ) - q = add_stat_columns(q) + q, _ar = add_stat_columns(q) q = q.group_by(Shout.id).order_by(desc(Shout.created_at)).limit(limit)