auth-check-middleware
Some checks failed
deploy / deploy (push) Failing after 1m6s

This commit is contained in:
Untone 2023-12-17 14:42:04 +03:00
parent f061d8a523
commit 01d7935cbd
6 changed files with 49 additions and 16 deletions

View File

@ -9,6 +9,7 @@ from sentry_sdk.integrations.strawberry import StrawberryIntegration
from strawberry.asgi import GraphQL
from starlette.applications import Starlette
from services.auth import TokenMiddleware
from services.rediscache import redis
from resolvers.listener import reactions_worker
from resolvers.schema import schema
@ -49,4 +50,5 @@ async def shutdown():
app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown])
app.add_middleware(TokenMiddleware)
app.mount("/", GraphQL(schema, debug=True))

View File

@ -1,8 +1,8 @@
from enum import Enum as Enumeration
import time
from sqlalchemy import Boolean, Column, Enum, Integer, ForeignKey, JSON as JSONType
from sqlalchemy import Column, Enum, Integer, ForeignKey, JSON as JSONType
from sqlalchemy.orm import relationship
# from sqlalchemy.dialects.postgresql import JSONB
from orm.author import Author
from services.db import Base

View File

@ -21,7 +21,8 @@ aiohttp = "^3.9.1"
[tool.poetry.dev-dependencies]
pytest = "^7.4.2"
black = { version = "^23.9.1", python = ">=3.12" }
black = { version = "^23.12.0", python = ">=3.12" }
ruff = { version = "^0.1.8", python = ">=3.12" }
mypy = { version = "^1.7", python = ">=3.12" }
setuptools = "^69.0.2"
@ -74,3 +75,12 @@ executionEnvironments = []
python_version = "3.12"
warn_unused_configs = true
plugins = ["mypy_sqlalchemy.plugin", "strawberry.ext.mypy_plugin"]
[tool.ruff]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E4", "E7", "E9", "F"]
ignore = []
line-length = 120
target-version = "py312"

View File

@ -2,12 +2,16 @@ from typing import List
from sqlalchemy import and_, select
from sqlalchemy.orm import aliased
from sqlalchemy.exc import SQLAlchemyError
from orm.author import Author
from orm.notification import Notification as NotificationMessage, NotificationSeen
from services.auth import login_required
from services.auth import check_auth
from aiohttp.web import HTTPUnauthorized
from services.db import local_session
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
import strawberry
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from strawberry.schema.config import StrawberryConfig
from strawberry.extensions import Extension
import logging
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
@ -70,7 +74,6 @@ def get_notifications(author_id: int, session, limit: int, offset: int) -> List[
@strawberry.type
class Query:
@login_required
@strawberry.field
async def load_notifications(self, info, limit: int = 50, offset: int = 0) -> NotificationsResult:
author_id = info.context.get("author_id")
@ -95,7 +98,6 @@ class Query:
@strawberry.type
class Mutation:
@login_required
@strawberry.mutation
async def mark_notification_as_read(self, info, notification_id: int) -> NotificationSeenResult:
author_id = info.context.get("author_id")
@ -113,7 +115,6 @@ class Mutation:
return NotificationSeenResult(error="cant mark as read")
return NotificationSeenResult()
@login_required
@strawberry.mutation
async def mark_all_notifications_as_read(self, info) -> NotificationSeenResult:
author_id = info.context.get("author_id")
@ -135,4 +136,23 @@ class Mutation:
return NotificationSeenResult()
schema = strawberry.Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False))
class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get("request")
is_authenticated, user_id = await check_auth(req)
if not is_authenticated:
raise HTTPUnauthorized(text="Please, login first")
else:
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context["author_id"] = author.id
if user_id:
context["user_id"] = user_id
context["user_id"] = user_id
schema = strawberry.Schema(
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
)

View File

@ -1,5 +1,6 @@
# from contextlib import contextmanager
from typing import Any, Callable, Dict, TypeVar
# from psycopg2.errors import UniqueViolation
from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
@ -18,7 +19,7 @@ REGISTRY: Dict[str, type] = {}
# @contextmanager
def local_session(src=""):
return Session(bind=engine, expire_on_commit=False)
# try:
# yield session
# session.commit()
@ -60,8 +61,8 @@ class Base(declarative_base()):
except Exception as e:
print(f"[services.db] Error dict: {e}")
return {}
def update(self, values: Dict[str, Any]) -> None:
for key, value in values.items():
if hasattr(self, key):
setattr(self, key, value)
for key, value in values.items():
if hasattr(self, key):
setattr(self, key, value)

View File

@ -55,9 +55,9 @@ class RedisCache:
while True:
message = await pubsub.get_message()
if message and isinstance(message['data'], (str, bytes, bytearray)):
if message and isinstance(message["data"], (str, bytes, bytearray)):
try:
yield json.loads(message['data'])
yield json.loads(message["data"])
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
await asyncio.sleep(0.1)