This commit is contained in:
parent
737cc40353
commit
693d8b6aee
|
@ -3,14 +3,12 @@ from sqlalchemy import and_, select
|
||||||
from sqlalchemy.orm import aliased
|
from sqlalchemy.orm import aliased
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from orm.author import Author
|
|
||||||
from orm.notification import Notification as NotificationMessage, NotificationSeen
|
from orm.notification import Notification as NotificationMessage, NotificationSeen
|
||||||
from services.auth import check_auth
|
from services.auth import LoginRequiredMiddleware
|
||||||
from services.db import local_session
|
from services.db import local_session
|
||||||
import strawberry
|
import strawberry
|
||||||
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
|
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
|
||||||
from strawberry.schema.config import StrawberryConfig
|
from strawberry.schema.config import StrawberryConfig
|
||||||
from strawberry.extensions import Extension
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
|
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
|
||||||
|
@ -135,21 +133,6 @@ class Mutation:
|
||||||
return NotificationSeenResult()
|
return NotificationSeenResult()
|
||||||
|
|
||||||
|
|
||||||
class LoginRequiredMiddleware(Extension):
|
|
||||||
async def on_request_start(self):
|
|
||||||
context = self.execution_context.context
|
|
||||||
req = context.get("request")
|
|
||||||
is_authenticated, user_id = await check_auth(req)
|
|
||||||
if is_authenticated:
|
|
||||||
with local_session() as session:
|
|
||||||
author = session.query(Author).filter(Author.user == user_id).first()
|
|
||||||
if author:
|
|
||||||
context["author_id"] = author.id
|
|
||||||
if user_id:
|
|
||||||
context["user_id"] = user_id
|
|
||||||
context["user_id"] = user_id or None
|
|
||||||
|
|
||||||
|
|
||||||
schema = strawberry.Schema(
|
schema = strawberry.Schema(
|
||||||
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
|
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp.web import HTTPUnauthorized
|
from aiohttp.web import HTTPUnauthorized
|
||||||
|
from strawberry.extensions import Extension
|
||||||
|
|
||||||
|
from orm.author import Author
|
||||||
|
from services.db import local_session
|
||||||
from settings import AUTH_URL
|
from settings import AUTH_URL
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,3 +65,18 @@ async def check_auth(req) -> (bool, int | None):
|
||||||
raise HTTPUnauthorized(message="Please, login first")
|
raise HTTPUnauthorized(message="Please, login first")
|
||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequiredMiddleware(Extension):
|
||||||
|
async def on_request_start(self):
|
||||||
|
context = self.execution_context.context
|
||||||
|
req = context.get("request")
|
||||||
|
is_authenticated, user_id = await check_auth(req)
|
||||||
|
if is_authenticated:
|
||||||
|
with local_session() as session:
|
||||||
|
author = session.query(Author).filter(Author.user == user_id).first()
|
||||||
|
if author:
|
||||||
|
context["author_id"] = author.id
|
||||||
|
if user_id:
|
||||||
|
context["user_id"] = user_id
|
||||||
|
context["user_id"] = user_id or None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user