Feature/notifications (#77)

feature - notifications

Co-authored-by: Igor Lobanov <igor.lobanov@onetwotrip.com>
This commit is contained in:
Ilya Y 2023-10-10 09:35:27 +03:00 committed by GitHub
parent 702219769a
commit 889f802429
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 412 additions and 305 deletions

View File

@ -36,7 +36,7 @@ class JWTCodec:
issuer="discours"
)
r = TokenPayload(**payload)
print('[auth.jwtcodec] debug token %r' % r)
# print('[auth.jwtcodec] debug token %r' % r)
return r
except jwt.InvalidIssuedAtError:
print('[auth.jwtcodec] invalid issued at: %r' % payload)

View File

@ -1,5 +1,4 @@
from ariadne import MutationType, QueryType, SubscriptionType, ScalarType
from ariadne import MutationType, QueryType, ScalarType
datetime_scalar = ScalarType("DateTime")
@ -11,5 +10,4 @@ def serialize_datetime(value):
query = QueryType()
mutation = MutationType()
subscription = SubscriptionType()
resolvers = [query, mutation, subscription, datetime_scalar]
resolvers = [query, mutation, datetime_scalar]

36
main.py
View File

@ -18,20 +18,18 @@ from base.resolvers import resolvers
from resolvers.auth import confirm_email_handler
from resolvers.upload import upload_handler
from services.main import storages_init
from services.notifications.notification_service import notification_service
from services.stat.viewed import ViewedStorage
from services.zine.gittask import GitTask
from settings import DEV_SERVER_PID_FILE_NAME, SENTRY_DSN
# from sse.transport import GraphQLSSEHandler
from services.inbox.presence import on_connect, on_disconnect
# from services.inbox.sse import sse_messages
from ariadne.asgi.handlers import GraphQLTransportWSHandler
# from services.zine.gittask import GitTask
from settings import DEV_SERVER_PID_FILE_NAME, SENTRY_DSN, SESSION_SECRET_KEY
from services.notifications.sse import sse_subscribe_handler
import_module("resolvers")
schema = make_executable_schema(load_schema_from_path("schema.graphql"), resolvers) # type: ignore
middleware = [
Middleware(AuthenticationMiddleware, backend=JWTAuthenticate()),
Middleware(SessionMiddleware, secret_key="!secret"),
Middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY),
]
@ -41,8 +39,11 @@ async def start_up():
await storages_init()
views_stat_task = asyncio.create_task(ViewedStorage().worker())
print(views_stat_task)
git_task = asyncio.create_task(GitTask.git_task_worker())
print(git_task)
# git_task = asyncio.create_task(GitTask.git_task_worker())
# print(git_task)
notification_service_task = asyncio.create_task(notification_service.worker())
print(notification_service_task)
try:
import sentry_sdk
sentry_sdk.init(SENTRY_DSN)
@ -71,7 +72,8 @@ routes = [
Route("/oauth/{provider}", endpoint=oauth_login),
Route("/oauth-authorize", endpoint=oauth_authorize),
Route("/confirm/{token}", endpoint=confirm_email_handler),
Route("/upload", endpoint=upload_handler, methods=['POST'])
Route("/upload", endpoint=upload_handler, methods=['POST']),
Route("/subscribe/{user_id}", endpoint=sse_subscribe_handler),
]
app = Starlette(
@ -83,14 +85,10 @@ app = Starlette(
)
app.mount("/", GraphQL(
schema,
debug=True,
websocket_handler=GraphQLTransportWSHandler(
on_connect=on_connect,
on_disconnect=on_disconnect
)
debug=True
))
dev_app = app = Starlette(
dev_app = Starlette(
debug=True,
on_startup=[dev_start_up],
on_shutdown=[shutdown],
@ -99,9 +97,5 @@ dev_app = app = Starlette(
)
dev_app.mount("/", GraphQL(
schema,
debug=True,
websocket_handler=GraphQLTransportWSHandler(
on_connect=on_connect,
on_disconnect=on_disconnect
)
debug=True
))

View File

@ -7,7 +7,18 @@ from orm.shout import Shout
from orm.topic import Topic, TopicFollower
from orm.user import User, UserRating
# NOTE: keep orm module isolated
def init_tables():
Base.metadata.create_all(engine)
Operation.init_table()
Resource.init_table()
User.init_table()
Community.init_table()
Role.init_table()
UserRating.init_table()
Shout.init_table()
print("[orm] tables initialized")
__all__ = [
"User",
@ -21,16 +32,5 @@ __all__ = [
"Notification",
"Reaction",
"UserRating",
"init_tables"
]
def init_tables():
Base.metadata.create_all(engine)
Operation.init_table()
Resource.init_table()
User.init_table()
Community.init_table()
Role.init_table()
UserRating.init_table()
Shout.init_table()
print("[orm] tables initialized")

View File

@ -1,5 +1,7 @@
from datetime import datetime
from sqlalchemy import Column, Enum, JSON, ForeignKey, DateTime, Boolean, Integer
from sqlalchemy import Column, Enum, ForeignKey, DateTime, Boolean, Integer
from sqlalchemy.dialects.postgresql import JSONB
from base.orm import Base
from enum import Enum as Enumeration
@ -18,5 +20,5 @@ class Notification(Base):
createdAt = Column(DateTime, nullable=False, default=datetime.now, index=True)
seen = Column(Boolean, nullable=False, default=False, index=True)
type = Column(Enum(NotificationType), nullable=False)
data = Column(JSON, nullable=True)
data = Column(JSONB, nullable=True)
occurrences = Column(Integer, default=1)

View File

@ -11,14 +11,12 @@ gql~=3.4.0
uvicorn>=0.18.3
pydantic>=1.10.2
passlib~=1.7.4
itsdangerous
authlib>=1.1.0
httpx>=0.23.0
psycopg2-binary
transliterate~=1.10.2
requests~=2.28.1
bcrypt>=4.0.0
websockets
bson~=0.5.10
flake8
DateTime~=4.7
@ -38,3 +36,4 @@ python-multipart~=0.0.6
alembic==1.11.3
Mako==1.2.4
MarkupSafe==2.1.3
sse-starlette=1.6.5

0
resetdb.sh Normal file → Executable file
View File

View File

@ -55,7 +55,6 @@ from resolvers.inbox.messages import (
create_message,
delete_message,
update_message,
message_generator,
mark_as_read
)
from resolvers.inbox.load import (
@ -65,56 +64,4 @@ from resolvers.inbox.load import (
)
from resolvers.inbox.search import search_recipients
__all__ = [
# auth
"login",
"register_by_email",
"is_email_used",
"confirm_email",
"auth_send_link",
"sign_out",
"get_current_user",
# zine.profile
"load_authors_by",
"rate_user",
"update_profile",
"get_authors_all",
# zine.load
"load_shout",
"load_shouts_by",
# zine.following
"follow",
"unfollow",
# create
"create_shout",
"update_shout",
"delete_shout",
"markdown_body",
# zine.topics
"topics_all",
"topics_by_community",
"topics_by_author",
"topic_follow",
"topic_unfollow",
"get_topic",
# zine.reactions
"reactions_follow",
"reactions_unfollow",
"create_reaction",
"update_reaction",
"delete_reaction",
"load_reactions_by",
# inbox
"load_chats",
"load_messages_by",
"create_chat",
"delete_chat",
"update_chat",
"create_message",
"delete_message",
"update_message",
"message_generator",
"mark_as_read",
"load_recipients",
"search_recipients"
]
from resolvers.notifications import load_notifications

View File

@ -6,7 +6,7 @@ from graphql.type import GraphQLResolveInfo
from auth.authenticate import login_required
from auth.credentials import AuthCredentials
from base.redis import redis
from base.resolvers import mutation, subscription
from base.resolvers import mutation
from services.following import FollowingManager, FollowingResult, Following
from validations.inbox import Message
@ -140,40 +140,3 @@ async def mark_as_read(_, info, chat_id: str, messages: [int]):
return {
"error": None
}
@subscription.source("newMessage")
async def message_generator(_, info: GraphQLResolveInfo):
print(f"[resolvers.messages] generator {info}")
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id
try:
user_following_chats = await redis.execute("GET", f"chats_by_user/{user_id}")
if user_following_chats:
user_following_chats = list(json.loads(user_following_chats)) # chat ids
else:
user_following_chats = []
tasks = []
updated = {}
for chat_id in user_following_chats:
chat = await redis.execute("GET", f"chats/{chat_id}")
updated[chat_id] = chat['updatedAt']
user_following_chats_sorted = sorted(user_following_chats, key=lambda x: updated[x], reverse=True)
for chat_id in user_following_chats_sorted:
following_chat = Following('chat', chat_id)
await FollowingManager.register('chat', following_chat)
chat_task = following_chat.queue.get()
tasks.append(chat_task)
while True:
msg = await asyncio.gather(*tasks)
yield msg
finally:
await FollowingManager.remove('chat', following_chat)
@subscription.field("newMessage")
@login_required
async def message_resolver(message: Message, info: Any):
return message

View File

@ -0,0 +1,84 @@
from sqlalchemy import select, desc, and_, update
from auth.credentials import AuthCredentials
from base.resolvers import query, mutation
from auth.authenticate import login_required
from base.orm import local_session
from orm import Notification
@query.field("loadNotifications")
@login_required
async def load_notifications(_, info, params=None):
if params is None:
params = {}
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id
limit = params.get('limit', 50)
offset = params.get('offset', 0)
q = select(Notification).where(
Notification.user == user_id
).order_by(desc(Notification.createdAt)).limit(limit).offset(offset)
with local_session() as session:
total_count = session.query(Notification).where(
Notification.user == user_id
).count()
total_unread_count = session.query(Notification).where(
and_(
Notification.user == user_id,
Notification.seen is False
)
).count()
notifications = session.execute(q).fetchall()
return {
"notifications": notifications,
"totalCount": total_count,
"totalUnreadCount": total_unread_count
}
@mutation.field("markNotificationAsRead")
@login_required
async def mark_notification_as_read(_, info, notification_id: int):
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id
with local_session() as session:
notification = session.query(Notification).where(
and_(Notification.id == notification_id, Notification.user == user_id)
).one()
notification.seen = True
session.commit()
return {}
@mutation.field("markAllNotificationsAsRead")
@login_required
async def mark_all_notifications_as_read(_, info):
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id
statement = update(Notification).where(
and_(
Notification.user == user_id,
Notification.seen == False
)
).values(seen=True)
with local_session() as session:
try:
session.execute(statement)
session.commit()
except Exception as e:
session.rollback()
print(f"[mark_all_notifications_as_read] error: {str(e)}")
return {}

View File

@ -1,6 +1,6 @@
import asyncio
from base.orm import local_session
from base.resolvers import mutation, subscription
from base.resolvers import mutation
from auth.authenticate import login_required
from auth.credentials import AuthCredentials
# from resolvers.community import community_follow, community_unfollow
@ -69,79 +69,3 @@ async def unfollow(_, info, what, slug):
return {"error": str(e)}
return {}
# by author and by topic
@subscription.source("newShout")
@login_required
async def shout_generator(_, info: GraphQLResolveInfo):
print(f"[resolvers.zine] shouts generator {info}")
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id
try:
tasks = []
with local_session() as session:
# notify new shout by followed authors
following_topics = session.query(TopicFollower).where(TopicFollower.follower == user_id).all()
for topic_id in following_topics:
following_topic = Following('topic', topic_id)
await FollowingManager.register('topic', following_topic)
following_topic_task = following_topic.queue.get()
tasks.append(following_topic_task)
# by followed topics
following_authors = session.query(AuthorFollower).where(
AuthorFollower.follower == user_id).all()
for author_id in following_authors:
following_author = Following('author', author_id)
await FollowingManager.register('author', following_author)
following_author_task = following_author.queue.get()
tasks.append(following_author_task)
# TODO: use communities
# by followed communities
# following_communities = session.query(CommunityFollower).where(
# CommunityFollower.follower == user_id).all()
# for community_id in following_communities:
# following_community = Following('community', author_id)
# await FollowingManager.register('community', following_community)
# following_community_task = following_community.queue.get()
# tasks.append(following_community_task)
while True:
shout = await asyncio.gather(*tasks)
yield shout
finally:
pass
@subscription.source("newReaction")
@login_required
async def reaction_generator(_, info):
print(f"[resolvers.zine] reactions generator {info}")
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id
try:
with local_session() as session:
followings = session.query(ShoutReactionsFollower.shout).where(
ShoutReactionsFollower.follower == user_id).unique()
# notify new reaction
tasks = []
for shout_id in followings:
following_shout = Following('shout', shout_id)
await FollowingManager.register('shout', following_shout)
following_author_task = following_shout.queue.get()
tasks.append(following_author_task)
while True:
reaction = await asyncio.gather(*tasks)
yield reaction
finally:
pass

View File

@ -183,6 +183,7 @@ async def load_shouts_by(_, info, options):
@query.field("loadDrafts")
@login_required
async def get_drafts(_, info):
auth: AuthCredentials = info.context["request"].auth
user_id = auth.user_id

View File

@ -10,6 +10,7 @@ from base.resolvers import mutation, query
from orm.reaction import Reaction, ReactionKind
from orm.shout import Shout, ShoutReactionsFollower
from orm.user import User
from services.notifications.notification_service import notification_service
def add_reaction_stat_columns(q):
@ -198,29 +199,32 @@ async def create_reaction(_, info, reaction):
r = Reaction.create(**reaction)
# Proposal accepting logix
if r.replyTo is not None and \
r.kind == ReactionKind.ACCEPT and \
auth.user_id in shout.dict()['authors']:
replied_reaction = session.query(Reaction).where(Reaction.id == r.replyTo).first()
if replied_reaction and replied_reaction.kind == ReactionKind.PROPOSE:
if replied_reaction.range:
old_body = shout.body
start, end = replied_reaction.range.split(':')
start = int(start)
end = int(end)
new_body = old_body[:start] + replied_reaction.body + old_body[end:]
shout.body = new_body
# TODO: update git version control
# # Proposal accepting logix
# FIXME: will break if there will be 2 proposals, will break if shout will be changed
# if r.replyTo is not None and \
# r.kind == ReactionKind.ACCEPT and \
# auth.user_id in shout.dict()['authors']:
# replied_reaction = session.query(Reaction).where(Reaction.id == r.replyTo).first()
# if replied_reaction and replied_reaction.kind == ReactionKind.PROPOSE:
# if replied_reaction.range:
# old_body = shout.body
# start, end = replied_reaction.range.split(':')
# start = int(start)
# end = int(end)
# new_body = old_body[:start] + replied_reaction.body + old_body[end:]
# shout.body = new_body
# # TODO: update git version control
session.add(r)
session.commit()
await notification_service.handle_new_reaction(r.id)
rdict = r.dict()
rdict['shout'] = shout.dict()
rdict['createdBy'] = author.dict()
# self-regulation mechanics
if check_to_hide(session, auth.user_id, r):
set_hidden(session, r.shout)
elif check_to_publish(session, auth.user_id, r):

View File

@ -179,7 +179,6 @@ type Mutation {
# user profile
rateUser(slug: String!, value: Int!): Result!
updateOnlineStatus: Result!
updateProfile(profile: ProfileInput!): Result!
# topics
@ -196,6 +195,9 @@ type Mutation {
# following
follow(what: FollowingEntity!, slug: String!): Result!
unfollow(what: FollowingEntity!, slug: String!): Result!
markNotificationAsRead(notification_id: Int!): Result!
markAllNotificationsAsRead: Result!
}
input MessagesBy {
@ -249,7 +251,17 @@ input ReactionBy {
days: Int # before
sort: String # how to sort, default createdAt
}
################################### Query
input NotificationsQueryParams {
limit: Int
offset: Int
}
type NotificationsQueryResult {
notifications: [Notification]!
totalCount: Int!
totalUnreadCount: Int!
}
type Query {
# inbox
@ -286,14 +298,8 @@ type Query {
topicsRandom(amount: Int): [Topic]!
topicsByCommunity(community: String!): [Topic]!
topicsByAuthor(author: String!): [Topic]!
}
############################################ Subscription
type Subscription {
newMessage: Message # new messages in inbox
newShout: Shout # personal feed new shout
newReaction: Reaction # new reactions to notify
loadNotifications(params: NotificationsQueryParams!): NotificationsQueryResult!
}
############################################ Entities

View File

@ -55,7 +55,7 @@ log_settings = {
local_headers = [
("Access-Control-Allow-Methods", "GET, POST, OPTIONS, HEAD"),
("Access-Control-Allow-Origin", "http://localhost:3000"),
("Access-Control-Allow-Origin", "https://localhost:3000"),
(
"Access-Control-Allow-Headers",
"DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization",

View File

@ -1,46 +0,0 @@
# from base.exceptions import Unauthorized
from auth.tokenstorage import SessionToken
from base.redis import redis
async def set_online_status(user_id, status):
if user_id:
if status:
await redis.execute("SADD", "users-online", user_id)
else:
await redis.execute("SREM", "users-online", user_id)
async def on_connect(req, params):
if not isinstance(params, dict):
req.scope["connection_params"] = {}
return
token = params.get('token')
if not token:
# raise Unauthorized("Please login")
return {
"error": "Please login first"
}
else:
payload = await SessionToken.verify(token)
if payload and payload.user_id:
req.scope["user_id"] = payload.user_id
await set_online_status(payload.user_id, True)
async def on_disconnect(req):
user_id = req.scope.get("user_id")
await set_online_status(user_id, False)
# FIXME: not used yet
def context_value(request):
context = {}
print(f"[inbox.presense] request debug: {request}")
if request.scope["type"] == "websocket":
# request is an instance of WebSocket
context.update(request.scope["connection_params"])
else:
context["token"] = request.META.get("authorization")
return context

View File

@ -1,22 +0,0 @@
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from graphql.type import GraphQLResolveInfo
from resolvers.inbox.messages import message_generator
# from base.exceptions import Unauthorized
# https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md
async def sse_messages(request: Request):
print(f'[SSE] request\n{request}\n')
info = GraphQLResolveInfo()
info.context['request'] = request.scope
user_id = request.scope['user'].user_id
if user_id:
event_generator = await message_generator(None, info)
return EventSourceResponse(event_generator)
else:
# raise Unauthorized("Please login")
return {
"error": "Please login first"
}

View File

@ -0,0 +1,137 @@
import asyncio
import json
from datetime import datetime, timezone
from sqlalchemy import and_
from base.orm import local_session
from orm import Reaction, Shout, Notification, User
from orm.notification import NotificationType
from orm.reaction import ReactionKind
from services.notifications.sse import connection_manager
def update_prev_notification(notification, user):
notification_data = json.loads(notification.data)
notification_data["users"] = [
user for user in notification_data["users"] if user['id'] != user.id
]
notification_data["users"].append({
"id": user.id,
"name": user.name
})
notification.data = json.dumps(notification_data, ensure_ascii=False)
notification.seen = False
notification.occurrences = notification.occurrences + 1
notification.createdAt = datetime.now(tz=timezone.utc)
class NewReactionNotificator:
def __init__(self, reaction_id):
self.reaction_id = reaction_id
async def run(self):
with local_session() as session:
reaction = session.query(Reaction).where(Reaction.id == self.reaction_id).one()
shout = session.query(Shout).where(Shout.id == reaction.shout).one()
user = session.query(User).where(User.id == reaction.createdBy).one()
notify_user_ids = []
if reaction.kind == ReactionKind.COMMENT:
parent_reaction = None
if reaction.replyTo:
parent_reaction = session.query(Reaction).where(Reaction.id == reaction.replyTo).one()
if parent_reaction.createdBy != reaction.createdBy:
prev_new_reply_notification = session.query(Notification).where(
and_(
Notification.user == shout.createdBy,
Notification.type == NotificationType.NEW_REPLY,
Notification.shout == shout.id,
Notification.reaction == parent_reaction.id
)
).first()
if prev_new_reply_notification:
update_prev_notification(prev_new_reply_notification, user)
else:
reply_notification_data = json.dumps({
"shout": {
"title": shout.title
},
"users": [
{"id": user.id, "name": user.name}
]
}, ensure_ascii=False)
reply_notification = Notification.create(**{
"user": parent_reaction.createdBy,
"type": NotificationType.NEW_REPLY.name,
"shout": shout.id,
"reaction": parent_reaction.id,
"data": reply_notification_data
})
session.add(reply_notification)
notify_user_ids.append(parent_reaction.createdBy)
if reaction.createdBy != shout.createdBy and (
parent_reaction is None or parent_reaction.createdBy != shout.createdBy
):
prev_new_comment_notification = session.query(Notification).where(
and_(
Notification.user == shout.createdBy,
Notification.type == NotificationType.NEW_COMMENT,
Notification.shout == shout.id
)
).first()
if prev_new_comment_notification:
update_prev_notification(prev_new_comment_notification, user)
else:
notification_data_string = json.dumps({
"shout": {
"title": shout.title
},
"users": [
{"id": user.id, "name": user.name}
]
}, ensure_ascii=False)
author_notification = Notification.create(**{
"user": shout.createdBy,
"type": NotificationType.NEW_COMMENT.name,
"shout": shout.id,
"data": notification_data_string
})
session.add(author_notification)
notify_user_ids.append(shout.createdBy)
session.commit()
for user_id in notify_user_ids:
await connection_manager.notify_user(user_id)
class NotificationService:
def __init__(self):
self._queue = asyncio.Queue()
async def handle_new_reaction(self, reaction_id):
notificator = NewReactionNotificator(reaction_id)
await self._queue.put(notificator)
async def worker(self):
while True:
notificator = await self._queue.get()
try:
await notificator.run()
except Exception as e:
print(f'[NotificationService.worker] error: {str(e)}')
notification_service = NotificationService()

View File

@ -0,0 +1,72 @@
import json
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
import asyncio
class ConnectionManager:
def __init__(self):
self.connections_by_user_id = {}
def add_connection(self, user_id, connection):
if user_id not in self.connections_by_user_id:
self.connections_by_user_id[user_id] = []
self.connections_by_user_id[user_id].append(connection)
def remove_connection(self, user_id, connection):
if user_id not in self.connections_by_user_id:
return
self.connections_by_user_id[user_id].remove(connection)
if len(self.connections_by_user_id[user_id]) == 0:
del self.connections_by_user_id[user_id]
async def notify_user(self, user_id):
if user_id not in self.connections_by_user_id:
return
for connection in self.connections_by_user_id[user_id]:
data = {
"type": "newNotifications"
}
data_string = json.dumps(data, ensure_ascii=False)
await connection.put(data_string)
async def broadcast(self, data: str):
for user_id in self.connections_by_user_id:
for connection in self.connections_by_user_id[user_id]:
await connection.put(data)
class Connection:
def __init__(self):
self._queue = asyncio.Queue()
async def put(self, data: str):
await self._queue.put(data)
async def listen(self):
data = await self._queue.get()
return data
connection_manager = ConnectionManager()
async def sse_subscribe_handler(request: Request):
user_id = int(request.path_params["user_id"])
connection = Connection()
connection_manager.add_connection(user_id, connection)
async def event_publisher():
try:
while True:
data = await connection.listen()
yield data
except asyncio.CancelledError as e:
connection_manager.remove_connection(user_id, connection)
raise e
return EventSourceResponse(event_publisher())

View File

@ -27,6 +27,7 @@ SHOUTS_REPO = "content"
SESSION_TOKEN_HEADER = "Authorization"
SENTRY_DSN = environ.get("SENTRY_DSN")
SESSION_SECRET_KEY = environ.get("SESSION_SECRET_KEY") or "!secret"
# for local development
DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'

43
test/test.json Normal file

File diff suppressed because one or more lines are too long