from typing import List from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.exc import SQLAlchemyError from orm.author import Author from orm.notification import Notification as NotificationMessage, NotificationSeen from services.auth import login_required from services.db import local_session from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper import strawberry from strawberry.schema.config import StrawberryConfig strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() @strawberry_sqlalchemy_mapper.type(NotificationMessage) class Notification: id: int action: str # create update delete join follow etc. entity: str # REACTION SHOUT created_at: int payload: str # JSON data seen: List[int] @strawberry.type class NotificationSeenResult: error: str = strawberry.field(default=None, name="error") @strawberry.type class NotificationsResult: notifications: List[Notification] unread: int total: int def get_notifications(author, session, limit, offset) -> List[Notification]: NotificationSeenAlias = aliased(NotificationSeen) query = select( NotificationMessage, NotificationSeenAlias.viewer.label("seen") ).outerjoin( NotificationSeen, and_(NotificationSeen.viewer == author.id, NotificationSeen.notification == NotificationMessage.id), ).group_by(NotificationSeen.notification) if limit: query = query.limit(limit) if offset: query = query.offset(offset) notifications = [] for n, seen in session.execute(query): ntf = Notification( id=n.id, payload=n.payload, entity=n.entity, action=n.action, created_at=n.created_at, seen=seen, ) if ntf: notifications.append(ntf) return notifications @strawberry.type class Query: @strawberry.field @login_required async def load_notifications(self, info, limit: int = 50, offset: int = 0) -> NotificationsResult: user_id = info.context["user_id"] with local_session() as session: try: author = session.query(Author).filter(Author.user == user_id).first() if author: notifications = get_notifications(author, session, limit, offset) if notifications and len(notifications) > 0: nr = NotificationsResult( notifications=notifications, unread=sum(1 for n in notifications if author.id in n.seen), total=session.query(NotificationMessage).count() ) return nr except Exception as ex: print(f"[resolvers.schema] {ex}") return NotificationsResult( notifications=[], total=0, unread=0 ) @strawberry.type class Mutation: @strawberry.mutation @login_required async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult: user_id = info.context["user_id"] with local_session() as session: try: author = session.query(Author).filter(Author.user == user_id).first() if author: ns = NotificationSeen(notification=notification_id, viewer=author.id) session.add(ns) session.commit() except SQLAlchemyError as e: session.rollback() print(f"[mark_notification_as_read] error: {str(e)}") nsr = NotificationSeenResult(error="cant mark as read") return nsr return NotificationSeenResult() @strawberry.mutation @login_required async def mark_all_notifications_as_read(self, info) -> NotificationSeenResult: user_id = info.context["user_id"] with local_session() as session: try: author = session.query(Author).filter(Author.user == user_id).first() if author: nslist = get_notifications(author, session, None, None) for n in nslist: if author.id not in n.seen: ns = NotificationSeen(viewer=author.id, notification=n.id) session.add(ns) session.commit() except SQLAlchemyError as e: session.rollback() print(f"[mark_all_notifications_as_read] error: {str(e)}") nsr = NotificationSeenResult(error="cant mark as read") return nsr return NotificationSeenResult() schema = strawberry.Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False))