merged
All checks were successful
Deploy to v2 / deploy (push) Successful in 2m9s

This commit is contained in:
Untone 2024-02-18 10:47:07 +03:00
commit e942d82412
6 changed files with 122 additions and 112 deletions

View File

@ -1,3 +1,7 @@
[0.2.22]
- precommit installed
- granian asgi added
[0.2.19]
- versioning sync with core, inbox, presence
- fix: auth connection user_id trimming

13
main.py
View File

@ -16,22 +16,23 @@ from services.rediscache import redis
from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger('\t[main]\t')
logger = logging.getLogger("\t[main]\t")
logger.setLevel(logging.DEBUG)
async def start_up():
logger.info("[main] starting...")
await redis.connect()
task = asyncio.create_task(notifications_worker())
logger.info(task)
if MODE == 'dev':
if MODE == "dev":
if exists(DEV_SERVER_PID_FILE_NAME):
with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f:
with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f:
f.write(str(os.getpid()))
else:
logger.info("[main] production mode")
try:
import sentry_sdk
@ -46,7 +47,7 @@ async def start_up():
],
)
except Exception as e:
logger.error('sentry init error', e)
logger.error("sentry init error", e)
async def shutdown():
@ -54,4 +55,4 @@ async def shutdown():
app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown])
app.mount('/', GraphQL(schema, debug=True))
app.mount("/", GraphQL(schema, debug=True))

View File

@ -34,6 +34,7 @@ async def handle_notification(n: ServiceMessage, channel: str):
async def listen_task(pattern):
logger.info(f' listening {pattern} ...')
async for message_data, channel in redis.listen(pattern):
try:
if message_data:

View File

