import json import logging import time from typing import Dict, List, Tuple, Union import strawberry from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.sql import not_ from orm.notification import ( Notification, NotificationAction, NotificationEntity, NotificationSeen, ) from resolvers.model import ( NotificationAuthor, NotificationGroup, NotificationReaction, NotificationShout, NotificationsResult, ) from services.db import local_session logger = logging.getLogger("[resolvers.schema]") logger.setLevel(logging.DEBUG) def query_notifications( author_id: int, after: int = 0 ) -> Tuple[int, int, List[Tuple[Notification, bool]]]: notification_seen_alias = aliased(NotificationSeen) query = select( Notification, notification_seen_alias.viewer.label("seen") ).outerjoin( NotificationSeen, and_( NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id, ), ) if after: query = query.filter(Notification.created_at > after) query = query.group_by(NotificationSeen.notification, Notification.created_at) with local_session() as session: total = ( session.query(Notification) .filter( and_( Notification.action == NotificationAction.CREATE.value, Notification.created_at > after, ) ) .count() ) unread = ( session.query(Notification) .filter( and_( Notification.action == NotificationAction.CREATE.value, Notification.created_at > after, not_(Notification.seen), ) ) .count() ) notifications_result = session.execute(query) notifications = [] for n, seen in notifications_result: notifications.append((n, seen)) return total, unread, notifications def process_shout_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: if not isinstance(notification.payload, str) or not isinstance( notification.entity, str ): return payload = json.loads(notification.payload) shout: NotificationShout = payload thread_id = str(shout.id) group = NotificationGroup( id=thread_id, entity=notification.entity, shout=shout, authors=shout.authors, updated_at=shout.created_at, reactions=[], action="create", seen=seen, ) return thread_id, group def process_reaction_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: if ( not isinstance(notification, Notification) or not isinstance(notification.payload, str) or not isinstance(notification.entity, str) ): return payload = json.loads(notification.payload) reaction: NotificationReaction = payload shout: NotificationShout = reaction.shout thread_id = str(reaction.shout) if reaction.kind == "COMMENT" and reaction.reply_to: thread_id += f"::{reaction.reply_to}" group = NotificationGroup( id=thread_id, action=str(notification.action), entity=notification.entity, updated_at=reaction.created_at, reactions=[reaction.id], shout=shout, authors=[reaction.created_by], seen=seen, ) return thread_id, group def process_follower_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: if not isinstance(notification.payload, str): return payload = json.loads(notification.payload) follower: NotificationAuthor = payload thread_id = "followers" group = NotificationGroup( id=thread_id, authors=[follower], updated_at=int(time.time()), shout=None, reactions=[], entity="follower", action="follow", seen=seen, ) return thread_id, group async def get_notifications_grouped( author_id: int, after: int = 0, limit: int = 10 ) -> Tuple[Dict[str, NotificationGroup], int, int]: """ Retrieves notifications for a given author. Args: author_id (int): The ID of the author for whom notifications are retrieved. after (int, optional): If provided, selects only notifications created after this timestamp will be considered. limit (int, optional): The maximum number of groupa to retrieve. offset (int, optional): Offset for pagination Returns: Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs and values are NotificationGroup objects, unread and total amounts. This function queries the database to retrieve notifications for the specified author, considering optional filters. The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup containing information about the notifications within that thread. NotificationGroup structure: { entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower'). updated_at: int, # Timestamp of the latest update in the thread. shout: Optional[NotificationShout] reactions: List[int], # List of reaction ids within the thread. authors: List[NotificationAuthor], # List of authors involved in the thread. } """ total, unread, notifications = query_notifications(author_id, after) groups_by_thread: Dict[str, NotificationGroup] = {} groups_amount = 0 for notification, seen in notifications: if groups_amount >= limit: break if str(notification.entity) == "shout" and str(notification.action) == "create": result = process_shout_notification(notification, seen) if result: thread_id, group = result groups_by_thread[thread_id] = group groups_amount += 1 elif ( str(notification.entity) == NotificationEntity.REACTION.value and str(notification.action) == NotificationAction.CREATE.value ): result = process_reaction_notification(notification, seen) if result: thread_id, group = result existing_group = groups_by_thread.get(thread_id) if existing_group: existing_group.seen = False existing_group.shout = group.shout existing_group.authors.append(group.authors[0]) if not existing_group.reactions: existing_group.reactions = [] existing_group.reactions.extend(group.reactions or []) groups_by_thread[thread_id] = existing_group else: groups_by_thread[thread_id] = group groups_amount += 1 elif str(notification.entity) == "follower": result = process_follower_notification(notification, seen) if result: thread_id, group = result groups_by_thread[thread_id] = group groups_amount += 1 return groups_by_thread, unread, total @strawberry.type class Query: @strawberry.field async def load_notifications( self, info, after: int, limit: int = 50, offset: int = 0 ) -> NotificationsResult: author_id = info.context.get("author_id") if author_id: groups, unread, total = await get_notifications_grouped( author_id, after, limit ) notifications = sorted( groups.values(), key=lambda group: group.updated_at, reverse=True ) return NotificationsResult( notifications=notifications, total=total, unread=unread, error=None ) return NotificationsResult(notifications=[], total=0, unread=0, error=None)