notifier/resolvers/schema.py
Untone 737cc40353
All checks were successful
deploy / deploy (push) Successful in 1m8s
auth-check-middleware-2
2023-12-17 14:45:20 +03:00

156 lines
5.6 KiB
Python

from typing import List
from sqlalchemy import and_, select
from sqlalchemy.orm import aliased
from sqlalchemy.exc import SQLAlchemyError
from orm.author import Author
from orm.notification import Notification as NotificationMessage, NotificationSeen
from services.auth import check_auth
from services.db import local_session
import strawberry
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from strawberry.schema.config import StrawberryConfig
from strawberry.extensions import Extension
import logging
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
# Инициализация логгера
logger = logging.getLogger(__name__)
@strawberry_sqlalchemy_mapper.type(NotificationMessage)
class Notification:
id: int
action: str # create update delete join follow etc.
entity: str # REACTION SHOUT
created_at: int
payload: str # JSON data
seen: List[int]
@strawberry.type
class NotificationSeenResult:
error: str | None
@strawberry.type
class NotificationsResult:
notifications: List[Notification]
unread: int
total: int
def get_notifications(author_id: int, session, limit: int, offset: int) -> List[Notification]:
NotificationSeenAlias = aliased(NotificationSeen)
query = (
select(NotificationMessage, NotificationSeenAlias.viewer.label("seen"))
.outerjoin(
NotificationSeen,
and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == NotificationMessage.id),
)
.group_by(NotificationSeen.notification)
)
if limit:
query = query.limit(limit)
if offset:
query = query.offset(offset)
notifications = []
for n, seen in session.execute(query):
ntf = Notification(
id=n.id,
payload=n.payload,
entity=n.entity,
action=n.action,
created_at=n.created_at,
seen=seen,
)
if ntf:
notifications.append(ntf)
return notifications
@strawberry.type
class Query:
@strawberry.field
async def load_notifications(self, info, limit: int = 50, offset: int = 0) -> NotificationsResult:
author_id = info.context.get("author_id")
with local_session() as session:
try:
if author_id:
notifications = get_notifications(author_id, session, limit, offset)
if notifications and len(notifications) > 0:
nr = NotificationsResult(
notifications=notifications,
unread=sum(1 for n in notifications if author_id in n.seen),
total=session.query(NotificationMessage).count(),
)
return nr
except Exception as ex:
import traceback
traceback.print_exc()
logger.error(f"[load_notifications] Ошибка при выполнении запроса к базе данных: {ex}")
return NotificationsResult(notifications=[], total=0, unread=0)
@strawberry.type
class Mutation:
@strawberry.mutation
async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult:
author_id = info.context.get("author_id")
if author_id:
with local_session() as session:
try:
ns = NotificationSeen(notification=notification_id, viewer=author_id)
session.add(ns)
session.commit()
except SQLAlchemyError as e:
session.rollback()
logger.error(
f"[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}"
)
return NotificationSeenResult(error="cant mark as read")
return NotificationSeenResult()
@strawberry.mutation
async def mark_all_notifications_as_read(self, info) -> NotificationSeenResult:
author_id = info.context.get("author_id")
if author_id:
try:
with local_session() as session:
nslist = get_notifications(author_id, session, None, None)
for n in nslist:
if author_id not in n.seen:
ns = NotificationSeen(viewer=author_id, notification=n.id)
session.add(ns)
session.commit()
except SQLAlchemyError as e:
session.rollback()
logger.error(
f"[mark_all_notifications_as_read] Ошибка при обновлении статуса прочтения всех уведомлений: {str(e)}"
)
return NotificationSeenResult(error="cant mark as read")
return NotificationSeenResult()
class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get("request")
is_authenticated, user_id = await check_auth(req)
if is_authenticated:
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context["author_id"] = author.id
if user_id:
context["user_id"] = user_id
context["user_id"] = user_id or None
schema = strawberry.Schema(
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
)