@ -8,7 +8,12 @@ from sqlalchemy import and_, select
from sqlalchemy.orm import aliased
from sqlalchemy.sql import not_
from orm.notification import Notification, NotificationAction, NotificationEntity, NotificationSeen
from orm.notification import (
Notification,
NotificationAction,
NotificationEntity,
NotificationSeen,
)
from resolvers.model import (
NotificationAuthor,
NotificationGroup,
@ -19,15 +24,22 @@ from resolvers.model import (
from services.db import local_session
logger = logging.getLogger('[resolvers.schema]')
logger = logging.getLogger("[resolvers.schema]")
logger.setLevel(logging.DEBUG)
def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[Tuple[Notification, bool]]]:
def query_notifications(
author_id: int, after: int = 0
) -> Tuple[int, int, List[Tuple[Notification, bool]]]:
notification_seen_alias = aliased(NotificationSeen)
query = select(Notification, notification_seen_alias.viewer.label('seen')).outerjoin(
query = select(
Notification, notification_seen_alias.viewer.label("seen")
).outerjoin(
NotificationSeen,
and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id),
and_(
NotificationSeen.viewer == author_id,
NotificationSeen.notification == Notification.id,
),
)
if after:
query = query.filter(Notification.created_at > after)
@ -36,7 +48,12 @@ def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[
with local_session() as session:
total = (
session.query(Notification)
.filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after))
.filter(
and_(
Notification.action == NotificationAction.CREATE.value,
Notification.created_at > after,
)
)
.count()
)
@ -63,7 +80,9 @@ def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[
def process_shout_notification(
notification: Notification, seen: bool
) -> Union[Tuple[str, NotificationGroup], None] | None:
if not isinstance(notification.payload, str) or not isinstance(notification.entity, str):
if not isinstance(notification.payload, str) or not isinstance(
notification.entity, str
):
return
payload = json.loads(notification.payload)
shout: NotificationShout = payload
@ -75,7 +94,7 @@ def process_shout_notification(
authors=shout.authors,
updated_at=shout.created_at,
reactions=[],
action='create',
action="create",
seen=seen,
)
return thread_id, group
@ -94,8 +113,8 @@ def process_reaction_notification(
reaction: NotificationReaction = payload
shout: NotificationShout = reaction.shout
thread_id = str(reaction.shout)
if reaction.kind == 'COMMENT' and reaction.reply_to:
thread_id += f'::{reaction.reply_to}'
if reaction.kind == "COMMENT" and reaction.reply_to:
thread_id += f"::{reaction.reply_to}"
group = NotificationGroup(
id=thread_id,
action=str(notification.action),
@ -116,15 +135,15 @@ def process_follower_notification(
return
payload = json.loads(notification.payload)
follower: NotificationAuthor = payload
thread_id = 'followers'
thread_id = "followers"
group = NotificationGroup(
id=thread_id,
authors=[follower],
updated_at=int(time.time()),
shout=None,
reactions=[],
entity='follower',
action='follow',
entity="follower",
action="follow",
seen=seen,
)
return thread_id, group
@ -133,6 +152,31 @@ def process_follower_notification(
async def get_notifications_grouped(
author_id: int, after: int = 0, limit: int = 10
) -> Tuple[Dict[str, NotificationGroup], int, int]:
"""
Retrieves notifications for a given author.
Args:
author_id (int): The ID of the author for whom notifications are retrieved.
after (int, optional): If provided, selects only notifications created after this timestamp will be considered.
limit (int, optional): The maximum number of groupa to retrieve.
offset (int, optional): Offset for pagination
Returns:
Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs and values are NotificationGroup objects, unread and total amounts.
This function queries the database to retrieve notifications for the specified author, considering optional filters.
The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup
containing information about the notifications within that thread.
NotificationGroup structure:
{
entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower').
updated_at: int, # Timestamp of the latest update in the thread.
shout: Optional[NotificationShout]
reactions: List[int], # List of reaction ids within the thread.
authors: List[NotificationAuthor], # List of authors involved in the thread.
}
"""
total, unread, notifications = query_notifications(author_id, after)
groups_by_thread: Dict[str, NotificationGroup] = {}
groups_amount = 0
@ -141,7 +185,7 @@ async def get_notifications_grouped(
if groups_amount >= limit:
break
if str(notification.entity) == 'shout' and str(notification.action) == 'create':
if str(notification.entity) == "shout" and str(notification.action) == "create":
result = process_shout_notification(notification, seen)
if result:
thread_id, group = result
@ -168,7 +212,7 @@ async def get_notifications_grouped(
groups_by_thread[thread_id] = group
groups_amount += 1
elif str(notification.entity) == 'follower':
elif str(notification.entity) == "follower":
result = process_follower_notification(notification, seen)
if result:
thread_id, group = result
@ -181,10 +225,18 @@ async def get_notifications_grouped(
@strawberry.type
class Query:
@strawberry.field
async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult:
author_id = info.context.get('author_id')
async def load_notifications(
self, info, after: int, limit: int = 50, offset: int = 0
) -> NotificationsResult:
author_id = info.context.get("author_id")
if author_id:
groups, unread, total = await get_notifications_grouped(author_id, after, limit)
notifications = sorted(groups.values(), key=lambda group: group.updated_at, reverse=True)
return NotificationsResult(notifications=notifications, total=total, unread=unread, error=None)
groups, unread, total = await get_notifications_grouped(
author_id, after, limit
)
notifications = sorted(
groups.values(), key=lambda group: group.updated_at, reverse=True
)
return NotificationsResult(
notifications=notifications, total=total, unread=unread, error=None
)
return NotificationsResult(notifications=[], total=0, unread=0, error=None)

View File

@ -1,71 +1,14 @@
import logging
import sys
from granian.constants import Interfaces
from granian.server import Granian
from settings import PORT
log_settings = {
'version': 1,
'disable_existing_loggers': True,
'formatters': {
'default': {
'()': 'uvicorn.logging.DefaultFormatter',
'fmt': '%(levelprefix)s %(message)s',
'use_colors': None,
},
'access': {
'()': 'uvicorn.logging.AccessFormatter',
'fmt': '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s',
},
},
'handlers': {
'default': {
'formatter': 'default',
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stderr',
},
'access': {
'formatter': 'access',
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stdout',
},
},
'loggers': {
'uvicorn': {'handlers': ['default'], 'level': 'INFO'},
'uvicorn.error': {'level': 'INFO', 'handlers': ['default'], 'propagate': True},
'uvicorn.access': {'handlers': ['access'], 'level': 'INFO', 'propagate': False},
},
}
local_headers = [
('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, HEAD'),
('Access-Control-Allow-Origin', 'https://localhost:3000'),
(
'Access-Control-Allow-Headers',
'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization',
),
('Access-Control-Expose-Headers', 'Content-Length,Content-Range'),
('Access-Control-Allow-Credentials', 'true'),
]
logger = logging.getLogger('[server] ')
logger.setLevel(logging.DEBUG)
def exception_handler(_et, exc, _tb):
logger.error(..., exc_info=(type(exc), exc, exc.__traceback__))
if __name__ == '__main__':
sys.excepthook = exception_handler
from granian.constants import Interfaces
from granian.server import Granian
print('[server] started')
if __name__ == "__main__":
print("[server] started")
granian_instance = Granian(
'main:app',
address='0.0.0.0', # noqa S104
"main:app",
address="0.0.0.0", # noqa S104
port=PORT,
workers=2,
threads=2,

View File

@ -8,53 +8,62 @@ from services.db import local_session
from settings import AUTH_URL
logger = logging.getLogger('\t[services.auth]\t')
logger = logging.getLogger("\t[services.auth]\t")
logger.setLevel(logging.DEBUG)
async def check_auth(req) -> str | None:
token = req.headers.get('Authorization')
user_id = ''
token = req.headers.get("Authorization")
user_id = ""
if token:
query_name = 'validate_jwt_token'
operation = 'ValidateToken'
query_name = "validate_jwt_token"
operation = "ValidateToken"
headers = {
'Content-Type': 'application/json',
"Content-Type": "application/json",
}
variables = {
'params': {
'token_type': 'access_token',
'token': token,
"params": {
"token_type": "access_token",
"token": token,
}
}
gql = {
'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}',
'variables': variables,
'operationName': operation,
"query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}",
"variables": variables,
"operationName": operation,
}
try:
# Asynchronous HTTP request to the authentication server
async with ClientSession() as session:
async with session.post(AUTH_URL, json=gql, headers=headers) as response:
print(f'[services.auth] HTTP Response {response.status} {await response.text()}')
async with session.post(
AUTH_URL, json=gql, headers=headers
) as response:
logger.debug(
f"HTTP Response {response.status} {await response.text()}"
)
if response.status == 200:
data = await response.json()
errors = data.get('errors')
errors = data.get("errors")
if errors:
print(f'[services.auth] errors: {errors}')
logger.error(f"errors: {errors}")
else:
user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub')
user_id = (
data.get("data", {})
.get(query_name, {})
.get("claims", {})
.get("sub")
)
if user_id:
print(f'[services.auth] got user_id: {user_id}')
logger.info(f"got user_id: {user_id}")
return user_id
except Exception as e:
import traceback
traceback.print_exc()
# Handling and logging exceptions during authentication check
print(f'[services.auth] Error {e}')
logger.error(f"Error {e}")
return None
@ -62,14 +71,14 @@ async def check_auth(req) -> str | None:
class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get('request')
req = context.get("request")
user_id = await check_auth(req)
if user_id:
context['user_id'] = user_id.strip()
context["user_id"] = user_id.strip()
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context['author_id'] = author.id
context['user_id'] = user_id or None
context["author_id"] = author.id
context["user_id"] = user_id or None
self.execution_context.context = context