From 713fb4d62b5dd6fffb91787f006559478e8fa1ce Mon Sep 17 00:00:00 2001 From: Untone Date: Wed, 5 Jun 2024 17:45:55 +0300 Subject: [PATCH] 0.4.1-following-update --- CHANGELOG.txt | 5 +- orm/community.py | 2 +- pyproject.toml | 2 +- resolvers/community.py | 36 +---- resolvers/follower.py | 296 +++++++++++------------------------------ resolvers/stat.py | 30 ++--- resolvers/topic.py | 37 ++---- services/cache.py | 30 ++++- 8 files changed, 143 insertions(+), 295 deletions(-) diff --git a/CHANGELOG.txt b/CHANGELOG.txt index afdaab8d..6f02d101 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,3 +1,6 @@ +[0.4.1] +- follow/unfollow logic updated and unified with cache + [0.4.0] - chore: version migrator synced - feat: precache_data on start @@ -117,7 +120,7 @@ [0.2.12] - Author.userpic -> Author.pic -- CommunityAuthor.role is string now +- CommunityFollower.role is string now - Author.user is string now [0.2.11] diff --git a/orm/community.py b/orm/community.py index 26618c75..b98a0650 100644 --- a/orm/community.py +++ b/orm/community.py @@ -7,7 +7,7 @@ from orm.author import Author from services.db import Base -class CommunityAuthor(Base): +class CommunityFollower(Base): __tablename__ = "community_author" id = None # type: ignore diff --git a/pyproject.toml b/pyproject.toml index e615cc89..cf9c34f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "core" -version = "0.4.0" +version = "0.4.1" description = "core module for discours.io" authors = ["discoursio devteam"] license = "MIT" diff --git a/resolvers/community.py b/resolvers/community.py index 0c057b16..83bfc876 100644 --- a/resolvers/community.py +++ b/resolvers/community.py @@ -1,10 +1,9 @@ -from sqlalchemy import and_, distinct, func, select +from sqlalchemy import distinct, func, select from orm.author import Author -from orm.community import Community, CommunityAuthor +from orm.community import Community from orm.shout import ShoutCommunity from services.db import local_session -from services.logger import root_logger as logger from services.schema import query @@ -25,37 +24,6 @@ def get_communities_from_query(q): return ccc -# for mutation.field("follow") -def community_follow(follower_id, slug): - try: - with local_session() as session: - community = session.query(Community).where(Community.slug == slug).first() - if isinstance(community, Community): - cf = CommunityAuthor(author=follower_id, community=community.id) - session.add(cf) - session.commit() - return True - except Exception as ex: - logger.debug(ex) - return False - - -# for mutation.field("unfollow") -def community_unfollow(follower_id, slug): - with local_session() as session: - flw = ( - session.query(CommunityAuthor) - .join(Community, Community.id == CommunityAuthor.community) - .filter(and_(CommunityAuthor.author == follower_id, Community.slug == slug)) - .first() - ) - if flw: - session.delete(flw) - session.commit() - return True - return False - - @query.field("get_communities_all") async def get_communities_all(_, _info): q = select(Author) diff --git a/resolvers/follower.py b/resolvers/follower.py index 10aac656..a9e83e97 100644 --- a/resolvers/follower.py +++ b/resolvers/follower.py @@ -4,273 +4,139 @@ from sqlalchemy import select from sqlalchemy.sql import and_ from orm.author import Author, AuthorFollower -from orm.community import Community +from orm.community import Community, CommunityFollower from orm.reaction import Reaction from orm.shout import Shout, ShoutReactionsFollower from orm.topic import Topic, TopicFollower from resolvers.stat import get_with_stat from services.auth import login_required -from services.cache import ( - cache_author, - cache_topic, - get_cached_author_by_user_id, - get_cached_follower_authors, - get_cached_follower_topics, -) +from services.cache import cache_author, cache_topic, get_cached_follower_authors, get_cached_follower_topics from services.db import local_session -from services.logger import root_logger as logger from services.notify import notify_follower from services.schema import mutation, query -async def cache_by_slug(what: str, slug: str): - is_author = what == "AUTHOR" - alias = Author if is_author else Topic - caching_query = select(alias).filter(alias.slug == slug) - [x] = get_with_stat(caching_query) - if not x: - return - - d = x.dict() # convert object to dictionary - if is_author: - await cache_author(d) - else: - await cache_topic(d) - return d - - @mutation.field("follow") @login_required async def follow(_, info, what, slug): - error = None user_id = info.context.get("user_id") follower_dict = info.context.get("author") if not user_id or not follower_dict: return {"error": "unauthorized"} follower_id = follower_dict.get("id") - entity = what.lower() - if what == "AUTHOR": - follower_id = int(follower_id) - error = author_follow(follower_id, slug) - if not error: - follows = await get_cached_follower_authors(follower_id) - with local_session() as session: - [author_id] = session.query(Author.id).filter(Author.slug == slug).first() - if author_id and author_id not in follows: - follows.append(author_id) - await cache_author(follower_dict) - await notify_follower(follower_dict, author_id, "follow") - [author] = get_with_stat(select(Author).filter(Author.id == author_id)) - if author: - author_dict = author.dict() - await cache_author(author_dict) + entity_classes = { + "AUTHOR": (Author, AuthorFollower, get_cached_follower_authors, cache_author), + "TOPIC": (Topic, TopicFollower, get_cached_follower_topics, cache_topic), + "COMMUNITY": (Community, CommunityFollower, None, None), # No cache methods provided for community + "SHOUT": (Shout, ShoutReactionsFollower, None, None), # No cache methods provided for shout + } - elif what == "TOPIC": - error = topic_follow(follower_id, slug) - if not error: - follows = await get_cached_follower_topics(follower_id) - topic_dict = await cache_by_slug(what, slug) - await cache_topic(topic_dict) + if what not in entity_classes: + return {"error": "invalid follow type"} - elif what == "COMMUNITY": + entity_class, follower_class, get_cached_follows_method, cache_method = entity_classes[what] + entity_type = what.lower() + entity_id = None + entity_dict = None + + try: + # Fetch entity id from the database with local_session() as session: - follows = session.query(Community).all() + entity_query = select(entity_class).filter(entity_class.slug == slug) + [entity] = get_with_stat(entity_query) + if not entity: + return {"error": f"{what.lower()} not found"} + entity_id = entity.id + entity_dict = entity.dict() - elif what == "SHOUT": - error = reactions_follow(follower_id, slug) - if not error: - # TODO: follows = await get_cached_follower_reactions(follower_id) - # shout_dict = await cache_shout_by_slug(what, slug) - # await cache_topic(topic_dict) - pass + if entity_id: + # Update database + with local_session() as session: + sub = follower_class(follower=follower_id, **{entity_type: entity_id}) + session.add(sub) + session.commit() - return {f"{entity}s": follows, "error": error} + follows = None + # Update cache + if cache_method: + await cache_method(entity_dict) + if get_cached_follows_method: + follows = await get_cached_follows_method(follower_id) + + # Notify author (only for AUTHOR type) + if what == "AUTHOR": + await notify_follower(follower=follower_dict, author=entity_id, action="follow") + + except Exception as exc: + return {"error": str(exc)} + + return {f"{what.lower()}s": follows} @mutation.field("unfollow") @login_required async def unfollow(_, info, what, slug): - follows = [] - error = None user_id = info.context.get("user_id") follower_dict = info.context.get("author") if not user_id or not follower_dict: return {"error": "unauthorized"} follower_id = follower_dict.get("id") - entity = what.lower() - follows = [] - - if what == "AUTHOR": - follows = await get_cached_follower_authors(follower_id) - follower_id = int(follower_id) - error = author_unfollow(follower_id, slug) - # NOTE: after triggers should update cached stats - if not error: - logger.info(f"@{follower_dict.get('slug')} unfollowed @{slug}") - [author_id] = local_session().query(Author.id).filter(Author.slug == slug).first() - if author_id and author_id in follows: - follows.remove(author_id) - await cache_author(follower_dict) - await notify_follower(follower_dict, author_id, "follow") - [author] = get_with_stat(select(Author).filter(Author.id == author_id)) - if author: - author_dict = author.dict() - await cache_author(author_dict) - - elif what == "TOPIC": - error = topic_unfollow(follower_id, slug) - if not error: - follows = await get_cached_follower_topics(follower_id) - topic_dict = await cache_by_slug(what, slug) - await cache_topic(topic_dict) - - elif what == "COMMUNITY": - with local_session() as session: - follows = session.execute(select(Community)) - - elif what == "SHOUT": - error = reactions_unfollow(follower_id, slug) - if not error: - pass - - return {"error": error, f"{entity}s": follows} - - -async def get_follows_by_user_id(user_id: str): - if not user_id: - return {"error": "unauthorized"} - author = await get_cached_author_by_user_id(user_id, get_with_stat) - if not author: - with local_session() as session: - author = session.query(Author).filter(Author.user == user_id).first() - if not author: - return {"error": "cant find author"} - author = author.dict() - - author_id = author.get("id") - if author_id: - topics = await get_cached_follower_topics(author_id) - authors = await get_cached_follower_authors(author_id) - return { - "topics": topics or [], - "authors": authors or [], - "communities": [{"id": 1, "name": "Дискурс", "slug": "discours", "pic": ""}], + entity_classes = { + "AUTHOR": (Author, AuthorFollower, get_cached_follower_authors, cache_author), + "TOPIC": (Topic, TopicFollower, get_cached_follower_topics, cache_topic), + "COMMUNITY": (Community, CommunityFollower, None, None), # No cache methods provided for community + "SHOUT": ( + Shout, + ShoutReactionsFollower, + None, + ), # No cache methods provided for shout } + if what not in entity_classes: + return {"error": "invalid unfollow type"} + + entity_class, follower_class, get_cached_follows_method, cache_method = entity_classes[what] + entity_type = what.lower() + entity_id = None + follows = [] + error = None -def topic_follow(follower_id, slug): try: with local_session() as session: - topic = session.query(Topic).where(Topic.slug == slug).one() - _following = TopicFollower(topic=topic.id, follower=follower_id) - return None - except Exception as error: - logger.warn(error) - return "cant follow" + entity = session.query(entity_class).filter(entity_class.slug == slug).first() + if not entity: + return {"error": f"{what.lower()} not found"} + entity_id = entity.id - -def topic_unfollow(follower_id, slug): - try: - with local_session() as session: sub = ( - session.query(TopicFollower) - .join(Topic) - .filter(and_(TopicFollower.follower == follower_id, Topic.slug == slug)) + session.query(follower_class) + .filter( + and_( + getattr(follower_class, "follower") == follower_id, + getattr(follower_class, entity_type) == entity_id, + ) + ) .first() ) if sub: session.delete(sub) session.commit() - return None - except Exception as error: - logger.warn(error) - return "cant unfollow" + if cache_method: + await cache_method(entity.dict()) -def reactions_follow(author_id, shout_id, auto=False): - try: - with local_session() as session: - shout = session.query(Shout).where(Shout.id == shout_id).one() + if get_cached_follows_method: + follows = await get_cached_follows_method(follower_id) - following = ( - session.query(ShoutReactionsFollower) - .where( - and_( - ShoutReactionsFollower.follower == author_id, - ShoutReactionsFollower.shout == shout.id, - ) - ) - .first() - ) + if what == "AUTHOR": + await notify_follower(follower=follower_dict, author=entity_id, action="unfollow") - if not following: - following = ShoutReactionsFollower(follower=author_id, shout=shout.id, auto=auto) - session.add(following) - session.commit() - return None - except Exception as error: - logger.warn(error) - return "cant follow" + except Exception as exc: + return {"error": str(exc)} - -def reactions_unfollow(author_id, shout_id: int): - try: - with local_session() as session: - shout = session.query(Shout).where(Shout.id == shout_id).one() - - following = ( - session.query(ShoutReactionsFollower) - .where( - and_( - ShoutReactionsFollower.follower == author_id, - ShoutReactionsFollower.shout == shout.id, - ) - ) - .first() - ) - - if following: - session.delete(following) - session.commit() - return None - except Exception as error: - logger.warn(error) - return "cant unfollow" - - -# for mutation.field("follow") -def author_follow(follower_id, slug): - try: - with local_session() as session: - author = session.query(Author).where(Author.slug == slug).one() - af = AuthorFollower(follower=follower_id, author=author.id) - session.add(af) - session.commit() - return None - except Exception as error: - logger.warn(error) - return "cant follow" - - -# for mutation.field("unfollow") -def author_unfollow(follower_id, slug): - try: - with local_session() as session: - flw = ( - session.query(AuthorFollower) - .join(Author, Author.id == AuthorFollower.author) - .filter(and_(AuthorFollower.follower == follower_id, Author.slug == slug)) - .first() - ) - if flw: - session.delete(flw) - session.commit() - return None - except Exception as error: - logger.warn(error) - return "cant unfollow" + return {f"{entity_type}s": follows, "error": error} @query.field("get_shout_followers") diff --git a/resolvers/stat.py b/resolvers/stat.py index 6918440f..33cd8cc5 100644 --- a/resolvers/stat.py +++ b/resolvers/stat.py @@ -17,25 +17,24 @@ def add_topic_stat_columns(q): new_q = select(Topic) # Apply the necessary filters to the new query object - new_q = new_q.join( - aliased_shout, - aliased_shout.topic == Topic.id, - ).join( - Shout, - and_( - aliased_shout.shout == Shout.id, - Shout.deleted_at.is_(None), - ), - ).add_columns( - func.count(distinct(aliased_shout.shout)).label("shouts_stat") + new_q = ( + new_q.join( + aliased_shout, + aliased_shout.topic == Topic.id, + ) + .join( + Shout, + and_( + aliased_shout.shout == Shout.id, + Shout.deleted_at.is_(None), + ), + ) + .add_columns(func.count(distinct(aliased_shout.shout)).label("shouts_stat")) ) aliased_follower = aliased(TopicFollower) - new_q = new_q.outerjoin( - aliased_follower, - aliased_follower.topic == Topic.id - ).add_columns( + new_q = new_q.outerjoin(aliased_follower, aliased_follower.topic == Topic.id).add_columns( func.count(distinct(aliased_follower.follower)).label("followers_stat") ) @@ -44,7 +43,6 @@ def add_topic_stat_columns(q): return new_q - def add_author_stat_columns(q): # Соединяем таблицу Author с таблицей ShoutAuthor и таблицей Shout с использованием INNER JOIN q = ( diff --git a/resolvers/topic.py b/resolvers/topic.py index 221599de..c2d4848f 100644 --- a/resolvers/topic.py +++ b/resolvers/topic.py @@ -1,11 +1,11 @@ -from sqlalchemy import and_, distinct, func, join, select +from sqlalchemy import distinct, func, select from orm.author import Author -from orm.shout import Shout, ShoutAuthor, ShoutTopic +from orm.shout import ShoutTopic from orm.topic import Topic from resolvers.stat import get_with_stat from services.auth import login_required -from services.cache import get_cached_topic_authors, get_cached_topic_followers +from services.cache import get_cached_topic_authors, get_cached_topic_by_slug, get_cached_topic_followers from services.db import local_session from services.logger import root_logger as logger from services.memorycache import cache_region @@ -50,10 +50,9 @@ async def get_topics_by_author(_, _info, author_id=0, slug="", user=""): @query.field("get_topic") -def get_topic(_, _info, slug: str): - topic_query = select(Topic).filter(Topic.slug == slug) - result = get_with_stat(topic_query) - for topic in result: +async def get_topic(_, _info, slug: str): + topic = await get_cached_topic_by_slug(slug) + if topic: return topic @@ -125,9 +124,8 @@ def get_topics_random(_, _info, amount=12): @query.field("get_topic_followers") async def get_topic_followers(_, _info, slug: str): logger.debug(f"getting followers for @{slug}") - topic_query = select(Topic.id).filter(Topic.slug == slug).first() - topic_id_result = local_session().execute(topic_query) - topic_id = topic_id_result[0] if topic_id_result else None + topic = await get_cached_topic_by_slug(slug) + topic_id = topic.id if isinstance(topic, Topic) else topic.get("id") followers = await get_cached_topic_followers(topic_id) return followers @@ -135,20 +133,7 @@ async def get_topic_followers(_, _info, slug: str): @query.field("get_topic_authors") async def get_topic_authors(_, _info, slug: str): logger.debug(f"getting authors for @{slug}") - topic_query = select(Topic.id).filter(Topic.slug == slug).first() - topic_id_result = local_session().execute(topic_query) - topic_id = topic_id_result[0] if topic_id_result else None - topic_authors_query = ( - select(ShoutAuthor.author) - .select_from(join(ShoutTopic, Shout, ShoutTopic.shout == Shout.id)) - .join(ShoutAuthor, ShoutAuthor.shout == Shout.id) - .filter( - and_( - ShoutTopic.topic == topic_id, - Shout.published_at.is_not(None), - Shout.deleted_at.is_(None), - ) - ) - ) - authors = await get_cached_topic_authors(topic_id, topic_authors_query) + topic = await get_cached_topic_by_slug(slug) + topic_id = topic.id if isinstance(topic, Topic) else topic.get("id") + authors = await get_cached_topic_authors(topic_id) return authors diff --git a/services/cache.py b/services/cache.py index 1ad3d04c..8b0ac7dd 100644 --- a/services/cache.py +++ b/services/cache.py @@ -4,6 +4,7 @@ from typing import List from sqlalchemy import and_, join, select from orm.author import Author, AuthorFollower +from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.topic import Topic, TopicFollower from services.db import local_session from services.encoders import CustomJSONEncoder @@ -80,6 +81,21 @@ async def get_cached_author_by_user_id(user_id: str, get_with_stat): return await get_cached_author(int(author_id), get_with_stat) +async def get_cached_topic_by_slug(slug: str, get_with_stat): + cached_result = await redis.execute("GET", f"topic:slug:{slug}") + if isinstance(cached_result, str): + return json.loads(cached_result) + elif get_with_stat: + with local_session() as session: + topic_query = select(Topic).filter(Topic.slug == slug) + result = get_with_stat(session.execute(topic_query)) + if result: + [topic] = result + if topic: + await cache_topic(topic) + return topic + + async def get_cached_authors_by_ids(authors_ids: List[int]) -> List[Author | dict]: authors = [] for author_id in authors_ids: @@ -93,13 +109,25 @@ async def get_cached_authors_by_ids(authors_ids: List[int]) -> List[Author | dic return authors -async def get_cached_topic_authors(topic_id: int, topic_authors_query): +async def get_cached_topic_authors(topic_id: int): rkey = f"topic:authors:{topic_id}" cached = await redis.execute("GET", rkey) authors_ids = [] if isinstance(cached, str): authors_ids = json.loads(cached) else: + topic_authors_query = ( + select(ShoutAuthor.author) + .select_from(join(ShoutTopic, Shout, ShoutTopic.shout == Shout.id)) + .join(ShoutAuthor, ShoutAuthor.shout == Shout.id) + .filter( + and_( + ShoutTopic.topic == topic_id, + Shout.published_at.is_not(None), + Shout.deleted_at.is_(None), + ) + ) + ) with local_session() as session: authors_ids = [aid for (aid,) in session.execute(topic_authors_query)] await redis.execute("SET", rkey, json.dumps(authors_ids))