diff --git a/server/constants/env.go b/server/constants/env.go index f36f278..e36c5e3 100644 --- a/server/constants/env.go +++ b/server/constants/env.go @@ -43,6 +43,10 @@ const ( EnvKeyJwtType = "JWT_TYPE" // EnvKeyJwtSecret key for env variable JWT_SECRET EnvKeyJwtSecret = "JWT_SECRET" + // EnvKeyJwtPrivateKey key for env variable JWT_PRIVATE_KEY + EnvKeyJwtPrivateKey = "JWT_PRIVATE_KEY" + // EnvKeyJwtPublicKey key for env variable JWT_PUBLIC_KEY + EnvKeyJwtPublicKey = "JWT_PUBLIC_KEY" // EnvKeyAllowedOrigins key for env variable ALLOWED_ORIGINS EnvKeyAllowedOrigins = "ALLOWED_ORIGINS" // EnvKeyAppURL key for env variable APP_URL diff --git a/server/env/env.go b/server/env/env.go index 7e65e9f..9658e8b 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -19,7 +19,7 @@ func InitEnv() { envData := envstore.EnvInMemoryStoreObj.GetEnvStoreClone() if envData.StringEnv[constants.EnvKeyEnv] == "" { - envData.StringEnv[constants.EnvKeyEnv] = os.Getenv("ENV") + envData.StringEnv[constants.EnvKeyEnv] = os.Getenv(constants.EnvKeyEnv) if envData.StringEnv[constants.EnvKeyEnv] == "" { envData.StringEnv[constants.EnvKeyEnv] = "production" } @@ -50,18 +50,18 @@ func InitEnv() { } if envData.StringEnv[constants.EnvKeyPort] == "" { - envData.StringEnv[constants.EnvKeyPort] = os.Getenv("PORT") + envData.StringEnv[constants.EnvKeyPort] = os.Getenv(constants.EnvKeyPort) if envData.StringEnv[constants.EnvKeyPort] == "" { envData.StringEnv[constants.EnvKeyPort] = "8080" } } if envData.StringEnv[constants.EnvKeyAdminSecret] == "" { - envData.StringEnv[constants.EnvKeyAdminSecret] = os.Getenv("ADMIN_SECRET") + envData.StringEnv[constants.EnvKeyAdminSecret] = os.Getenv(constants.EnvKeyAdminSecret) } if envData.StringEnv[constants.EnvKeyDatabaseType] == "" { - envData.StringEnv[constants.EnvKeyDatabaseType] = os.Getenv("DATABASE_TYPE") + envData.StringEnv[constants.EnvKeyDatabaseType] = os.Getenv(constants.EnvKeyDatabaseType) if envstore.ARG_DB_TYPE != nil && *envstore.ARG_DB_TYPE != "" { envData.StringEnv[constants.EnvKeyDatabaseType] = *envstore.ARG_DB_TYPE @@ -73,7 +73,7 @@ func InitEnv() { } if envData.StringEnv[constants.EnvKeyDatabaseURL] == "" { - envData.StringEnv[constants.EnvKeyDatabaseURL] = os.Getenv("DATABASE_URL") + envData.StringEnv[constants.EnvKeyDatabaseURL] = os.Getenv(constants.EnvKeyDatabaseURL) if envstore.ARG_DB_URL != nil && *envstore.ARG_DB_URL != "" { envData.StringEnv[constants.EnvKeyDatabaseURL] = *envstore.ARG_DB_URL @@ -85,48 +85,56 @@ func InitEnv() { } if envData.StringEnv[constants.EnvKeyDatabaseName] == "" { - envData.StringEnv[constants.EnvKeyDatabaseName] = os.Getenv("DATABASE_NAME") + envData.StringEnv[constants.EnvKeyDatabaseName] = os.Getenv(constants.EnvKeyDatabaseName) if envData.StringEnv[constants.EnvKeyDatabaseName] == "" { envData.StringEnv[constants.EnvKeyDatabaseName] = "authorizer" } } if envData.StringEnv[constants.EnvKeySmtpHost] == "" { - envData.StringEnv[constants.EnvKeySmtpHost] = os.Getenv("SMTP_HOST") + envData.StringEnv[constants.EnvKeySmtpHost] = os.Getenv(constants.EnvKeySmtpHost) } if envData.StringEnv[constants.EnvKeySmtpPort] == "" { - envData.StringEnv[constants.EnvKeySmtpPort] = os.Getenv("SMTP_PORT") + envData.StringEnv[constants.EnvKeySmtpPort] = os.Getenv(constants.EnvKeySmtpPort) } if envData.StringEnv[constants.EnvKeySmtpUsername] == "" { - envData.StringEnv[constants.EnvKeySmtpUsername] = os.Getenv("SMTP_USERNAME") + envData.StringEnv[constants.EnvKeySmtpUsername] = os.Getenv(constants.EnvKeySmtpUsername) } if envData.StringEnv[constants.EnvKeySmtpPassword] == "" { - envData.StringEnv[constants.EnvKeySmtpPassword] = os.Getenv("SMTP_PASSWORD") + envData.StringEnv[constants.EnvKeySmtpPassword] = os.Getenv(constants.EnvKeySmtpPassword) } if envData.StringEnv[constants.EnvKeySenderEmail] == "" { - envData.StringEnv[constants.EnvKeySenderEmail] = os.Getenv("SENDER_EMAIL") + envData.StringEnv[constants.EnvKeySenderEmail] = os.Getenv(constants.EnvKeySenderEmail) } if envData.StringEnv[constants.EnvKeyJwtSecret] == "" { - envData.StringEnv[constants.EnvKeyJwtSecret] = os.Getenv("JWT_SECRET") + envData.StringEnv[constants.EnvKeyJwtSecret] = os.Getenv(constants.EnvKeyJwtSecret) if envData.StringEnv[constants.EnvKeyJwtSecret] == "" { envData.StringEnv[constants.EnvKeyJwtSecret] = uuid.New().String() } } + if envData.StringEnv[constants.EnvKeyJwtPrivateKey] == "" { + envData.StringEnv[constants.EnvKeyJwtPrivateKey] = os.Getenv(constants.EnvKeyJwtPrivateKey) + } + + if envData.StringEnv[constants.EnvKeyJwtPublicKey] == "" { + envData.StringEnv[constants.EnvKeyJwtPublicKey] = os.Getenv(constants.EnvKeyJwtPublicKey) + } + if envData.StringEnv[constants.EnvKeyJwtType] == "" { - envData.StringEnv[constants.EnvKeyJwtType] = os.Getenv("JWT_TYPE") + envData.StringEnv[constants.EnvKeyJwtType] = os.Getenv(constants.EnvKeyJwtType) if envData.StringEnv[constants.EnvKeyJwtType] == "" { envData.StringEnv[constants.EnvKeyJwtType] = "HS256" } } if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" { - envData.StringEnv[constants.EnvKeyJwtRoleClaim] = os.Getenv("JWT_ROLE_CLAIM") + envData.StringEnv[constants.EnvKeyJwtRoleClaim] = os.Getenv(constants.EnvKeyJwtRoleClaim) if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" { envData.StringEnv[constants.EnvKeyJwtRoleClaim] = "role" @@ -134,48 +142,48 @@ func InitEnv() { } if envData.StringEnv[constants.EnvKeyRedisURL] == "" { - envData.StringEnv[constants.EnvKeyRedisURL] = os.Getenv("REDIS_URL") + envData.StringEnv[constants.EnvKeyRedisURL] = os.Getenv(constants.EnvKeyRedisURL) } if envData.StringEnv[constants.EnvKeyCookieName] == "" { - envData.StringEnv[constants.EnvKeyCookieName] = os.Getenv("COOKIE_NAME") + envData.StringEnv[constants.EnvKeyCookieName] = os.Getenv(constants.EnvKeyCookieName) if envData.StringEnv[constants.EnvKeyCookieName] == "" { envData.StringEnv[constants.EnvKeyCookieName] = "authorizer" } } if envData.StringEnv[constants.EnvKeyGoogleClientID] == "" { - envData.StringEnv[constants.EnvKeyGoogleClientID] = os.Getenv("GOOGLE_CLIENT_ID") + envData.StringEnv[constants.EnvKeyGoogleClientID] = os.Getenv(constants.EnvKeyGoogleClientID) } if envData.StringEnv[constants.EnvKeyGoogleClientSecret] == "" { - envData.StringEnv[constants.EnvKeyGoogleClientSecret] = os.Getenv("GOOGLE_CLIENT_SECRET") + envData.StringEnv[constants.EnvKeyGoogleClientSecret] = os.Getenv(constants.EnvKeyGoogleClientSecret) } if envData.StringEnv[constants.EnvKeyGithubClientID] == "" { - envData.StringEnv[constants.EnvKeyGithubClientID] = os.Getenv("GITHUB_CLIENT_ID") + envData.StringEnv[constants.EnvKeyGithubClientID] = os.Getenv(constants.EnvKeyGithubClientID) } if envData.StringEnv[constants.EnvKeyGithubClientSecret] == "" { - envData.StringEnv[constants.EnvKeyGithubClientSecret] = os.Getenv("GITHUB_CLIENT_SECRET") + envData.StringEnv[constants.EnvKeyGithubClientSecret] = os.Getenv(constants.EnvKeyGithubClientSecret) } if envData.StringEnv[constants.EnvKeyFacebookClientID] == "" { - envData.StringEnv[constants.EnvKeyFacebookClientID] = os.Getenv("FACEBOOK_CLIENT_ID") + envData.StringEnv[constants.EnvKeyFacebookClientID] = os.Getenv(constants.EnvKeyFacebookClientID) } if envData.StringEnv[constants.EnvKeyFacebookClientSecret] == "" { - envData.StringEnv[constants.EnvKeyFacebookClientSecret] = os.Getenv("FACEBOOK_CLIENT_SECRET") + envData.StringEnv[constants.EnvKeyFacebookClientSecret] = os.Getenv(constants.EnvKeyFacebookClientSecret) } if envData.StringEnv[constants.EnvKeyResetPasswordURL] == "" { - envData.StringEnv[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/") + envData.StringEnv[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv(constants.EnvKeyResetPasswordURL), "/") } - envData.BoolEnv[constants.EnvKeyDisableBasicAuthentication] = os.Getenv("DISABLE_BASIC_AUTHENTICATION") == "true" - envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = os.Getenv("DISABLE_EMAIL_VERIFICATION") == "true" - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv("DISABLE_MAGIC_LINK_LOGIN") == "true" - envData.BoolEnv[constants.EnvKeyDisableLoginPage] = os.Getenv("DISABLE_LOGIN_PAGE") == "true" + envData.BoolEnv[constants.EnvKeyDisableBasicAuthentication] = os.Getenv(constants.EnvKeyDisableBasicAuthentication) == "true" + envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = os.Getenv(constants.EnvKeyDisableEmailVerification) == "true" + envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv(constants.EnvKeyDisableMagicLinkLogin) == "true" + envData.BoolEnv[constants.EnvKeyDisableLoginPage] = os.Getenv(constants.EnvKeyDisableLoginPage) == "true" // no need to add nil check as its already done above if envData.StringEnv[constants.EnvKeySmtpHost] == "" || envData.StringEnv[constants.EnvKeySmtpUsername] == "" || envData.StringEnv[constants.EnvKeySmtpPassword] == "" || envData.StringEnv[constants.EnvKeySenderEmail] == "" && envData.StringEnv[constants.EnvKeySmtpPort] == "" { @@ -187,7 +195,7 @@ func InitEnv() { envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true } - allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",") + allowedOriginsSplit := strings.Split(os.Getenv(constants.EnvKeyAllowedOrigins), ",") allowedOrigins := []string{} hasWildCard := false @@ -215,14 +223,14 @@ func InitEnv() { envData.SliceEnv[constants.EnvKeyAllowedOrigins] = allowedOrigins - rolesEnv := strings.TrimSpace(os.Getenv("ROLES")) + rolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyRoles)) rolesSplit := strings.Split(rolesEnv, ",") roles := []string{} if len(rolesEnv) == 0 { roles = []string{"user"} } - defaultRolesEnv := strings.TrimSpace(os.Getenv("DEFAULT_ROLES")) + defaultRolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyDefaultRoles)) defaultRoleSplit := strings.Split(defaultRolesEnv, ",") defaultRoles := []string{} @@ -230,7 +238,7 @@ func InitEnv() { defaultRoles = []string{"user"} } - protectedRolesEnv := strings.TrimSpace(os.Getenv("PROTECTED_ROLES")) + protectedRolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyProtectedRoles)) protectedRolesSplit := strings.Split(protectedRolesEnv, ",") protectedRoles := []string{} @@ -259,12 +267,12 @@ func InitEnv() { envData.SliceEnv[constants.EnvKeyDefaultRoles] = defaultRoles envData.SliceEnv[constants.EnvKeyProtectedRoles] = protectedRoles - if os.Getenv("ORGANIZATION_NAME") != "" { - envData.StringEnv[constants.EnvKeyOrganizationName] = os.Getenv("ORGANIZATION_NAME") + if os.Getenv(constants.EnvKeyOrganizationName) != "" { + envData.StringEnv[constants.EnvKeyOrganizationName] = os.Getenv(constants.EnvKeyOrganizationName) } - if os.Getenv("ORGANIZATION_LOGO") != "" { - envData.StringEnv[constants.EnvKeyOrganizationLogo] = os.Getenv("ORGANIZATION_LOGO") + if os.Getenv(constants.EnvKeyOrganizationLogo) != "" { + envData.StringEnv[constants.EnvKeyOrganizationLogo] = os.Getenv(constants.EnvKeyOrganizationLogo) } envstore.EnvInMemoryStoreObj.UpdateEnvStore(envData) diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index a2d7644..9aecff4 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -33,13 +33,13 @@ func VerifyEmailHandler() gin.HandlerFunc { } // verify if token exists in db - claim, err := token.VerifyVerificationToken(tokenInQuery) + claim, err := token.ParseJWTToken(tokenInQuery) if err != nil { c.JSON(400, errorRes) return } - user, err := db.Provider.GetUserByEmail(claim.Email) + user, err := db.Provider.GetUserByEmail(claim["email"].(string)) if err != nil { c.JSON(400, gin.H{ "message": err.Error(), @@ -68,6 +68,6 @@ func VerifyEmailHandler() gin.HandlerFunc { cookie.SetCookie(c, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) utils.SaveSessionInDB(user.ID, c) - c.Redirect(http.StatusTemporaryRedirect, claim.RedirectURL) + c.Redirect(http.StatusTemporaryRedirect, claim["redirect_url"].(string)) } } diff --git a/server/resolvers/is_valid_jwt.go b/server/resolvers/is_valid_jwt.go index 7061dcf..f9e33e3 100644 --- a/server/resolvers/is_valid_jwt.go +++ b/server/resolvers/is_valid_jwt.go @@ -26,7 +26,7 @@ func IsValidJwtResolver(ctx context.Context, params *model.IsValidJWTQueryInput) } } - claims, err := tokenHelper.VerifyJWTToken(token) + claims, err := tokenHelper.ParseJWTToken(token) if err != nil { return nil, err } diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index 0926e70..4653896 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -38,7 +38,7 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) { fingerPrint := string(decryptedFingerPrint) // verify refresh token and fingerprint - claims, err := token.VerifyJWTToken(refreshToken) + claims, err := token.ParseJWTToken(refreshToken) if err != nil { return res, err } diff --git a/server/resolvers/reset_password.go b/server/resolvers/reset_password.go index cc482b8..d48ae4b 100644 --- a/server/resolvers/reset_password.go +++ b/server/resolvers/reset_password.go @@ -31,12 +31,12 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput) } // verify if token exists in db - claim, err := token.VerifyVerificationToken(params.Token) + claim, err := token.ParseJWTToken(params.Token) if err != nil { return res, fmt.Errorf(`invalid token`) } - user, err := db.Provider.GetUserByEmail(claim.Email) + user, err := db.Provider.GetUserByEmail(claim["email"].(string)) if err != nil { return res, err } diff --git a/server/resolvers/session.go b/server/resolvers/session.go index 175fcf9..e390b5a 100644 --- a/server/resolvers/session.go +++ b/server/resolvers/session.go @@ -41,7 +41,7 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod fingerPrint := string(decryptedFingerPrint) // verify refresh token and fingerprint - claims, err := token.VerifyJWTToken(refreshToken) + claims, err := token.ParseJWTToken(refreshToken) if err != nil { return res, err } diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index 7b57950..883ba08 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -28,12 +28,12 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m } // verify if token exists in db - claim, err := token.VerifyVerificationToken(params.Token) + claim, err := token.ParseJWTToken(params.Token) if err != nil { return res, fmt.Errorf(`invalid token`) } - user, err := db.Provider.GetUserByEmail(claim.Email) + user, err := db.Provider.GetUserByEmail(claim["email"].(string)) if err != nil { return res, err } diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 0eb4542..1d8c988 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -62,7 +62,6 @@ func CreateAuthToken(user models.User, roles []string) (*Token, error) { // CreateRefreshToken util to create JWT token func CreateRefreshToken(user models.User, roles []string) (string, int64, error) { - t := jwt.New(jwt.GetSigningMethod(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType))) // expires in 1 year expiryBound := time.Hour * 8760 expiresAt := time.Now().Add(expiryBound).Unix() @@ -75,8 +74,7 @@ func CreateRefreshToken(user models.User, roles []string) (string, int64, error) "id": user.ID, } - t.Claims = customClaims - token, err := t.SignedString([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) + token, err := SignJWTToken(customClaims) if err != nil { return "", 0, err } @@ -86,9 +84,7 @@ func CreateRefreshToken(user models.User, roles []string) (string, int64, error) // CreateAccessToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT func CreateAccessToken(user models.User, roles []string) (string, int64, error) { - t := jwt.New(jwt.GetSigningMethod(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType))) expiryBound := time.Minute * 30 - expiresAt := time.Now().Add(expiryBound).Unix() resUser := user.AsAPIUser() @@ -141,9 +137,7 @@ func CreateAccessToken(user models.User, roles []string) (string, int64, error) } } - t.Claims = customClaims - - token, err := t.SignedString([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) + token, err := SignJWTToken(customClaims) if err != nil { return "", 0, err } @@ -187,43 +181,13 @@ func GetFingerPrint(gc *gin.Context) (string, error) { return fingerPrint, nil } -// VerifyJWTToken helps in verifying the JWT token -func VerifyJWTToken(token string) (map[string]interface{}, error) { - var res map[string]interface{} - claims := jwt.MapClaims{} - - t, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { - return []byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil - }) - if err != nil { - return res, err - } - - if !t.Valid { - return res, fmt.Errorf(`invalid token`) - } - - // claim parses exp & iat into float 64 with e^10, - // but we expect it to be int64 - // hence we need to assert interface and convert to int64 - intExp := int64(claims["exp"].(float64)) - intIat := int64(claims["iat"].(float64)) - - data, _ := json.Marshal(claims) - json.Unmarshal(data, &res) - res["exp"] = intExp - res["iat"] = intIat - - return res, nil -} - func ValidateAccessToken(gc *gin.Context) (map[string]interface{}, error) { token, err := GetAccessToken(gc) if err != nil { return nil, err } - claims, err := VerifyJWTToken(token) + claims, err := ParseJWTToken(token) if err != nil { return nil, err } diff --git a/server/token/jwt.go b/server/token/jwt.go new file mode 100644 index 0000000..6c517b1 --- /dev/null +++ b/server/token/jwt.go @@ -0,0 +1,83 @@ +package token + +import ( + "errors" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/envstore" + "github.com/golang-jwt/jwt" +) + +// SignJWTToken common util to sing jwt token +func SignJWTToken(claims jwt.MapClaims) (string, error) { + jwtType := envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + signingMethod := jwt.GetSigningMethod(jwtType) + t := jwt.New(signingMethod) + t.Claims = claims + + switch signingMethod { + case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512: + return t.SignedString([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) + case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512: + key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey))) + if err != nil { + return "", err + } + return t.SignedString(key) + case jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512: + key, err := jwt.ParseECPrivateKeyFromPEM([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey))) + if err != nil { + return "", err + } + return t.SignedString(key) + default: + return "", errors.New("unsupported signing method") + } +} + +// ParseJWTToken common util to parse jwt token +func ParseJWTToken(token string) (jwt.MapClaims, error) { + jwtType := envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + signingMethod := jwt.GetSigningMethod(jwtType) + + var err error + var claims jwt.MapClaims + + switch signingMethod { + case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512: + _, err = jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return []byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil + }) + case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512: + _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { + key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey))) + if err != nil { + return nil, err + } + return key, nil + }) + case jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512: + _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { + key, err := jwt.ParseECPublicKeyFromPEM([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey))) + if err != nil { + return nil, err + } + return key, nil + }) + default: + err = errors.New("unsupported signing method") + } + if err != nil { + return claims, err + } + + // claim parses exp & iat into float 64 with e^10, + // but we expect it to be int64 + // hence we need to assert interface and convert to int64 + intExp := int64(claims["exp"].(float64)) + intIat := int64(claims["iat"].(float64)) + claims["exp"] = intExp + claims["iat"] = intIat + + return claims, nil +} diff --git a/server/token/verification_token.go b/server/token/verification_token.go index 5e70bfb..8b0bdcc 100644 --- a/server/token/verification_token.go +++ b/server/token/verification_token.go @@ -8,44 +8,16 @@ import ( "github.com/golang-jwt/jwt" ) -// VerificationRequestToken is the user info that is stored in the JWT of verification request -type VerificationRequestToken struct { - Email string `json:"email"` - Host string `json:"host"` - RedirectURL string `json:"redirect_url"` -} - -// CustomClaim is the custom claim that is stored in the JWT of verification request -type CustomClaim struct { - *jwt.StandardClaims - TokenType string `json:"token_type"` - VerificationRequestToken -} - // CreateVerificationToken creates a verification JWT token func CreateVerificationToken(email, tokenType, hostname string) (string, error) { - t := jwt.New(jwt.GetSigningMethod(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType))) - - t.Claims = &CustomClaim{ - &jwt.StandardClaims{ - ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), - }, - tokenType, - VerificationRequestToken{Email: email, Host: hostname, RedirectURL: envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAppURL)}, + claims := jwt.MapClaims{ + "exp": time.Now().Add(time.Minute * 30).Unix(), + "iat": time.Now().Unix(), + "token_type": tokenType, + "email": email, + "host": hostname, + "redirect_url": envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAppURL), } - return t.SignedString([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) -} - -// VerifyVerificationToken verifies the verification JWT token -func VerifyVerificationToken(token string) (*CustomClaim, error) { - claims := &CustomClaim{} - _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { - return []byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil - }) - if err != nil { - return claims, err - } - - return claims, nil + return SignJWTToken(claims) }