diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 7075c4c..7bced8c 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,3 +1,7 @@ +[0.2.22] +- precommit installed +- granian asgi added + [0.2.19] - versioning sync with core, inbox, presence - fix: auth connection user_id trimming diff --git a/main.py b/main.py index 440c7d3..8d0c1ea 100644 --- a/main.py +++ b/main.py @@ -16,22 +16,23 @@ from services.rediscache import redis from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger('\t[main]\t') +logger = logging.getLogger("\t[main]\t") logger.setLevel(logging.DEBUG) async def start_up(): + logger.info("[main] starting...") await redis.connect() task = asyncio.create_task(notifications_worker()) logger.info(task) - if MODE == 'dev': + if MODE == "dev": if exists(DEV_SERVER_PID_FILE_NAME): - with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f: + with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f: f.write(str(os.getpid())) else: + logger.info("[main] production mode") try: import sentry_sdk @@ -46,7 +47,7 @@ async def start_up(): ], ) except Exception as e: - logger.error('sentry init error', e) + logger.error("sentry init error", e) async def shutdown(): @@ -54,4 +55,4 @@ async def shutdown(): app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown]) -app.mount('/', GraphQL(schema, debug=True)) +app.mount("/", GraphQL(schema, debug=True)) diff --git a/resolvers/listener.py b/resolvers/listener.py index 95b762c..00f826a 100644 --- a/resolvers/listener.py +++ b/resolvers/listener.py @@ -34,6 +34,7 @@ async def handle_notification(n: ServiceMessage, channel: str): async def listen_task(pattern): + logger.info(f' listening {pattern} ...') async for message_data, channel in redis.listen(pattern): try: if message_data: diff --git a/resolvers/load.py b/resolvers/load.py index 04de03b..d74ad49 100644 --- a/resolvers/load.py +++ b/resolvers/load.py @@ -8,7 +8,12 @@ from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.sql import not_ -from orm.notification import Notification, NotificationAction, NotificationEntity, NotificationSeen +from orm.notification import ( + Notification, + NotificationAction, + NotificationEntity, + NotificationSeen, +) from resolvers.model import ( NotificationAuthor, NotificationGroup, @@ -19,15 +24,22 @@ from resolvers.model import ( from services.db import local_session -logger = logging.getLogger('[resolvers.schema]') +logger = logging.getLogger("[resolvers.schema]") logger.setLevel(logging.DEBUG) -def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[Tuple[Notification, bool]]]: +def query_notifications( + author_id: int, after: int = 0 +) -> Tuple[int, int, List[Tuple[Notification, bool]]]: notification_seen_alias = aliased(NotificationSeen) - query = select(Notification, notification_seen_alias.viewer.label('seen')).outerjoin( + query = select( + Notification, notification_seen_alias.viewer.label("seen") + ).outerjoin( NotificationSeen, - and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id), + and_( + NotificationSeen.viewer == author_id, + NotificationSeen.notification == Notification.id, + ), ) if after: query = query.filter(Notification.created_at > after) @@ -36,7 +48,12 @@ def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[ with local_session() as session: total = ( session.query(Notification) - .filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)) + .filter( + and_( + Notification.action == NotificationAction.CREATE.value, + Notification.created_at > after, + ) + ) .count() ) @@ -63,7 +80,9 @@ def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[ def process_shout_notification( notification: Notification, seen: bool ) -> Union[Tuple[str, NotificationGroup], None] | None: - if not isinstance(notification.payload, str) or not isinstance(notification.entity, str): + if not isinstance(notification.payload, str) or not isinstance( + notification.entity, str + ): return payload = json.loads(notification.payload) shout: NotificationShout = payload @@ -75,7 +94,7 @@ def process_shout_notification( authors=shout.authors, updated_at=shout.created_at, reactions=[], - action='create', + action="create", seen=seen, ) return thread_id, group @@ -94,8 +113,8 @@ def process_reaction_notification( reaction: NotificationReaction = payload shout: NotificationShout = reaction.shout thread_id = str(reaction.shout) - if reaction.kind == 'COMMENT' and reaction.reply_to: - thread_id += f'::{reaction.reply_to}' + if reaction.kind == "COMMENT" and reaction.reply_to: + thread_id += f"::{reaction.reply_to}" group = NotificationGroup( id=thread_id, action=str(notification.action), @@ -116,15 +135,15 @@ def process_follower_notification( return payload = json.loads(notification.payload) follower: NotificationAuthor = payload - thread_id = 'followers' + thread_id = "followers" group = NotificationGroup( id=thread_id, authors=[follower], updated_at=int(time.time()), shout=None, reactions=[], - entity='follower', - action='follow', + entity="follower", + action="follow", seen=seen, ) return thread_id, group @@ -133,6 +152,31 @@ def process_follower_notification( async def get_notifications_grouped( author_id: int, after: int = 0, limit: int = 10 ) -> Tuple[Dict[str, NotificationGroup], int, int]: + """ + Retrieves notifications for a given author. + + Args: + author_id (int): The ID of the author for whom notifications are retrieved. + after (int, optional): If provided, selects only notifications created after this timestamp will be considered. + limit (int, optional): The maximum number of groupa to retrieve. + offset (int, optional): Offset for pagination + + Returns: + Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs and values are NotificationGroup objects, unread and total amounts. + + This function queries the database to retrieve notifications for the specified author, considering optional filters. + The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup + containing information about the notifications within that thread. + + NotificationGroup structure: + { + entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower'). + updated_at: int, # Timestamp of the latest update in the thread. + shout: Optional[NotificationShout] + reactions: List[int], # List of reaction ids within the thread. + authors: List[NotificationAuthor], # List of authors involved in the thread. + } + """ total, unread, notifications = query_notifications(author_id, after) groups_by_thread: Dict[str, NotificationGroup] = {} groups_amount = 0 @@ -141,7 +185,7 @@ async def get_notifications_grouped( if groups_amount >= limit: break - if str(notification.entity) == 'shout' and str(notification.action) == 'create': + if str(notification.entity) == "shout" and str(notification.action) == "create": result = process_shout_notification(notification, seen) if result: thread_id, group = result @@ -168,7 +212,7 @@ async def get_notifications_grouped( groups_by_thread[thread_id] = group groups_amount += 1 - elif str(notification.entity) == 'follower': + elif str(notification.entity) == "follower": result = process_follower_notification(notification, seen) if result: thread_id, group = result @@ -181,10 +225,18 @@ async def get_notifications_grouped( @strawberry.type class Query: @strawberry.field - async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult: - author_id = info.context.get('author_id') + async def load_notifications( + self, info, after: int, limit: int = 50, offset: int = 0 + ) -> NotificationsResult: + author_id = info.context.get("author_id") if author_id: - groups, unread, total = await get_notifications_grouped(author_id, after, limit) - notifications = sorted(groups.values(), key=lambda group: group.updated_at, reverse=True) - return NotificationsResult(notifications=notifications, total=total, unread=unread, error=None) + groups, unread, total = await get_notifications_grouped( + author_id, after, limit + ) + notifications = sorted( + groups.values(), key=lambda group: group.updated_at, reverse=True + ) + return NotificationsResult( + notifications=notifications, total=total, unread=unread, error=None + ) return NotificationsResult(notifications=[], total=0, unread=0, error=None) diff --git a/server.py b/server.py index 9591541..15fca66 100644 --- a/server.py +++ b/server.py @@ -1,71 +1,14 @@ -import logging -import sys - +from granian.constants import Interfaces +from granian.server import Granian from settings import PORT -log_settings = { - 'version': 1, - 'disable_existing_loggers': True, - 'formatters': { - 'default': { - '()': 'uvicorn.logging.DefaultFormatter', - 'fmt': '%(levelprefix)s %(message)s', - 'use_colors': None, - }, - 'access': { - '()': 'uvicorn.logging.AccessFormatter', - 'fmt': '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', - }, - }, - 'handlers': { - 'default': { - 'formatter': 'default', - 'class': 'logging.StreamHandler', - 'stream': 'ext://sys.stderr', - }, - 'access': { - 'formatter': 'access', - 'class': 'logging.StreamHandler', - 'stream': 'ext://sys.stdout', - }, - }, - 'loggers': { - 'uvicorn': {'handlers': ['default'], 'level': 'INFO'}, - 'uvicorn.error': {'level': 'INFO', 'handlers': ['default'], 'propagate': True}, - 'uvicorn.access': {'handlers': ['access'], 'level': 'INFO', 'propagate': False}, - }, -} - -local_headers = [ - ('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, HEAD'), - ('Access-Control-Allow-Origin', 'https://localhost:3000'), - ( - 'Access-Control-Allow-Headers', - 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization', - ), - ('Access-Control-Expose-Headers', 'Content-Length,Content-Range'), - ('Access-Control-Allow-Credentials', 'true'), -] - -logger = logging.getLogger('[server] ') -logger.setLevel(logging.DEBUG) - - -def exception_handler(_et, exc, _tb): - logger.error(..., exc_info=(type(exc), exc, exc.__traceback__)) - - -if __name__ == '__main__': - sys.excepthook = exception_handler - from granian.constants import Interfaces - from granian.server import Granian - - print('[server] started') +if __name__ == "__main__": + print("[server] started") granian_instance = Granian( - 'main:app', - address='0.0.0.0', # noqa S104 + "main:app", + address="0.0.0.0", # noqa S104 port=PORT, workers=2, threads=2, diff --git a/services/auth.py b/services/auth.py index b84a9bd..104ab30 100644 --- a/services/auth.py +++ b/services/auth.py @@ -8,53 +8,62 @@ from services.db import local_session from settings import AUTH_URL -logger = logging.getLogger('\t[services.auth]\t') +logger = logging.getLogger("\t[services.auth]\t") logger.setLevel(logging.DEBUG) async def check_auth(req) -> str | None: - token = req.headers.get('Authorization') - user_id = '' + token = req.headers.get("Authorization") + user_id = "" if token: - query_name = 'validate_jwt_token' - operation = 'ValidateToken' + query_name = "validate_jwt_token" + operation = "ValidateToken" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } variables = { - 'params': { - 'token_type': 'access_token', - 'token': token, + "params": { + "token_type": "access_token", + "token": token, } } gql = { - 'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}', - 'variables': variables, - 'operationName': operation, + "query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", + "variables": variables, + "operationName": operation, } try: # Asynchronous HTTP request to the authentication server async with ClientSession() as session: - async with session.post(AUTH_URL, json=gql, headers=headers) as response: - print(f'[services.auth] HTTP Response {response.status} {await response.text()}') + async with session.post( + AUTH_URL, json=gql, headers=headers + ) as response: + logger.debug( + f"HTTP Response {response.status} {await response.text()}" + ) if response.status == 200: data = await response.json() - errors = data.get('errors') + errors = data.get("errors") if errors: - print(f'[services.auth] errors: {errors}') + logger.error(f"errors: {errors}") else: - user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub') + user_id = ( + data.get("data", {}) + .get(query_name, {}) + .get("claims", {}) + .get("sub") + ) if user_id: - print(f'[services.auth] got user_id: {user_id}') + logger.info(f"got user_id: {user_id}") return user_id except Exception as e: import traceback traceback.print_exc() # Handling and logging exceptions during authentication check - print(f'[services.auth] Error {e}') + logger.error(f"Error {e}") return None @@ -62,14 +71,14 @@ async def check_auth(req) -> str | None: class LoginRequiredMiddleware(Extension): async def on_request_start(self): context = self.execution_context.context - req = context.get('request') + req = context.get("request") user_id = await check_auth(req) if user_id: - context['user_id'] = user_id.strip() + context["user_id"] = user_id.strip() with local_session() as session: author = session.query(Author).filter(Author.user == user_id).first() if author: - context['author_id'] = author.id - context['user_id'] = user_id or None + context["author_id"] = author.id + context["user_id"] = user_id or None self.execution_context.context = context