auth-check-middleware-3
All checks were successful
deploy / deploy (push) Successful in 1m9s

This commit is contained in:
Untone 2023-12-17 14:47:05 +03:00
parent 737cc40353
commit 693d8b6aee
2 changed files with 20 additions and 18 deletions

View File

@ -3,14 +3,12 @@ from sqlalchemy import and_, select
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from orm.author import Author
from orm.notification import Notification as NotificationMessage, NotificationSeen from orm.notification import Notification as NotificationMessage, NotificationSeen
from services.auth import check_auth from services.auth import LoginRequiredMiddleware
from services.db import local_session from services.db import local_session
import strawberry import strawberry
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from strawberry.schema.config import StrawberryConfig from strawberry.schema.config import StrawberryConfig
from strawberry.extensions import Extension
import logging import logging
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
@ -135,21 +133,6 @@ class Mutation:
return NotificationSeenResult() return NotificationSeenResult()
class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get("request")
is_authenticated, user_id = await check_auth(req)
if is_authenticated:
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context["author_id"] = author.id
if user_id:
context["user_id"] = user_id
context["user_id"] = user_id or None
schema = strawberry.Schema( schema = strawberry.Schema(
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware] query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
) )

View File

@ -1,5 +1,9 @@
import aiohttp import aiohttp
from aiohttp.web import HTTPUnauthorized from aiohttp.web import HTTPUnauthorized
from strawberry.extensions import Extension
from orm.author import Author
from services.db import local_session
from settings import AUTH_URL from settings import AUTH_URL
@ -61,3 +65,18 @@ async def check_auth(req) -> (bool, int | None):
raise HTTPUnauthorized(message="Please, login first") raise HTTPUnauthorized(message="Please, login first")
return False, None return False, None
class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get("request")
is_authenticated, user_id = await check_auth(req)
if is_authenticated:
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context["author_id"] = author.id
if user_id:
context["user_id"] = user_id
context["user_id"] = user_id or None