cached-auth

This commit is contained in:
Untone 2024-01-23 21:34:51 +03:00
parent 987eb8c078
commit c41fe8b6c9
3 changed files with 48 additions and 42 deletions

View File

@ -63,12 +63,14 @@ def followed_communities(follower_id):
def community_follow(follower_id, slug): def community_follow(follower_id, slug):
try: try:
with local_session() as session: with local_session() as session:
community = session.query(Community).where(Community.slug == slug).one() community = session.query(Community).where(Community.slug == slug).first()
if isinstance(community, Community):
cf = CommunityAuthor(author=follower_id, community=community.id) cf = CommunityAuthor(author=follower_id, community=community.id)
session.add(cf) session.add(cf)
session.commit() session.commit()
return True return True
except Exception: except Exception:
pass
return False return False

View File

@ -142,7 +142,7 @@ async def set_published(session, shout_id, approver_id):
s = session.query(Shout).where(Shout.id == shout_id).first() s = session.query(Shout).where(Shout.id == shout_id).first()
s.published_at = int(time.time()) s.published_at = int(time.time())
s.published_by = approver_id s.published_by = approver_id
s.visibility = ShoutVisibility.PUBLIC.value Shout.update(s, {"visibility": ShoutVisibility.PUBLIC.value})
author = session.query(Author).filter(Author.id == s.created_by).first() author = session.query(Author).filter(Author.id == s.created_by).first()
if author: if author:
await add_user_role(str(author.user)) await add_user_role(str(author.user))
@ -152,7 +152,7 @@ async def set_published(session, shout_id, approver_id):
def set_hidden(session, shout_id): def set_hidden(session, shout_id):
s = session.query(Shout).where(Shout.id == shout_id).first() s = session.query(Shout).where(Shout.id == shout_id).first()
s.visibility = ShoutVisibility.COMMUNITY.value Shout.update(s, {"visibility": ShoutVisibility.COMMUNITY.value})
session.add(s) session.add(s)
session.commit() session.commit()

View File

@ -1,9 +1,9 @@
from functools import wraps from cachetools import TTLCache, cached
import logging import logging
import time
from aiohttp import ClientSession
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from aiohttp import ClientSession
from settings import AUTH_URL, AUTH_SECRET from settings import AUTH_URL, AUTH_SECRET
@ -11,10 +11,11 @@ logging.basicConfig()
logger = logging.getLogger("\t[services.auth]\t") logger = logging.getLogger("\t[services.auth]\t")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
# Define a TTLCache with a time-to-live of 100 seconds
token_cache = TTLCache(maxsize=99999, ttl=1799)
async def request_data(gql, headers = { "Content-Type": "application/json" }): async def request_data(gql, headers={"Content-Type": "application/json"}):
try: try:
# Asynchronous HTTP request to the authentication server
async with ClientSession() as session: async with ClientSession() as session:
async with session.post(AUTH_URL, json=gql, headers=headers) as response: async with session.post(AUTH_URL, json=gql, headers=headers) as response:
if response.status == 200: if response.status == 200:
@ -25,17 +26,11 @@ async def request_data(gql, headers = { "Content-Type": "application/json" }):
else: else:
return data return data
except Exception as e: except Exception as e:
# Handling and logging exceptions during authentication check
logger.error(f"[services.auth] request_data error: {e}") logger.error(f"[services.auth] request_data error: {e}")
return None return None
@cached(cache=token_cache)
async def user_id_from_token(token):
async def check_auth(req) -> str | None:
token = req.headers.get("Authorization")
user_id = ""
if token:
# Logging the authentication token
logger.error(f"[services.auth] checking auth token: {token}") logger.error(f"[services.auth] checking auth token: {token}")
query_name = "validate_jwt_token" query_name = "validate_jwt_token"
operation = "ValidateToken" operation = "ValidateToken"
@ -53,13 +48,25 @@ async def check_auth(req) -> str | None:
} }
data = await request_data(gql) data = await request_data(gql)
if data: if data:
expires_in = data.get("data", {}).get(query_name, {}).get("claims", {}).get("expires_in")
user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub") user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub")
if expires_in is not None and user_id is not None:
if expires_in < 100:
# Token will expire soon, remove it from cache
token_cache.pop(token, None)
else:
expires_at = time.time() + expires_in
return user_id, expires_at
async def check_auth(req) -> str | None:
token = req.headers.get("Authorization")
cached_result = await user_id_from_token(token)
if cached_result:
user_id, expires_at = cached_result
if expires_at > time.time():
return user_id return user_id
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
async def add_user_role(user_id): async def add_user_role(user_id):
logger.info(f"[services.auth] add author role for user_id: {user_id}") logger.info(f"[services.auth] add author role for user_id: {user_id}")
query_name = "_update_user" query_name = "_update_user"
@ -77,7 +84,6 @@ async def add_user_role(user_id):
return user_id return user_id
def login_required(f): def login_required(f):
@wraps(f)
async def decorated_function(*args, **kwargs): async def decorated_function(*args, **kwargs):
info = args[1] info = args[1]
context = info.context context = info.context
@ -89,9 +95,7 @@ def login_required(f):
return decorated_function return decorated_function
def auth_request(f): def auth_request(f):
@wraps(f)
async def decorated_function(*args, **kwargs): async def decorated_function(*args, **kwargs):
req = args[0] req = args[0]
user_id = await check_auth(req) user_id = await check_auth(req)