unmiddlewared

This commit is contained in:
Tony Rewin 2023-10-04 23:42:39 +03:00
parent 2db81462d0
commit e76f924b2d
10 changed files with 75 additions and 96 deletions

View File

@ -1,3 +1,8 @@
[0.2.10]
- middlwares removed
- orm removed
- added core api connector
[0.2.9] [0.2.9]
- starlette is back - starlette is back
- auth middleware - auth middleware

12
main.py
View File

@ -3,22 +3,12 @@ from os.path import exists
from ariadne import load_schema_from_path, make_executable_schema from ariadne import load_schema_from_path, make_executable_schema
from ariadne.asgi import GraphQL from ariadne.asgi import GraphQL
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
from services.auth import JWTAuthenticate
from services.redis import redis from services.redis import redis
from resolvers import resolvers from resolvers import resolvers
from settings import DEV_SERVER_PID_FILE_NAME, SENTRY_DSN, SESSION_SECRET_KEY, MODE from settings import DEV_SERVER_PID_FILE_NAME, SENTRY_DSN, MODE
schema = make_executable_schema(load_schema_from_path("schema.graphql"), resolvers) # type: ignore schema = make_executable_schema(load_schema_from_path("schema.graphql"), resolvers) # type: ignore
middleware = [
Middleware(AuthenticationMiddleware, backend=JWTAuthenticate()),
Middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY),
]
async def start_up(): async def start_up():
if MODE == "dev": if MODE == "dev":

View File

@ -1,6 +0,0 @@
from services.db import Base, engine
def init_tables():
Base.metadata.create_all(engine)
print("[orm] tables initialized")

View File

@ -1,41 +0,0 @@
from datetime import datetime
from sqlalchemy import JSON as JSONType
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from services.db import Base
class AuthorRating(Base):
__tablename__ = "author_rating"
id = None # type: ignore
rater = Column(ForeignKey("author.id"), primary_key=True, index=True)
author = Column(ForeignKey("author.id"), primary_key=True, index=True)
value = Column(Integer)
class AuthorFollower(Base):
__tablename__ = "author_follower"
id = None # type: ignore
follower = Column(ForeignKey("author.id"), primary_key=True, index=True)
author = Column(ForeignKey("author.id"), primary_key=True, index=True)
createdAt = Column(DateTime, nullable=False, default=datetime.now)
auto = Column(Boolean, nullable=False, default=False)
class Author(Base):
__tablename__ = "author"
user = Column(Integer, nullable=False) # unbounded link with authorizer's User type
bio = Column(String, nullable=True, comment="Bio") # status description
about = Column(String, nullable=True, comment="About") # long and formatted
userpic = Column(String, nullable=True, comment="Userpic")
name = Column(String, nullable=True, comment="Display name")
slug = Column(String, unique=True, comment="Author's slug")
muted = Column(Boolean, default=False)
createdAt = Column(DateTime, nullable=False, default=datetime.now)
lastSeen = Column(DateTime, nullable=False, default=datetime.now)
deletedAt = Column(DateTime, nullable=True, comment="Deleted at")
links = Column(JSONType, nullable=True, comment="Links")
ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author)

View File

