diff --git a/auth/authenticate.py b/auth/authenticate.py index 6a84214f..e8104fae 100644 --- a/auth/authenticate.py +++ b/auth/authenticate.py @@ -2,48 +2,14 @@ from functools import wraps from typing import Optional, Tuple from graphql.type import GraphQLResolveInfo -from jwt import DecodeError, ExpiredSignatureError from starlette.authentication import AuthenticationBackend from starlette.requests import HTTPConnection from auth.credentials import AuthCredentials, AuthUser -from auth.jwtcodec import JWTCodec -from auth.tokenstorage import TokenStorage -from base.exceptions import ExpiredToken, InvalidToken from services.auth.users import UserStorage from settings import SESSION_TOKEN_HEADER - - -class SessionToken: - @classmethod - async def verify(cls, token: str): - """ - Rules for a token to be valid. - 1. token format is legal && - token exists in redis database && - token is not expired - 2. token format is legal && - token exists in redis database && - token is expired && - token is of specified type - """ - try: - print('[auth.authenticate] session token verify') - payload = JWTCodec.decode(token) - except ExpiredSignatureError: - payload = JWTCodec.decode(token, verify_exp=False) - if not await cls.get(payload.user_id, token): - raise ExpiredToken("Token signature has expired, please try again") - except DecodeError as e: - raise InvalidToken("token format error") from e - else: - if not await cls.get(payload.user_id, token): - raise ExpiredToken("Session token has expired, please login again") - return payload - - @classmethod - async def get(cls, uid, token): - return await TokenStorage.get(f"{uid}-{token}") +from auth.tokenstorage import SessionToken +from base.exceptions import InvalidToken class JWTAuthenticate(AuthenticationBackend): @@ -54,10 +20,18 @@ class JWTAuthenticate(AuthenticationBackend): if SESSION_TOKEN_HEADER not in request.headers: return AuthCredentials(scopes=[]), AuthUser(user_id=None) - token = request.headers.get(SESSION_TOKEN_HEADER, "") + token = request.headers.get(SESSION_TOKEN_HEADER) + if not token: + print("[auth.authenticate] no token in header %s" % SESSION_TOKEN_HEADER) + return AuthCredentials(scopes=[], error_message=str("no token")), AuthUser( + user_id=None + ) try: - payload = await SessionToken.verify(token) + if len(token.split('.')) > 1: + payload = await SessionToken.verify(token) + else: + InvalidToken("please try again") except Exception as exc: print("[auth.authenticate] session token verify error") print(exc) diff --git a/auth/credentials.py b/auth/credentials.py index 5e2dfea9..401ae420 100644 --- a/auth/credentials.py +++ b/auth/credentials.py @@ -20,7 +20,7 @@ class AuthCredentials(BaseModel): return True async def permissions(self) -> List[Permission]: - if self.user_id is not None: + if self.user_id is None: raise OperationNotAllowed("Please login first") return NotImplemented() diff --git a/auth/jwtcodec.py b/auth/jwtcodec.py index 87bd2b5a..c2feacd3 100644 --- a/auth/jwtcodec.py +++ b/auth/jwtcodec.py @@ -8,12 +8,11 @@ from settings import JWT_ALGORITHM, JWT_SECRET_KEY class JWTCodec: @staticmethod def encode(user: AuthInput, exp: datetime) -> str: - issued = datetime.now(tz=timezone.utc) payload = { "user_id": user.id, "username": user.email or user.phone, "exp": exp, - "iat": issued, + "iat": datetime.now(tz=timezone.utc), "iss": "discours" } try: diff --git a/auth/tokenstorage.py b/auth/tokenstorage.py index ef6fa0d6..5c1b5b2d 100644 --- a/auth/tokenstorage.py +++ b/auth/tokenstorage.py @@ -13,9 +13,30 @@ async def save(token_key, life_span, auto_delete=True): await redis.execute("EXPIREAT", token_key, int(expire_at)) +class SessionToken: + @classmethod + async def verify(cls, token: str): + """ + Rules for a token to be valid. + - token format is legal + - token exists in redis database + - token is not expired + """ + try: + return JWTCodec.decode(token) + except Exception as e: + raise e + + @classmethod + async def get(cls, uid, token): + return await TokenStorage.get(f"{uid}-{token}") + + class TokenStorage: @staticmethod async def get(token_key): + print('[tokenstorage.get] ' + token_key) + # 2041-eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoyMDQxLCJ1c2VybmFtZSI6ImFudG9uLnJld2luK3Rlc3QtbG9hZGNoYXRAZ21haWwuY29tIiwiZXhwIjoxNjcxNzgwNjE2LCJpYXQiOjE2NjkxODg2MTYsImlzcyI6ImRpc2NvdXJzIn0.Nml4oV6iMjMmc6xwM7lTKEZJKBXvJFEIZ-Up1C1rITQ return await redis.execute("GET", token_key) @staticmethod diff --git a/resolvers/auth.py b/resolvers/auth.py index 54947f9a..648253c8 100644 --- a/resolvers/auth.py +++ b/resolvers/auth.py @@ -24,13 +24,24 @@ from settings import SESSION_TOKEN_HEADER @mutation.field("refreshSession") @login_required async def get_current_user(_, info): - print('[resolvers.auth] get current user %s' % str(info)) user = info.context["request"].user + # print(info.context["request"].headers) + old_token = info.context["request"].headers.get("Authorization") user.lastSeen = datetime.now(tz=timezone.utc) with local_session() as session: session.add(user) session.commit() token = await TokenStorage.create_session(user) + print("[resolvers.auth] new session token created") + if old_token: + payload = await TokenStorage.get(str(user.id) + '-' + str(old_token)) + if payload: + print("[resolvers.auth] got session from old token: %r" % payload) + return { + "token": token, + "user": user, + "news": await user_subscriptions(user.slug), + } return { "token": token, "user": user, diff --git a/server.py b/server.py index 2cb70715..265bfed6 100644 --- a/server.py +++ b/server.py @@ -53,7 +53,6 @@ if __name__ == "__main__": if len(sys.argv) > 1: x = sys.argv[1] if x == "dev": - print("DEV MODE") if os.path.exists(DEV_SERVER_STATUS_FILE_NAME): os.remove(DEV_SERVER_STATUS_FILE_NAME) @@ -67,6 +66,12 @@ if __name__ == "__main__": ("Access-Control-Expose-Headers", "Content-Length,Content-Range"), ("Access-Control-Allow-Credentials", "true"), ] + want_reload = False + if "reload" in sys.argv: + print("MODE: DEV + RELOAD") + want_reload = True + else: + print("MODE: DEV") uvicorn.run( "main:dev_app", host="localhost", @@ -75,7 +80,7 @@ if __name__ == "__main__": # log_config=LOGGING_CONFIG, log_level=None, access_log=False, - reload=True + reload=want_reload ) # , ssl_keyfile="discours.key", ssl_certfile="discours.crt") elif x == "migrate": from migration import migrate