import json import logging import time from typing import Dict, List, Tuple, Union import strawberry from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.sql import not_ from orm.notification import Notification, NotificationAction, NotificationEntity, NotificationSeen from resolvers.model import ( NotificationAuthor, NotificationGroup, NotificationReaction, NotificationShout, NotificationsResult, ) from services.db import local_session logger = logging.getLogger('[resolvers.schema]') logger.setLevel(logging.DEBUG) def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[Tuple[Notification, bool]]]: notification_seen_alias = aliased(NotificationSeen) query = select(Notification, notification_seen_alias.viewer.label('seen')).outerjoin( NotificationSeen, and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id), ) if after: query = query.filter(Notification.created_at > after) query = query.group_by(NotificationSeen.notification, Notification.created_at) with local_session() as session: total = ( session.query(Notification) .filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)) .count() ) unread = ( session.query(Notification) .filter( and_( Notification.action == NotificationAction.CREATE.value, Notification.created_at > after, not_(Notification.seen), ) ) .count() ) notifications_result = session.execute(query) notifications = [] for n, seen in notifications_result: notifications.append((n, seen)) return total, unread, notifications def process_shout_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: if not isinstance(notification.payload, str) or not isinstance(notification.entity, str): return payload = json.loads(notification.payload) shout: NotificationShout = payload thread_id = str(shout.id) group = NotificationGroup( id=thread_id, entity=notification.entity, shout=shout, authors=shout.authors, updated_at=shout.created_at, reactions=[], action='create', seen=seen, ) return thread_id, group def process_reaction_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: if ( not isinstance(notification, Notification) or not isinstance(notification.payload, str) or not isinstance(notification.entity, str) ): return payload = json.loads(notification.payload) reaction: NotificationReaction = payload shout: NotificationShout = reaction.shout thread_id = str(reaction.shout) if reaction.kind == 'COMMENT' and reaction.reply_to: thread_id += f'::{reaction.reply_to}' group = NotificationGroup( id=thread_id, action=str(notification.action), entity=notification.entity, updated_at=reaction.created_at, reactions=[reaction.id], shout=shout, authors=[reaction.created_by], seen=seen, ) return thread_id, group def process_follower_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: if not isinstance(notification.payload, str): return payload = json.loads(notification.payload) follower: NotificationAuthor = payload thread_id = 'followers' group = NotificationGroup( id=thread_id, authors=[follower], updated_at=int(time.time()), shout=None, reactions=[], entity='follower', action='follow', seen=seen, ) return thread_id, group async def get_notifications_grouped( author_id: int, after: int = 0, limit: int = 10 ) -> Tuple[Dict[str, NotificationGroup], int, int]: total, unread, notifications = query_notifications(author_id, after) groups_by_thread: Dict[str, NotificationGroup] = {} groups_amount = 0 for notification, seen in notifications: if groups_amount >= limit: break if str(notification.entity) == 'shout' and str(notification.action) == 'create': result = process_shout_notification(notification, seen) if result: thread_id, group = result groups_by_thread[thread_id] = group groups_amount += 1 elif ( str(notification.entity) == NotificationEntity.REACTION.value and str(notification.action) == NotificationAction.CREATE.value ): result = process_reaction_notification(notification, seen) if result: thread_id, group = result existing_group = groups_by_thread.get(thread_id) if existing_group: existing_group.seen = False existing_group.shout = group.shout existing_group.authors.append(group.authors[0]) if not existing_group.reactions: existing_group.reactions = [] existing_group.reactions.extend(group.reactions or []) groups_by_thread[thread_id] = existing_group else: groups_by_thread[thread_id] = group groups_amount += 1 elif str(notification.entity) == 'follower': result = process_follower_notification(notification, seen) if result: thread_id, group = result groups_by_thread[thread_id] = group groups_amount += 1 return groups_by_thread, unread, total @strawberry.type class Query: @strawberry.field async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult: author_id = info.context.get('author_id') if author_id: groups, unread, total = await get_notifications_grouped(author_id, after, limit) notifications = sorted(groups.values(), key=lambda group: group.updated_at, reverse=True) return NotificationsResult(notifications=notifications, total=total, unread=unread, error=None) return NotificationsResult(notifications=[], total=0, unread=0, error=None)