import time from datetime import UTC, datetime from typing import Any import orjson from graphql import GraphQLResolveInfo from sqlalchemy import and_, select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import aliased from sqlalchemy.sql import not_ from orm.author import Author from orm.notification import ( Notification, NotificationAction, NotificationEntity, NotificationSeen, ) from orm.shout import Shout from services.auth import login_required from storage.db import local_session from storage.schema import mutation, query from utils.logger import root_logger as logger def query_notifications(author_id: int, after: int = 0) -> tuple[int, int, list[tuple[Notification, bool]]]: notification_seen_alias = aliased(NotificationSeen) q = select(Notification, notification_seen_alias.viewer.label("seen")).outerjoin( NotificationSeen, and_( NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id, ), ) if after: # Convert Unix timestamp to datetime for PostgreSQL compatibility after_datetime = datetime.fromtimestamp(after, tz=UTC) q = q.where(Notification.created_at > after_datetime) with local_session() as session: # Build query conditions conditions = [Notification.action == NotificationAction.CREATE.value] if after: after_datetime = datetime.fromtimestamp(after, tz=UTC) conditions.append(Notification.created_at > after_datetime) total = session.query(Notification).where(and_(*conditions)).count() unread_conditions = [*conditions, not_(Notification.seen)] unread = session.query(Notification).where(and_(*unread_conditions)).count() notifications_result = session.execute(q) notifications = [] for n, seen in notifications_result: notifications.append((n, seen)) return total, unread, notifications def group_notification( thread: str, authors: list[Any] | None = None, shout: Any | None = None, reactions: list[Any] | None = None, entity: str = "follower", action: str = "follow", ) -> dict: reactions = reactions or [] authors = authors or [] return { "thread": thread, "authors": authors, "updated_at": int(time.time()), "shout": shout, "reactions": reactions, "entity": entity, "action": action, } def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0) -> list[dict]: """ Retrieves notifications for a given author. Args: author_id (int): The ID of the author for whom notifications are retrieved. after (int, optional): If provided, selects only notifications created after this timestamp will be considered. limit (int, optional): The maximum number of groupa to retrieve. offset (int, optional): offset Returns: Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs and values are NotificationGroup objects, unread and total amounts. This function queries the database to retrieve notifications for the specified author, considering optional filters. The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup containing information about the notifications within that thread. NotificationGroup structure: { entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower'). updated_at: int, # Timestamp of the latest update in the thread. shout: Optional[NotificationShout] reactions: List[int], # List of reaction ids within the thread. authors: List[NotificationAuthor], # List of authors involved in the thread. } """ _total, _unread, notifications = query_notifications(author_id, after) groups_by_thread = {} groups_amount = 0 for notification, _seen in notifications: if (groups_amount + offset) >= limit: break payload = orjson.loads(str(notification.payload)) if str(notification.entity) == NotificationEntity.SHOUT.value: shout = payload shout_id = shout.get("id") author_id = shout.get("created_by") thread_id = f"shout-{shout_id}" with local_session() as session: author = session.query(Author).where(Author.id == author_id).first() shout = session.query(Shout).where(Shout.id == shout_id).first() if author and shout: author_dict = author.dict() shout_dict = shout.dict() group = group_notification( thread_id, shout=shout_dict, authors=[author_dict], action=str(notification.action), entity=str(notification.entity), ) groups_by_thread[thread_id] = group groups_amount += 1 elif str(notification.entity) == NotificationEntity.REACTION.value: reaction = payload if not isinstance(reaction, dict): msg = "reaction data is not consistent" raise ValueError(msg) shout_id = reaction.get("shout") author_id = reaction.get("created_by", 0) if shout_id and author_id: with local_session() as session: author = session.query(Author).where(Author.id == author_id).first() shout = session.query(Shout).where(Shout.id == shout_id).first() if shout and author: author_dict = author.dict() shout_dict = shout.dict() reply_id = reaction.get("reply_to") thread_id = f"shout-{shout_id}" if reply_id and reaction.get("kind", "").lower() == "comment": thread_id = f"shout-{shout_id}::{reply_id}" existing_group = groups_by_thread.get(thread_id) if existing_group: existing_group["seen"] = False existing_group["authors"].append(author_id) existing_group["reactions"] = existing_group["reactions"] or [] existing_group["reactions"].append(reaction) groups_by_thread[thread_id] = existing_group else: group = group_notification( thread_id, authors=[author_dict], shout=shout_dict, reactions=[reaction], entity=str(notification.entity), action=str(notification.action), ) if group: groups_by_thread[thread_id] = group groups_amount += 1 elif str(notification.entity) == "follower": thread_id = "followers" follower = orjson.loads(payload) existing_group = groups_by_thread.get(thread_id) if existing_group: if str(notification.action) == "follow": existing_group["authors"].append(follower) elif str(notification.action) == "unfollow": follower_id = follower.get("id") for author in existing_group["authors"]: if isinstance(author, dict) and author.get("id") == follower_id: existing_group["authors"].remove(author) break else: group = group_notification( thread_id, authors=[follower], entity=str(notification.entity), action=str(notification.action), ) groups_amount += 1 existing_group = group groups_by_thread[thread_id] = existing_group return list(groups_by_thread.values()) @query.field("load_notifications") @login_required async def load_notifications(_: None, info: GraphQLResolveInfo, after: int, limit: int = 50, offset: int = 0) -> dict: author_dict = info.context.get("author") or {} author_id = author_dict.get("id") error = None total = 0 unread = 0 notifications = [] try: if author_id: groups_list = get_notifications_grouped(author_id, after, limit) notifications = sorted(groups_list, key=lambda group: group.get("updated_at", 0), reverse=True) # Считаем реальное количество сгруппированных уведомлений total = len(notifications) unread = sum(1 for n in notifications if not n.get("seen", False)) except Exception as e: error = str(e) logger.error(e) return { "notifications": notifications, "total": total, "unread": unread, "error": error, } @mutation.field("notification_mark_seen") @login_required async def notification_mark_seen(_: None, info: GraphQLResolveInfo, notification_id: int) -> dict: author_id = info.context.get("author", {}).get("id") if author_id: with local_session() as session: try: ns = NotificationSeen(notification=notification_id, viewer=author_id) session.add(ns) session.commit() except SQLAlchemyError as e: session.rollback() logger.error(f"seen mutation failed: {e}") return {"error": "cant mark as read"} return {"error": None} @mutation.field("notifications_seen_after") @login_required async def notifications_seen_after(_: None, info: GraphQLResolveInfo, after: int) -> dict: """Mark all notifications after given timestamp as seen.""" error = None try: author_id = info.context.get("author", {}).get("id") if author_id: with local_session() as session: # Convert Unix timestamp to datetime for PostgreSQL compatibility after_datetime = datetime.fromtimestamp(after, tz=UTC) if after else None if after_datetime: nnn = session.query(Notification).where(and_(Notification.created_at > after_datetime)).all() else: nnn = session.query(Notification).all() for notification in nnn: ns = NotificationSeen(notification=notification.id, author=author_id) session.add(ns) session.commit() except Exception as e: print(e) error = "cant mark as read" return {"error": error} @mutation.field("notifications_seen_thread") @login_required async def notifications_seen_thread(_: None, info: GraphQLResolveInfo, thread: str, after: int) -> dict: error = None author_id = info.context.get("author", {}).get("id") if author_id: with local_session() as session: # Convert Unix timestamp to datetime for PostgreSQL compatibility after_datetime = datetime.fromtimestamp(after, tz=UTC) if after else None # Handle different thread types: shout reactions, followers, or new shouts if thread == "followers": # Mark follower notifications as seen query_conditions = [ Notification.entity == NotificationEntity.FOLLOWER.value, ] if after_datetime: query_conditions.append(Notification.created_at > after_datetime) follower_notifications = session.query(Notification).where(and_(*query_conditions)).all() for n in follower_notifications: try: ns = NotificationSeen(notification=n.id, viewer=author_id) session.add(ns) except Exception as e: logger.warning(f"Failed to mark follower notification as seen: {e}") session.commit() return {"error": None} # Handle shout and reaction notifications thread_parts = thread.split(":") if len(thread_parts) < 2: return {"error": "Invalid thread format"} shout_id = thread_parts[0] reply_to_id = thread_parts[1] if len(thread_parts) > 1 else None # Query for new shout notifications in this thread shout_query_conditions = [ Notification.entity == NotificationEntity.SHOUT.value, Notification.action == NotificationAction.CREATE.value, ] if after_datetime: shout_query_conditions.append(Notification.created_at > after_datetime) shout_notifications = session.query(Notification).where(and_(*shout_query_conditions)).all() # Mark relevant shout notifications as seen for n in shout_notifications: payload = orjson.loads(str(n.payload)) if str(payload.get("id")) == shout_id: try: ns = NotificationSeen(notification=n.id, viewer=author_id) session.add(ns) except Exception as e: logger.warning(f"Failed to mark shout notification as seen: {e}") # Query for reaction notifications if after_datetime: new_reaction_notifications = ( session.query(Notification) .where( Notification.action == NotificationAction.CREATE.value, Notification.entity == NotificationEntity.REACTION.value, Notification.created_at > after_datetime, ) .all() ) removed_reaction_notifications = ( session.query(Notification) .where( Notification.action == NotificationAction.DELETE.value, Notification.entity == NotificationEntity.REACTION.value, Notification.created_at > after_datetime, ) .all() ) else: new_reaction_notifications = ( session.query(Notification) .where( Notification.action == NotificationAction.CREATE.value, Notification.entity == NotificationEntity.REACTION.value, ) .all() ) removed_reaction_notifications = ( session.query(Notification) .where( Notification.action == NotificationAction.DELETE.value, Notification.entity == NotificationEntity.REACTION.value, ) .all() ) exclude = set() for nr in removed_reaction_notifications: reaction = orjson.loads(str(nr.payload)) reaction_id = reaction.get("id") exclude.add(reaction_id) for n in new_reaction_notifications: reaction = orjson.loads(str(n.payload)) reaction_id = reaction.get("id") if ( reaction_id not in exclude and reaction.get("shout") == shout_id and reaction.get("reply_to") == reply_to_id ): try: ns = NotificationSeen(notification=n.id, viewer=author_id) session.add(ns) session.commit() except Exception as e: logger.warn(e) session.rollback() else: error = "You are not logged in" return {"error": error}