@ -1,9 +1,9 @@
import json import json
from services.core import get_author
from services.db import local_session from services.db import local_session
from services.redis import redis from services.redis import redis
from resolvers import query from resolvers import query
from orm.author import Author
from services.auth import login_required from services.auth import login_required
from .chats import create_chat from .chats import create_chat
from .unread import get_unread_counter from .unread import get_unread_counter
@ -63,7 +63,7 @@ async def load_chats(_, info, limit: int = 50, offset: int = 0):
member_ids = c["members"].copy() member_ids = c["members"].copy()
c["members"] = [] c["members"] = []
for member_id in member_ids: for member_id in member_ids:
a = session.query(Author).where(Author.id == member_id).first() a = await get_author(member_id)
if a: if a:
c["members"].append( c["members"].append(
{ {

View File

@ -1,10 +1,9 @@
import json import json
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from services.auth import login_required from services.auth import login_required
from services.core import get_network
from services.redis import redis from services.redis import redis
from resolvers import query from resolvers import query
from services.db import local_session
from orm.author import AuthorFollower, Author
from resolvers.load import load_messages from resolvers.load import load_messages
@ -13,8 +12,8 @@ from resolvers.load import load_messages
async def search_recipients(_, info, text: str, limit: int = 50, offset: int = 0): async def search_recipients(_, info, text: str, limit: int = 50, offset: int = 0):
result = [] result = []
# TODO: maybe redis scan? # TODO: maybe redis scan?
author = info.context["author"] author_id = info.context["author_id"]
talk_before = await redis.execute("GET", f"/chats_by_author/{author.id}") talk_before = await redis.execute("GET", f"/chats_by_author/{author_id}")
if talk_before: if talk_before:
talk_before = list(json.loads(talk_before))[offset : (offset + limit)] talk_before = list(json.loads(talk_before))[offset : (offset + limit)]
for chat_id in talk_before: for chat_id in talk_before:
@ -27,25 +26,8 @@ async def search_recipients(_, info, text: str, limit: int = 50, offset: int = 0
result.append(member) result.append(member)
more_amount = limit - len(result) more_amount = limit - len(result)
if more_amount > 0:
with local_session() as session: result += await get_network(author_id, more_amount)
# followings
result += (
session.query(AuthorFollower.author)
.join(Author, Author.id == AuthorFollower.follower)
.where(Author.slug.startswith(text))
.offset(offset + len(result))
.limit(more_amount)
)
# followers
result += (
session.query(AuthorFollower.follower)
.join(Author, Author.id == AuthorFollower.author)
.where(Author.slug.startswith(text))
.offset(offset + len(result))
.limit(offset + len(result) + limit)
)
return {"members": list(result), "error": None} return {"members": list(result), "error": None}
@ -84,7 +66,7 @@ async def search_in_chats(_, info, by, limit, offset):
) )
messages_set.union(set(mmm)) messages_set.union(set(mmm))
messages_sorted = list(messages_set).sort()
return {"messages": messages_sorted, "error": None} return {"messages": messages_sorted, "error": None}

View File

@ -48,6 +48,7 @@ local_headers = [
def exception_handler(exception_type, exception, traceback, debug_hook=sys.excepthook): def exception_handler(exception_type, exception, traceback, debug_hook=sys.excepthook):
print(traceback)
print("%s: %s" % (exception_type.__name__, exception)) print("%s: %s" % (exception_type.__name__, exception))

View File

@ -11,9 +11,7 @@ from settings import AUTH_URL
from orm.author import Author from orm.author import Author
class AuthUser(BaseModel): INTERNAL_AUTH_SERVER = "v2.discours" in AUTH_URL
user_id: Optional[int]
username: Optional[str]
class AuthCredentials(BaseModel): class AuthCredentials(BaseModel):
@ -26,17 +24,14 @@ class AuthCredentials(BaseModel):
class JWTAuthenticate(AuthenticationBackend): class JWTAuthenticate(AuthenticationBackend):
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
logged_in, user_id = await check_auth(request) logged_in, user_id = await check_auth(request)
return ( return AuthCredentials(user_id=user_id, logged_in=logged_in), user_id
AuthCredentials(user_id=user_id, logged_in=logged_in),
AuthUser(user_id=user_id),
)
async def check_auth(req): async def check_auth(req):
token = req.headers.get("Authorization") token = req.headers.get("Authorization")
gql = ( gql = (
{"mutation": "{ getSession { user { id } } }"} {"mutation": "{ getSession { user { id } } }"}
if "v2" in AUTH_URL if INTERNAL_AUTH_SERVER
else {"query": "{ session { user { id } } }"} else {"query": "{ session { user { id } } }"}
) )
headers = {"Authorization": token, "Content-Type": "application/json"} headers = {"Authorization": token, "Content-Type": "application/json"}
@ -84,8 +79,9 @@ def auth_request(f):
if not is_authenticated: if not is_authenticated:
raise HTTPError("please, login first") raise HTTPError("please, login first")
else: else:
author_id = await author_id_by_user_id(user_id) req["author_id"] = (
req["author_id"] = author_id user_id if INTERNAL_AUTH_SERVER else await author_id_by_user_id(user_id)
)
return await f(*args, **kwargs) return await f(*args, **kwargs)
return decorated_function return decorated_function

54
services/core.py Normal file
View File

@ -0,0 +1,54 @@
from httpx import AsyncClient
from settings import API_BASE
async def get_author(author_id):
gql = {
"query": "{ getAuthor(author_id: %s) { id slug userpic name lastSeen } }"
% author_id
}
headers = {"Content-Type": "application/json"}
try:
async with AsyncClient() as client:
response = await client.post(API_BASE, headers=headers, data=gql)
if response.status_code != 200:
return False, None
r = response.json()
author = r.get("data", {}).get("getAuthor")
return author
except Exception:
pass
async def get_network(author_id, limit=50, offset=0):
headers = {"Content-Type": "application/json"}
gql = {
"query": "{ authorFollowings(author_id: %s, limit: %s, offset: %s) { id slug userpic name } }"
% (author_id, limit, offset)
}
followings = []
followers = []
try:
async with AsyncClient() as client:
response = await client.post(API_BASE, headers=headers, data=gql)
if response.status_code != 200:
return False, None
r = response.json()
followings = r.get("data", {}).get("authorFollowers", [])
more_amount = limit - len(followings)
if more_amount > 0:
gql = {
"query": "{ authorFollowers(author_id: %s, limit: %s) { id slug userpic name } }"
% (author_id, more_amount)
}
response = await client.post(API_BASE, headers=headers, data=gql)
if response.status_code != 200:
return False, None
r = response.json()
followers = r.get("data", {}).get("authorFollowers", [])
except Exception as e:
pass
return followings + followers

View File

@ -11,6 +11,4 @@ API_BASE = environ.get("API_BASE") or ""
AUTH_URL = environ.get("AUTH_URL") or "" AUTH_URL = environ.get("AUTH_URL") or ""
MODE = environ.get("MODE") or "production" MODE = environ.get("MODE") or "production"
SENTRY_DSN = environ.get("SENTRY_DSN") SENTRY_DSN = environ.get("SENTRY_DSN")
SESSION_SECRET_KEY = environ.get("SESSION_SECRET_KEY") or "!secret"
DEV_SERVER_PID_FILE_NAME = "dev-server.pid" DEV_SERVER_PID_FILE_NAME = "dev-server.pid"
SESSION_TOKEN_HEADER = "Authorization"