auth-check-middleware
Some checks failed
deploy / deploy (push) Failing after 1m6s

This commit is contained in:
Untone 2023-12-17 14:42:04 +03:00
parent f061d8a523
commit 01d7935cbd
6 changed files with 49 additions and 16 deletions

View File

@ -9,6 +9,7 @@ from sentry_sdk.integrations.strawberry import StrawberryIntegration
from strawberry.asgi import GraphQL from strawberry.asgi import GraphQL
from starlette.applications import Starlette from starlette.applications import Starlette
from services.auth import TokenMiddleware
from services.rediscache import redis from services.rediscache import redis
from resolvers.listener import reactions_worker from resolvers.listener import reactions_worker
from resolvers.schema import schema from resolvers.schema import schema
@ -49,4 +50,5 @@ async def shutdown():
app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown]) app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown])
app.add_middleware(TokenMiddleware)
app.mount("/", GraphQL(schema, debug=True)) app.mount("/", GraphQL(schema, debug=True))

View File

@ -1,8 +1,8 @@
from enum import Enum as Enumeration from enum import Enum as Enumeration
import time import time
from sqlalchemy import Boolean, Column, Enum, Integer, ForeignKey, JSON as JSONType from sqlalchemy import Column, Enum, Integer, ForeignKey, JSON as JSONType
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
# from sqlalchemy.dialects.postgresql import JSONB
from orm.author import Author from orm.author import Author
from services.db import Base from services.db import Base

View File

@ -21,7 +21,8 @@ aiohttp = "^3.9.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^7.4.2" pytest = "^7.4.2"
black = { version = "^23.9.1", python = ">=3.12" } black = { version = "^23.12.0", python = ">=3.12" }
ruff = { version = "^0.1.8", python = ">=3.12" }
mypy = { version = "^1.7", python = ">=3.12" } mypy = { version = "^1.7", python = ">=3.12" }
setuptools = "^69.0.2" setuptools = "^69.0.2"
@ -74,3 +75,12 @@ executionEnvironments = []
python_version = "3.12" python_version = "3.12"
warn_unused_configs = true warn_unused_configs = true
plugins = ["mypy_sqlalchemy.plugin", "strawberry.ext.mypy_plugin"] plugins = ["mypy_sqlalchemy.plugin", "strawberry.ext.mypy_plugin"]
[tool.ruff]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E4", "E7", "E9", "F"]
ignore = []
line-length = 120
target-version = "py312"

View File

@ -2,12 +2,16 @@ from typing import List
from sqlalchemy import and_, select from sqlalchemy import and_, select
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from orm.author import Author
from orm.notification import Notification as NotificationMessage, NotificationSeen from orm.notification import Notification as NotificationMessage, NotificationSeen
from services.auth import login_required from services.auth import check_auth
from aiohttp.web import HTTPUnauthorized
from services.db import local_session from services.db import local_session
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
import strawberry import strawberry
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from strawberry.schema.config import StrawberryConfig from strawberry.schema.config import StrawberryConfig
from strawberry.extensions import Extension
import logging import logging
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
@ -70,7 +74,6 @@ def get_notifications(author_id: int, session, limit: int, offset: int) -> List[
@strawberry.type @strawberry.type
class Query: class Query:
@login_required
@strawberry.field @strawberry.field
async def load_notifications(self, info, limit: int = 50, offset: int = 0) -> NotificationsResult: async def load_notifications(self, info, limit: int = 50, offset: int = 0) -> NotificationsResult:
author_id = info.context.get("author_id") author_id = info.context.get("author_id")
@ -95,7 +98,6 @@ class Query:
@strawberry.type @strawberry.type
class Mutation: class Mutation:
@login_required
@strawberry.mutation @strawberry.mutation
async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult: async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult:
author_id = info.context.get("author_id") author_id = info.context.get("author_id")
@ -113,7 +115,6 @@ class Mutation:
return NotificationSeenResult(error="cant mark as read") return NotificationSeenResult(error="cant mark as read")
return NotificationSeenResult() return NotificationSeenResult()
@login_required
@strawberry.mutation @strawberry.mutation
async def mark_all_notifications_as_read(self, info) -> NotificationSeenResult: async def mark_all_notifications_as_read(self, info) -> NotificationSeenResult:
author_id = info.context.get("author_id") author_id = info.context.get("author_id")
@ -135,4 +136,23 @@ class Mutation:
return NotificationSeenResult() return NotificationSeenResult()
schema = strawberry.Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get("request")
is_authenticated, user_id = await check_auth(req)
if not is_authenticated:
raise HTTPUnauthorized(text="Please, login first")
else:
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context["author_id"] = author.id
if user_id:
context["user_id"] = user_id
context["user_id"] = user_id
schema = strawberry.Schema(
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
)

View File

@ -1,5 +1,6 @@
# from contextlib import contextmanager # from contextlib import contextmanager
from typing import Any, Callable, Dict, TypeVar from typing import Any, Callable, Dict, TypeVar
# from psycopg2.errors import UniqueViolation # from psycopg2.errors import UniqueViolation
from sqlalchemy import Column, Integer, create_engine from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base

View File

@ -55,9 +55,9 @@ class RedisCache:
while True: while True:
message = await pubsub.get_message() message = await pubsub.get_message()
if message and isinstance(message['data'], (str, bytes, bytearray)): if message and isinstance(message["data"], (str, bytes, bytearray)):
try: try:
yield json.loads(message['data']) yield json.loads(message["data"])
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}") print(f"Error decoding JSON: {e}")
await asyncio.sleep(0.1) await asyncio.sleep(0.1)