diff --git a/resolvers/schema.py b/resolvers/schema.py index 669c6ac..ec1779a 100644 --- a/resolvers/schema.py +++ b/resolvers/schema.py @@ -3,14 +3,12 @@ from sqlalchemy import and_, select from sqlalchemy.orm import aliased from sqlalchemy.exc import SQLAlchemyError -from orm.author import Author from orm.notification import Notification as NotificationMessage, NotificationSeen -from services.auth import check_auth +from services.auth import LoginRequiredMiddleware from services.db import local_session import strawberry from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper from strawberry.schema.config import StrawberryConfig -from strawberry.extensions import Extension import logging strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() @@ -135,21 +133,6 @@ class Mutation: return NotificationSeenResult() -class LoginRequiredMiddleware(Extension): - async def on_request_start(self): - context = self.execution_context.context - req = context.get("request") - is_authenticated, user_id = await check_auth(req) - if is_authenticated: - with local_session() as session: - author = session.query(Author).filter(Author.user == user_id).first() - if author: - context["author_id"] = author.id - if user_id: - context["user_id"] = user_id - context["user_id"] = user_id or None - - schema = strawberry.Schema( query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware] ) diff --git a/services/auth.py b/services/auth.py index 010287f..515b0fa 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,5 +1,9 @@ import aiohttp from aiohttp.web import HTTPUnauthorized +from strawberry.extensions import Extension + +from orm.author import Author +from services.db import local_session from settings import AUTH_URL @@ -61,3 +65,18 @@ async def check_auth(req) -> (bool, int | None): raise HTTPUnauthorized(message="Please, login first") return False, None + + +class LoginRequiredMiddleware(Extension): + async def on_request_start(self): + context = self.execution_context.context + req = context.get("request") + is_authenticated, user_id = await check_auth(req) + if is_authenticated: + with local_session() as session: + author = session.query(Author).filter(Author.user == user_id).first() + if author: + context["author_id"] = author.id + if user_id: + context["user_id"] = user_id + context["user_id"] = user_id or None