refresh-token
This commit is contained in:
parent
786bd20275
commit
84600308ad
|
@ -2,48 +2,14 @@ from functools import wraps
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from graphql.type import GraphQLResolveInfo
|
from graphql.type import GraphQLResolveInfo
|
||||||
from jwt import DecodeError, ExpiredSignatureError
|
|
||||||
from starlette.authentication import AuthenticationBackend
|
from starlette.authentication import AuthenticationBackend
|
||||||
from starlette.requests import HTTPConnection
|
from starlette.requests import HTTPConnection
|
||||||
|
|
||||||
from auth.credentials import AuthCredentials, AuthUser
|
from auth.credentials import AuthCredentials, AuthUser
|
||||||
from auth.jwtcodec import JWTCodec
|
|
||||||
from auth.tokenstorage import TokenStorage
|
|
||||||
from base.exceptions import ExpiredToken, InvalidToken
|
|
||||||
from services.auth.users import UserStorage
|
from services.auth.users import UserStorage
|
||||||
from settings import SESSION_TOKEN_HEADER
|
from settings import SESSION_TOKEN_HEADER
|
||||||
|
from auth.tokenstorage import SessionToken
|
||||||
|
from base.exceptions import InvalidToken
|
||||||
class SessionToken:
|
|
||||||
@classmethod
|
|
||||||
async def verify(cls, token: str):
|
|
||||||
"""
|
|
||||||
Rules for a token to be valid.
|
|
||||||
1. token format is legal &&
|
|
||||||
token exists in redis database &&
|
|
||||||
token is not expired
|
|
||||||
2. token format is legal &&
|
|
||||||
token exists in redis database &&
|
|
||||||
token is expired &&
|
|
||||||
token is of specified type
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
print('[auth.authenticate] session token verify')
|
|
||||||
payload = JWTCodec.decode(token)
|
|
||||||
except ExpiredSignatureError:
|
|
||||||
payload = JWTCodec.decode(token, verify_exp=False)
|
|
||||||
if not await cls.get(payload.user_id, token):
|
|
||||||
raise ExpiredToken("Token signature has expired, please try again")
|
|
||||||
except DecodeError as e:
|
|
||||||
raise InvalidToken("token format error") from e
|
|
||||||
else:
|
|
||||||
if not await cls.get(payload.user_id, token):
|
|
||||||
raise ExpiredToken("Session token has expired, please login again")
|
|
||||||
return payload
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get(cls, uid, token):
|
|
||||||
return await TokenStorage.get(f"{uid}-{token}")
|
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthenticate(AuthenticationBackend):
|
class JWTAuthenticate(AuthenticationBackend):
|
||||||
|
@ -54,10 +20,18 @@ class JWTAuthenticate(AuthenticationBackend):
|
||||||
if SESSION_TOKEN_HEADER not in request.headers:
|
if SESSION_TOKEN_HEADER not in request.headers:
|
||||||
return AuthCredentials(scopes=[]), AuthUser(user_id=None)
|
return AuthCredentials(scopes=[]), AuthUser(user_id=None)
|
||||||
|
|
||||||
token = request.headers.get(SESSION_TOKEN_HEADER, "")
|
token = request.headers.get(SESSION_TOKEN_HEADER)
|
||||||
|
if not token:
|
||||||
|
print("[auth.authenticate] no token in header %s" % SESSION_TOKEN_HEADER)
|
||||||
|
return AuthCredentials(scopes=[], error_message=str("no token")), AuthUser(
|
||||||
|
user_id=None
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if len(token.split('.')) > 1:
|
||||||
payload = await SessionToken.verify(token)
|
payload = await SessionToken.verify(token)
|
||||||
|
else:
|
||||||
|
InvalidToken("please try again")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print("[auth.authenticate] session token verify error")
|
print("[auth.authenticate] session token verify error")
|
||||||
print(exc)
|
print(exc)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class AuthCredentials(BaseModel):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def permissions(self) -> List[Permission]:
|
async def permissions(self) -> List[Permission]:
|
||||||
if self.user_id is not None:
|
if self.user_id is None:
|
||||||
raise OperationNotAllowed("Please login first")
|
raise OperationNotAllowed("Please login first")
|
||||||
return NotImplemented()
|
return NotImplemented()
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,11 @@ from settings import JWT_ALGORITHM, JWT_SECRET_KEY
|
||||||
class JWTCodec:
|
class JWTCodec:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode(user: AuthInput, exp: datetime) -> str:
|
def encode(user: AuthInput, exp: datetime) -> str:
|
||||||
issued = datetime.now(tz=timezone.utc)
|
|
||||||
payload = {
|
payload = {
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"username": user.email or user.phone,
|
"username": user.email or user.phone,
|
||||||
"exp": exp,
|
"exp": exp,
|
||||||
"iat": issued,
|
"iat": datetime.now(tz=timezone.utc),
|
||||||
"iss": "discours"
|
"iss": "discours"
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -13,9 +13,30 @@ async def save(token_key, life_span, auto_delete=True):
|
||||||
await redis.execute("EXPIREAT", token_key, int(expire_at))
|
await redis.execute("EXPIREAT", token_key, int(expire_at))
|
||||||
|
|
||||||
|
|
||||||
|
class SessionToken:
|
||||||
|
@classmethod
|
||||||
|
async def verify(cls, token: str):
|
||||||
|
"""
|
||||||
|
Rules for a token to be valid.
|
||||||
|
- token format is legal
|
||||||
|
- token exists in redis database
|
||||||
|
- token is not expired
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return JWTCodec.decode(token)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(cls, uid, token):
|
||||||
|
return await TokenStorage.get(f"{uid}-{token}")
|
||||||
|
|
||||||
|
|
||||||
class TokenStorage:
|
class TokenStorage:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get(token_key):
|
async def get(token_key):
|
||||||
|
print('[tokenstorage.get] ' + token_key)
|
||||||
|
# 2041-eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoyMDQxLCJ1c2VybmFtZSI6ImFudG9uLnJld2luK3Rlc3QtbG9hZGNoYXRAZ21haWwuY29tIiwiZXhwIjoxNjcxNzgwNjE2LCJpYXQiOjE2NjkxODg2MTYsImlzcyI6ImRpc2NvdXJzIn0.Nml4oV6iMjMmc6xwM7lTKEZJKBXvJFEIZ-Up1C1rITQ
|
||||||
return await redis.execute("GET", token_key)
|
return await redis.execute("GET", token_key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -24,13 +24,24 @@ from settings import SESSION_TOKEN_HEADER
|
||||||
@mutation.field("refreshSession")
|
@mutation.field("refreshSession")
|
||||||
@login_required
|
@login_required
|
||||||
async def get_current_user(_, info):
|
async def get_current_user(_, info):
|
||||||
print('[resolvers.auth] get current user %s' % str(info))
|
|
||||||
user = info.context["request"].user
|
user = info.context["request"].user
|
||||||
|
# print(info.context["request"].headers)
|
||||||
|
old_token = info.context["request"].headers.get("Authorization")
|
||||||
user.lastSeen = datetime.now(tz=timezone.utc)
|
user.lastSeen = datetime.now(tz=timezone.utc)
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
session.add(user)
|
session.add(user)
|
||||||
session.commit()
|
session.commit()
|
||||||
token = await TokenStorage.create_session(user)
|
token = await TokenStorage.create_session(user)
|
||||||
|
print("[resolvers.auth] new session token created")
|
||||||
|
if old_token:
|
||||||
|
payload = await TokenStorage.get(str(user.id) + '-' + str(old_token))
|
||||||
|
if payload:
|
||||||
|
print("[resolvers.auth] got session from old token: %r" % payload)
|
||||||
|
return {
|
||||||
|
"token": token,
|
||||||
|
"user": user,
|
||||||
|
"news": await user_subscriptions(user.slug),
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"user": user,
|
"user": user,
|
||||||
|
|
|
@ -53,7 +53,6 @@ if __name__ == "__main__":
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
x = sys.argv[1]
|
x = sys.argv[1]
|
||||||
if x == "dev":
|
if x == "dev":
|
||||||
print("DEV MODE")
|
|
||||||
if os.path.exists(DEV_SERVER_STATUS_FILE_NAME):
|
if os.path.exists(DEV_SERVER_STATUS_FILE_NAME):
|
||||||
os.remove(DEV_SERVER_STATUS_FILE_NAME)
|
os.remove(DEV_SERVER_STATUS_FILE_NAME)
|
||||||
|
|
||||||
|
@ -67,6 +66,12 @@ if __name__ == "__main__":
|
||||||
("Access-Control-Expose-Headers", "Content-Length,Content-Range"),
|
("Access-Control-Expose-Headers", "Content-Length,Content-Range"),
|
||||||
("Access-Control-Allow-Credentials", "true"),
|
("Access-Control-Allow-Credentials", "true"),
|
||||||
]
|
]
|
||||||
|
want_reload = False
|
||||||
|
if "reload" in sys.argv:
|
||||||
|
print("MODE: DEV + RELOAD")
|
||||||
|
want_reload = True
|
||||||
|
else:
|
||||||
|
print("MODE: DEV")
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"main:dev_app",
|
"main:dev_app",
|
||||||
host="localhost",
|
host="localhost",
|
||||||
|
@ -75,7 +80,7 @@ if __name__ == "__main__":
|
||||||
# log_config=LOGGING_CONFIG,
|
# log_config=LOGGING_CONFIG,
|
||||||
log_level=None,
|
log_level=None,
|
||||||
access_log=False,
|
access_log=False,
|
||||||
reload=True
|
reload=want_reload
|
||||||
) # , ssl_keyfile="discours.key", ssl_certfile="discours.crt")
|
) # , ssl_keyfile="discours.key", ssl_certfile="discours.crt")
|
||||||
elif x == "migrate":
|
elif x == "migrate":
|
||||||
from migration import migrate
|
from migration import migrate
|
||||||
|
|
Loading…
Reference in New Issue
Block a user