432 lines
18 KiB
Python
432 lines
18 KiB
Python
import time
|
||
from datetime import UTC, datetime
|
||
from typing import Any
|
||
|
||
import orjson
|
||
from graphql import GraphQLResolveInfo
|
||
from sqlalchemy import and_, select
|
||
from sqlalchemy.exc import SQLAlchemyError
|
||
from sqlalchemy.orm import aliased
|
||
from sqlalchemy.sql import not_
|
||
|
||
from orm.author import Author
|
||
from orm.notification import (
|
||
Notification,
|
||
NotificationAction,
|
||
NotificationEntity,
|
||
NotificationSeen,
|
||
)
|
||
from orm.shout import Shout, ShoutReactionsFollower
|
||
from services.auth import login_required
|
||
from storage.db import local_session
|
||
from storage.schema import mutation, query
|
||
from utils.logger import root_logger as logger
|
||
|
||
|
||
def query_notifications(author_id: int, after: int = 0) -> tuple[int, int, list[tuple[Notification, bool]]]:
|
||
notification_seen_alias = aliased(NotificationSeen)
|
||
q = select(Notification, notification_seen_alias.viewer.label("seen")).outerjoin(
|
||
NotificationSeen,
|
||
and_(
|
||
NotificationSeen.viewer == author_id,
|
||
NotificationSeen.notification == Notification.id,
|
||
),
|
||
)
|
||
if after:
|
||
# Convert Unix timestamp to datetime for PostgreSQL compatibility
|
||
after_datetime = datetime.fromtimestamp(after, tz=UTC)
|
||
q = q.where(Notification.created_at > after_datetime)
|
||
|
||
with local_session() as session:
|
||
# Build query conditions
|
||
conditions = [Notification.action == NotificationAction.CREATE.value]
|
||
if after:
|
||
after_datetime = datetime.fromtimestamp(after, tz=UTC)
|
||
conditions.append(Notification.created_at > after_datetime)
|
||
|
||
total = session.query(Notification).where(and_(*conditions)).count()
|
||
|
||
unread_conditions = [*conditions, not_(Notification.seen)]
|
||
unread = session.query(Notification).where(and_(*unread_conditions)).count()
|
||
|
||
notifications_result = session.execute(q)
|
||
notifications = []
|
||
for n, seen in notifications_result:
|
||
notifications.append((n, seen))
|
||
|
||
return total, unread, notifications
|
||
|
||
|
||
def check_subscription(shout_id: int, current_author_id: int) -> bool:
|
||
"""
|
||
Проверяет подписку пользователя на уведомления о шауте.
|
||
|
||
Проверяет наличие записи в ShoutReactionsFollower:
|
||
- Запись есть → подписан
|
||
- Записи нет → не подписан (отписался или никогда не подписывался)
|
||
|
||
Автоматическая подписка (auto=True) создается при:
|
||
- Создании поста
|
||
- Первом комментарии/реакции
|
||
|
||
Отписка = удаление записи из таблицы
|
||
|
||
Returns:
|
||
bool: True если подписан на уведомления
|
||
"""
|
||
with local_session() as session:
|
||
# Проверяем наличие записи в ShoutReactionsFollower
|
||
follow = (
|
||
session.query(ShoutReactionsFollower)
|
||
.filter(
|
||
ShoutReactionsFollower.follower == current_author_id,
|
||
ShoutReactionsFollower.shout == shout_id,
|
||
)
|
||
.first()
|
||
)
|
||
|
||
return follow is not None
|
||
|
||
|
||
def group_notification(
|
||
thread: str,
|
||
authors: list[Any] | None = None,
|
||
shout: Any | None = None,
|
||
reactions: list[Any] | None = None,
|
||
entity: str = "follower",
|
||
action: str = "follow",
|
||
) -> dict:
|
||
reactions = reactions or []
|
||
authors = authors or []
|
||
return {
|
||
"thread": thread,
|
||
"authors": authors,
|
||
"updated_at": int(time.time()),
|
||
"shout": shout,
|
||
"reactions": reactions,
|
||
"entity": entity,
|
||
"action": action,
|
||
}
|
||
|
||
|
||
def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0) -> list[dict]:
|
||
"""
|
||
Retrieves notifications for a given author.
|
||
|
||
Args:
|
||
author_id (int): The ID of the author for whom notifications are retrieved.
|
||
after (int, optional): If provided, selects only notifications created after this timestamp will be considered.
|
||
limit (int, optional): The maximum number of groupa to retrieve.
|
||
offset (int, optional): offset
|
||
|
||
Returns:
|
||
Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs
|
||
and values are NotificationGroup objects, unread and total amounts.
|
||
|
||
This function queries the database to retrieve notifications for the specified author, considering optional filters.
|
||
The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup
|
||
containing information about the notifications within that thread.
|
||
|
||
NotificationGroup structure:
|
||
{
|
||
entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower').
|
||
updated_at: int, # Timestamp of the latest update in the thread.
|
||
shout: Optional[NotificationShout]
|
||
reactions: List[int], # List of reaction ids within the thread.
|
||
authors: List[NotificationAuthor], # List of authors involved in the thread.
|
||
}
|
||
"""
|
||
_total, _unread, notifications = query_notifications(author_id, after)
|
||
groups_by_thread = {}
|
||
groups_amount = 0
|
||
|
||
for notification, _seen in notifications:
|
||
if (groups_amount + offset) >= limit:
|
||
break
|
||
|
||
payload = orjson.loads(str(notification.payload))
|
||
|
||
if str(notification.entity) == NotificationEntity.SHOUT.value:
|
||
shout = payload
|
||
shout_id = shout.get("id")
|
||
shout_author_id = shout.get("created_by")
|
||
thread_id = f"shout-{shout_id}"
|
||
|
||
with local_session() as session:
|
||
author = session.query(Author).where(Author.id == shout_author_id).first()
|
||
shout = session.query(Shout).where(Shout.id == shout_id).first()
|
||
if author and shout:
|
||
# Проверяем подписку - если не подписан, пропускаем это уведомление
|
||
if not check_subscription(shout_id, author_id):
|
||
continue
|
||
|
||
author_dict = author.dict()
|
||
shout_dict = shout.dict()
|
||
|
||
group = group_notification(
|
||
thread_id,
|
||
shout=shout_dict,
|
||
authors=[author_dict],
|
||
action=str(notification.action),
|
||
entity=str(notification.entity),
|
||
)
|
||
groups_by_thread[thread_id] = group
|
||
groups_amount += 1
|
||
|
||
elif str(notification.entity) == NotificationEntity.REACTION.value:
|
||
reaction = payload
|
||
if not isinstance(reaction, dict):
|
||
msg = "reaction data is not consistent"
|
||
raise ValueError(msg)
|
||
shout_id = reaction.get("shout")
|
||
author_id = reaction.get("created_by", 0)
|
||
if shout_id and author_id:
|
||
with local_session() as session:
|
||
author = session.query(Author).where(Author.id == author_id).first()
|
||
shout = session.query(Shout).where(Shout.id == shout_id).first()
|
||
if shout and author:
|
||
author_dict = author.dict()
|
||
shout_dict = shout.dict()
|
||
reply_id = reaction.get("reply_to")
|
||
thread_id = f"shout-{shout_id}"
|
||
if reply_id and reaction.get("kind", "").lower() == "comment":
|
||
thread_id = f"shout-{shout_id}::{reply_id}"
|
||
|
||
existing_group = groups_by_thread.get(thread_id)
|
||
if existing_group:
|
||
existing_group["seen"] = False
|
||
existing_group["authors"].append(author_id)
|
||
existing_group["reactions"] = existing_group["reactions"] or []
|
||
existing_group["reactions"].append(reaction)
|
||
groups_by_thread[thread_id] = existing_group
|
||
else:
|
||
# Проверяем подписку - если не подписан, пропускаем это уведомление
|
||
if not check_subscription(shout_id, author_id):
|
||
continue
|
||
|
||
group = group_notification(
|
||
thread_id,
|
||
authors=[author_dict],
|
||
shout=shout_dict,
|
||
reactions=[reaction],
|
||
entity=str(notification.entity),
|
||
action=str(notification.action),
|
||
)
|
||
if group:
|
||
groups_by_thread[thread_id] = group
|
||
groups_amount += 1
|
||
|
||
elif str(notification.entity) == "follower":
|
||
thread_id = "followers"
|
||
follower = orjson.loads(payload)
|
||
existing_group = groups_by_thread.get(thread_id)
|
||
if existing_group:
|
||
if str(notification.action) == "follow":
|
||
existing_group["authors"].append(follower)
|
||
elif str(notification.action) == "unfollow":
|
||
follower_id = follower.get("id")
|
||
for author in existing_group["authors"]:
|
||
if isinstance(author, dict) and author.get("id") == follower_id:
|
||
existing_group["authors"].remove(author)
|
||
break
|
||
else:
|
||
group = group_notification(
|
||
thread_id,
|
||
authors=[follower],
|
||
entity=str(notification.entity),
|
||
action=str(notification.action),
|
||
)
|
||
groups_amount += 1
|
||
existing_group = group
|
||
groups_by_thread[thread_id] = existing_group
|
||
return list(groups_by_thread.values())
|
||
|
||
|
||
@query.field("load_notifications")
|
||
@login_required
|
||
async def load_notifications(_: None, info: GraphQLResolveInfo, after: int, limit: int = 50, offset: int = 0) -> dict:
|
||
author_dict = info.context.get("author") or {}
|
||
author_id = author_dict.get("id")
|
||
error = None
|
||
total = 0
|
||
unread = 0
|
||
notifications = []
|
||
try:
|
||
if author_id:
|
||
groups_list = get_notifications_grouped(author_id, after, limit)
|
||
notifications = sorted(groups_list, key=lambda group: group.get("updated_at", 0), reverse=True)
|
||
|
||
# Считаем реальное количество сгруппированных уведомлений
|
||
total = len(notifications)
|
||
unread = sum(1 for n in notifications if not n.get("seen", False))
|
||
except Exception as e:
|
||
error = str(e)
|
||
logger.error(e)
|
||
return {
|
||
"notifications": notifications,
|
||
"total": total,
|
||
"unread": unread,
|
||
"error": error,
|
||
}
|
||
|
||
|
||
@mutation.field("notification_mark_seen")
|
||
@login_required
|
||
async def notification_mark_seen(_: None, info: GraphQLResolveInfo, notification_id: int) -> dict:
|
||
author_id = info.context.get("author", {}).get("id")
|
||
if author_id:
|
||
with local_session() as session:
|
||
try:
|
||
ns = NotificationSeen(notification=notification_id, viewer=author_id)
|
||
session.add(ns)
|
||
session.commit()
|
||
except SQLAlchemyError as e:
|
||
session.rollback()
|
||
logger.error(f"seen mutation failed: {e}")
|
||
return {"error": "cant mark as read"}
|
||
return {"error": None}
|
||
|
||
|
||
@mutation.field("notifications_seen_after")
|
||
@login_required
|
||
async def notifications_seen_after(_: None, info: GraphQLResolveInfo, after: int) -> dict:
|
||
"""Mark all notifications after given timestamp as seen."""
|
||
error = None
|
||
try:
|
||
author_id = info.context.get("author", {}).get("id")
|
||
if author_id:
|
||
with local_session() as session:
|
||
# Convert Unix timestamp to datetime for PostgreSQL compatibility
|
||
after_datetime = datetime.fromtimestamp(after, tz=UTC) if after else None
|
||
if after_datetime:
|
||
nnn = session.query(Notification).where(and_(Notification.created_at > after_datetime)).all()
|
||
else:
|
||
nnn = session.query(Notification).all()
|
||
for notification in nnn:
|
||
ns = NotificationSeen(notification=notification.id, author=author_id)
|
||
session.add(ns)
|
||
session.commit()
|
||
except Exception as e:
|
||
print(e)
|
||
error = "cant mark as read"
|
||
return {"error": error}
|
||
|
||
|
||
@mutation.field("notifications_seen_thread")
|
||
@login_required
|
||
async def notifications_seen_thread(_: None, info: GraphQLResolveInfo, thread: str, after: int) -> dict:
|
||
error = None
|
||
author_id = info.context.get("author", {}).get("id")
|
||
if author_id:
|
||
with local_session() as session:
|
||
# Convert Unix timestamp to datetime for PostgreSQL compatibility
|
||
after_datetime = datetime.fromtimestamp(after, tz=UTC) if after else None
|
||
|
||
# Handle different thread types: shout reactions, followers, or new shouts
|
||
if thread == "followers":
|
||
# Mark follower notifications as seen
|
||
query_conditions = [
|
||
Notification.entity == NotificationEntity.AUTHOR.value,
|
||
]
|
||
if after_datetime:
|
||
query_conditions.append(Notification.created_at > after_datetime)
|
||
|
||
follower_notifications = session.query(Notification).where(and_(*query_conditions)).all()
|
||
for n in follower_notifications:
|
||
try:
|
||
ns = NotificationSeen(notification=n.id, viewer=author_id)
|
||
session.add(ns)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to mark follower notification as seen: {e}")
|
||
session.commit()
|
||
return {"error": None}
|
||
|
||
# Handle shout and reaction notifications
|
||
thread_parts = thread.split(":")
|
||
if len(thread_parts) < 2:
|
||
return {"error": "Invalid thread format"}
|
||
|
||
shout_id = thread_parts[0]
|
||
reply_to_id = thread_parts[1] if len(thread_parts) > 1 else None
|
||
|
||
# Query for new shout notifications in this thread
|
||
shout_query_conditions = [
|
||
Notification.entity == NotificationEntity.SHOUT.value,
|
||
Notification.action == NotificationAction.CREATE.value,
|
||
]
|
||
if after_datetime:
|
||
shout_query_conditions.append(Notification.created_at > after_datetime)
|
||
|
||
shout_notifications = session.query(Notification).where(and_(*shout_query_conditions)).all()
|
||
|
||
# Mark relevant shout notifications as seen
|
||
for n in shout_notifications:
|
||
payload = orjson.loads(str(n.payload))
|
||
if str(payload.get("id")) == shout_id:
|
||
try:
|
||
ns = NotificationSeen(notification=n.id, viewer=author_id)
|
||
session.add(ns)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to mark shout notification as seen: {e}")
|
||
|
||
# Query for reaction notifications
|
||
if after_datetime:
|
||
new_reaction_notifications = (
|
||
session.query(Notification)
|
||
.where(
|
||
Notification.action == NotificationAction.CREATE.value,
|
||
Notification.entity == NotificationEntity.REACTION.value,
|
||
Notification.created_at > after_datetime,
|
||
)
|
||
.all()
|
||
)
|
||
removed_reaction_notifications = (
|
||
session.query(Notification)
|
||
.where(
|
||
Notification.action == NotificationAction.DELETE.value,
|
||
Notification.entity == NotificationEntity.REACTION.value,
|
||
Notification.created_at > after_datetime,
|
||
)
|
||
.all()
|
||
)
|
||
else:
|
||
new_reaction_notifications = (
|
||
session.query(Notification)
|
||
.where(
|
||
Notification.action == NotificationAction.CREATE.value,
|
||
Notification.entity == NotificationEntity.REACTION.value,
|
||
)
|
||
.all()
|
||
)
|
||
removed_reaction_notifications = (
|
||
session.query(Notification)
|
||
.where(
|
||
Notification.action == NotificationAction.DELETE.value,
|
||
Notification.entity == NotificationEntity.REACTION.value,
|
||
)
|
||
.all()
|
||
)
|
||
exclude = set()
|
||
for nr in removed_reaction_notifications:
|
||
reaction = orjson.loads(str(nr.payload))
|
||
reaction_id = reaction.get("id")
|
||
exclude.add(reaction_id)
|
||
for n in new_reaction_notifications:
|
||
reaction = orjson.loads(str(n.payload))
|
||
reaction_id = reaction.get("id")
|
||
if (
|
||
reaction_id not in exclude
|
||
and reaction.get("shout") == shout_id
|
||
and reaction.get("reply_to") == reply_to_id
|
||
):
|
||
try:
|
||
ns = NotificationSeen(notification=n.id, viewer=author_id)
|
||
session.add(ns)
|
||
session.commit()
|
||
except Exception as e:
|
||
logger.warn(e)
|
||
session.rollback()
|
||
else:
|
||
error = "You are not logged in"
|
||
return {"error": error}
|