diff --git a/resolvers/community.py b/resolvers/community.py index 483a0199..ad29609d 100644 --- a/resolvers/community.py +++ b/resolvers/community.py @@ -63,13 +63,15 @@ def followed_communities(follower_id): def community_follow(follower_id, slug): try: with local_session() as session: - community = session.query(Community).where(Community.slug == slug).one() - cf = CommunityAuthor(author=follower_id, community=community.id) - session.add(cf) - session.commit() - return True + community = session.query(Community).where(Community.slug == slug).first() + if isinstance(community, Community): + cf = CommunityAuthor(author=follower_id, community=community.id) + session.add(cf) + session.commit() + return True except Exception: - return False + pass + return False # for mutation.field("unfollow") diff --git a/resolvers/reaction.py b/resolvers/reaction.py index 806b713f..c1ddc802 100644 --- a/resolvers/reaction.py +++ b/resolvers/reaction.py @@ -142,7 +142,7 @@ async def set_published(session, shout_id, approver_id): s = session.query(Shout).where(Shout.id == shout_id).first() s.published_at = int(time.time()) s.published_by = approver_id - s.visibility = ShoutVisibility.PUBLIC.value + Shout.update(s, {"visibility": ShoutVisibility.PUBLIC.value}) author = session.query(Author).filter(Author.id == s.created_by).first() if author: await add_user_role(str(author.user)) @@ -152,7 +152,7 @@ async def set_published(session, shout_id, approver_id): def set_hidden(session, shout_id): s = session.query(Shout).where(Shout.id == shout_id).first() - s.visibility = ShoutVisibility.COMMUNITY.value + Shout.update(s, {"visibility": ShoutVisibility.COMMUNITY.value}) session.add(s) session.commit() diff --git a/services/auth.py b/services/auth.py index 8106c825..479c79de 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,9 +1,9 @@ -from functools import wraps +from cachetools import TTLCache, cached import logging - -from aiohttp import ClientSession +import time from starlette.exceptions import HTTPException +from aiohttp import ClientSession from settings import AUTH_URL, AUTH_SECRET @@ -11,10 +11,11 @@ logging.basicConfig() logger = logging.getLogger("\t[services.auth]\t") logger.setLevel(logging.DEBUG) +# Define a TTLCache with a time-to-live of 100 seconds +token_cache = TTLCache(maxsize=99999, ttl=1799) -async def request_data(gql, headers = { "Content-Type": "application/json" }): +async def request_data(gql, headers={"Content-Type": "application/json"}): try: - # Asynchronous HTTP request to the authentication server async with ClientSession() as session: async with session.post(AUTH_URL, json=gql, headers=headers) as response: if response.status == 200: @@ -25,40 +26,46 @@ async def request_data(gql, headers = { "Content-Type": "application/json" }): else: return data except Exception as e: - # Handling and logging exceptions during authentication check logger.error(f"[services.auth] request_data error: {e}") return None +@cached(cache=token_cache) +async def user_id_from_token(token): + logger.error(f"[services.auth] checking auth token: {token}") + query_name = "validate_jwt_token" + operation = "ValidateToken" + variables = { + "params": { + "token_type": "access_token", + "token": token, + } + } + gql = { + "query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", + "variables": variables, + "operationName": operation, + } + data = await request_data(gql) + if data: + expires_in = data.get("data", {}).get(query_name, {}).get("claims", {}).get("expires_in") + user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub") + if expires_in is not None and user_id is not None: + if expires_in < 100: + # Token will expire soon, remove it from cache + token_cache.pop(token, None) + else: + expires_at = time.time() + expires_in + return user_id, expires_at async def check_auth(req) -> str | None: token = req.headers.get("Authorization") - user_id = "" - if token: - # Logging the authentication token - logger.error(f"[services.auth] checking auth token: {token}") - query_name = "validate_jwt_token" - operation = "ValidateToken" - variables = { - "params": { - "token_type": "access_token", - "token": token, - } - } - - gql = { - "query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", - "variables": variables, - "operationName": operation, - } - data = await request_data(gql) - if data: - user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub") + cached_result = await user_id_from_token(token) + if cached_result: + user_id, expires_at = cached_result + if expires_at > time.time(): return user_id - - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") - + raise HTTPException(status_code=401, detail="Unauthorized") async def add_user_role(user_id): logger.info(f"[services.auth] add author role for user_id: {user_id}") @@ -77,7 +84,6 @@ async def add_user_role(user_id): return user_id def login_required(f): - @wraps(f) async def decorated_function(*args, **kwargs): info = args[1] context = info.context @@ -89,9 +95,7 @@ def login_required(f): return decorated_function - def auth_request(f): - @wraps(f) async def decorated_function(*args, **kwargs): req = args[0] user_id = await check_auth(req)