diff --git a/server/constants/auth_methods.go b/server/constants/auth_methods.go new file mode 100644 index 0000000..272bf53 --- /dev/null +++ b/server/constants/auth_methods.go @@ -0,0 +1,18 @@ +package constants + +const ( + // AuthRecipeMethodBasicAuth is the basic_auth auth method + AuthRecipeMethodBasicAuth = "basic_auth" + // AuthRecipeMethodMagicLinkLogin is the magic_link_login auth method + AuthRecipeMethodMagicLinkLogin = "magic_link_login" + // AuthRecipeMethodGoogle is the google auth method + AuthRecipeMethodGoogle = "google" + // AuthRecipeMethodGithub is the github auth method + AuthRecipeMethodGithub = "github" + // AuthRecipeMethodFacebook is the facebook auth method + AuthRecipeMethodFacebook = "facebook" + // AuthRecipeMethodLinkedin is the linkedin auth method + AuthRecipeMethodLinkedIn = "linkedin" + // AuthRecipeMethodApple is the apple auth method + AuthRecipeMethodApple = "apple" +) diff --git a/server/constants/signup_methods.go b/server/constants/signup_methods.go deleted file mode 100644 index 279ccae..0000000 --- a/server/constants/signup_methods.go +++ /dev/null @@ -1,18 +0,0 @@ -package constants - -const ( - // SignupMethodBasicAuth is the basic_auth signup method - SignupMethodBasicAuth = "basic_auth" - // SignupMethodMagicLinkLogin is the magic_link_login signup method - SignupMethodMagicLinkLogin = "magic_link_login" - // SignupMethodGoogle is the google signup method - SignupMethodGoogle = "google" - // SignupMethodGithub is the github signup method - SignupMethodGithub = "github" - // SignupMethodFacebook is the facebook signup method - SignupMethodFacebook = "facebook" - // SignupMethodLinkedin is the linkedin signup method - SignupMethodLinkedIn = "linkedin" - // SignupMethodApple is the apple signup method - SignupMethodApple = "apple" -) diff --git a/server/handlers/authorize.go b/server/handlers/authorize.go index 3a8b6ab..372e6e1 100644 --- a/server/handlers/authorize.go +++ b/server/handlers/authorize.go @@ -218,13 +218,18 @@ func AuthorizeHandler() gin.HandlerFunc { return } + sessionKey := user.ID + if claims.LoginMethod != "" { + sessionKey = claims.LoginMethod + ":" + user.ID + } + // if user is logged in - // based on the response type, generate the response + // based on the response type code, generate the response if isResponseTypeCode { // rollover the session for security - go memorystore.Provider.DeleteUserSession(user.ID, claims.Nonce) + go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) nonce := uuid.New().String() - newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope) + newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod) if err != nil { if isQuery { gc.Redirect(http.StatusFound, loginURL) @@ -262,7 +267,7 @@ func AuthorizeHandler() gin.HandlerFunc { if isResponseTypeToken { // rollover the session for security - authToken, err := token.CreateAuthToken(gc, user, claims.Roles, scope) + authToken, err := token.CreateAuthToken(gc, user, claims.Roles, scope, claims.LoginMethod) if err != nil { if isQuery { gc.Redirect(http.StatusFound, loginURL) @@ -280,9 +285,10 @@ func AuthorizeHandler() gin.HandlerFunc { } return } - go memorystore.Provider.DeleteUserSession(user.ID, claims.Nonce) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + + go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -305,7 +311,7 @@ func AuthorizeHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token params += "&refresh_token=" + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } if isQuery { diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index b3020ea..780f7d8 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -57,15 +57,15 @@ func OAuthCallbackHandler() gin.HandlerFunc { user := models.User{} code := c.Request.FormValue("code") switch provider { - case constants.SignupMethodGoogle: + case constants.AuthRecipeMethodGoogle: user, err = processGoogleUserInfo(code) - case constants.SignupMethodGithub: + case constants.AuthRecipeMethodGithub: user, err = processGithubUserInfo(code) - case constants.SignupMethodFacebook: + case constants.AuthRecipeMethodFacebook: user, err = processFacebookUserInfo(code) - case constants.SignupMethodLinkedIn: + case constants.AuthRecipeMethodLinkedIn: user, err = processLinkedInUserInfo(code) - case constants.SignupMethodApple: + case constants.AuthRecipeMethodApple: user, err = processAppleUserInfo(code) default: log.Info("Invalid oauth provider") @@ -192,7 +192,7 @@ func OAuthCallbackHandler() gin.HandlerFunc { } } - authToken, err := token.CreateAuthToken(c, user, inputRoles, scopes) + authToken, err := token.CreateAuthToken(c, user, inputRoles, scopes, provider) if err != nil { log.Debug("Failed to create auth token: ", err) c.JSON(500, gin.H{"error": err.Error()}) @@ -205,13 +205,14 @@ func OAuthCallbackHandler() gin.HandlerFunc { params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + stateValue + "&id_token=" + authToken.IDToken.Token + sessionKey := provider + ":" + user.ID cookie.SetSession(c, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { params = params + `&refresh_token=` + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } go db.Provider.AddSession(models.Session{ diff --git a/server/handlers/oauth_login.go b/server/handlers/oauth_login.go index 235d794..4109162 100644 --- a/server/handlers/oauth_login.go +++ b/server/handlers/oauth_login.go @@ -100,13 +100,13 @@ func OAuthLoginHandler() gin.HandlerFunc { provider := c.Param("oauth_provider") isProviderConfigured := true switch provider { - case constants.SignupMethodGoogle: + case constants.AuthRecipeMethodGoogle: if oauth.OAuthProviders.GoogleConfig == nil { log.Debug("Google OAuth provider is not configured") isProviderConfigured = false break } - err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGoogle) + err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodGoogle) if err != nil { log.Debug("Error setting state: ", err) c.JSON(500, gin.H{ @@ -115,16 +115,16 @@ func OAuthLoginHandler() gin.HandlerFunc { return } // during the init of OAuthProvider authorizer url might be empty - oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodGoogle + oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodGoogle url := oauth.OAuthProviders.GoogleConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) - case constants.SignupMethodGithub: + case constants.AuthRecipeMethodGithub: if oauth.OAuthProviders.GithubConfig == nil { log.Debug("Github OAuth provider is not configured") isProviderConfigured = false break } - err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGithub) + err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodGithub) if err != nil { log.Debug("Error setting state: ", err) c.JSON(500, gin.H{ @@ -132,16 +132,16 @@ func OAuthLoginHandler() gin.HandlerFunc { }) return } - oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodGithub + oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodGithub url := oauth.OAuthProviders.GithubConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) - case constants.SignupMethodFacebook: + case constants.AuthRecipeMethodFacebook: if oauth.OAuthProviders.FacebookConfig == nil { log.Debug("Facebook OAuth provider is not configured") isProviderConfigured = false break } - err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodFacebook) + err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodFacebook) if err != nil { log.Debug("Error setting state: ", err) c.JSON(500, gin.H{ @@ -149,16 +149,16 @@ func OAuthLoginHandler() gin.HandlerFunc { }) return } - oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodFacebook + oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodFacebook url := oauth.OAuthProviders.FacebookConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) - case constants.SignupMethodLinkedIn: + case constants.AuthRecipeMethodLinkedIn: if oauth.OAuthProviders.LinkedInConfig == nil { log.Debug("Linkedin OAuth provider is not configured") isProviderConfigured = false break } - err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodLinkedIn) + err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodLinkedIn) if err != nil { log.Debug("Error setting state: ", err) c.JSON(500, gin.H{ @@ -166,16 +166,16 @@ func OAuthLoginHandler() gin.HandlerFunc { }) return } - oauth.OAuthProviders.LinkedInConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodLinkedIn + oauth.OAuthProviders.LinkedInConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodLinkedIn url := oauth.OAuthProviders.LinkedInConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) - case constants.SignupMethodApple: + case constants.AuthRecipeMethodApple: if oauth.OAuthProviders.AppleConfig == nil { log.Debug("Apple OAuth provider is not configured") isProviderConfigured = false break } - err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodApple) + err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodApple) if err != nil { log.Debug("Error setting state: ", err) c.JSON(500, gin.H{ @@ -183,7 +183,7 @@ func OAuthLoginHandler() gin.HandlerFunc { }) return } - oauth.OAuthProviders.AppleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodApple + oauth.OAuthProviders.AppleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodApple // there is scope encoding issue with oauth2 and how apple expects, hence added scope manually // check: https://github.com/golang/oauth2/issues/449 url := oauth.OAuthProviders.AppleConfig.AuthCodeURL(oauthStateString, oauth2.SetAuthURLParam("response_mode", "form_post")) + "&scope=name email" diff --git a/server/handlers/revoke.go b/server/handlers/revoke.go index 6e71ee3..faeb663 100644 --- a/server/handlers/revoke.go +++ b/server/handlers/revoke.go @@ -56,7 +56,14 @@ func RevokeHandler() gin.HandlerFunc { return } - memorystore.Provider.DeleteUserSession(claims["sub"].(string), claims["nonce"].(string)) + userID := claims["sub"].(string) + loginMethod := claims["login_method"] + sessionToken := userID + if loginMethod != nil && loginMethod != "" { + sessionToken = loginMethod.(string) + ":" + userID + } + + memorystore.Provider.DeleteUserSession(sessionToken, claims["nonce"].(string)) gc.JSON(http.StatusOK, gin.H{ "message": "Token revoked successfully", diff --git a/server/handlers/token.go b/server/handlers/token.go index 4740466..e808dbb 100644 --- a/server/handlers/token.go +++ b/server/handlers/token.go @@ -72,6 +72,9 @@ func TokenHandler() gin.HandlerFunc { var userID string var roles, scope []string + loginMethod := "" + sessionKey := "" + if isAuthorizationCodeGrant { if codeVerifier == "" { @@ -134,8 +137,13 @@ func TokenHandler() gin.HandlerFunc { userID = claims.Subject roles = claims.Roles scope = claims.Scope + loginMethod = claims.LoginMethod // rollover the session for security - go memorystore.Provider.DeleteUserSession(userID, claims.Nonce) + sessionKey = userID + if loginMethod != "" { + sessionKey = loginMethod + ":" + userID + } + go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) } else { // validate refresh token if refreshToken == "" { @@ -155,6 +163,7 @@ func TokenHandler() gin.HandlerFunc { }) } userID = claims["sub"].(string) + loginMethod := claims["login_method"] rolesInterface := claims["roles"].([]interface{}) scopeInterface := claims["scope"].([]interface{}) for _, v := range rolesInterface { @@ -163,8 +172,22 @@ func TokenHandler() gin.HandlerFunc { for _, v := range scopeInterface { scope = append(scope, v.(string)) } + + sessionKey = userID + if loginMethod != nil && loginMethod != "" { + sessionKey = loginMethod.(string) + ":" + sessionKey + } // remove older refresh token and rotate it for security - go memorystore.Provider.DeleteUserSession(userID, claims["nonce"].(string)) + go memorystore.Provider.DeleteUserSession(sessionKey, claims["nonce"].(string)) + } + + if sessionKey == "" { + log.Debug("Error getting sessionKey: ", sessionKey, loginMethod) + gc.JSON(http.StatusUnauthorized, gin.H{ + "error": "unauthorized", + "error_description": "User not found", + }) + return } user, err := db.Provider.GetUserByID(userID) @@ -177,7 +200,7 @@ func TokenHandler() gin.HandlerFunc { return } - authToken, err := token.CreateAuthToken(gc, user, roles, scope) + authToken, err := token.CreateAuthToken(gc, user, roles, scope, loginMethod) if err != nil { log.Debug("Error creating auth token: ", err) gc.JSON(http.StatusUnauthorized, gin.H{ @@ -186,8 +209,8 @@ func TokenHandler() gin.HandlerFunc { }) return } - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -205,7 +228,7 @@ func TokenHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } gc.JSON(http.StatusOK, res) diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index 3eb7324..4e9d7d7 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -92,7 +92,11 @@ func VerifyEmailHandler() gin.HandlerFunc { } else { scope = strings.Split(scopeString, " ") } - authToken, err := token.CreateAuthToken(c, user, roles, scope) + loginMethod := constants.AuthRecipeMethodBasicAuth + if verificationRequest.Identifier == constants.VerificationTypeMagicLinkLogin { + loginMethod = constants.AuthRecipeMethodMagicLinkLogin + } + authToken, err := token.CreateAuthToken(c, user, roles, scope, loginMethod) if err != nil { log.Debug("Error creating auth token: ", err) errorRes["error_description"] = err.Error() @@ -107,13 +111,14 @@ func VerifyEmailHandler() gin.HandlerFunc { params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + state + "&id_token=" + authToken.IDToken.Token + sessionKey := loginMethod + ":" + user.ID cookie.SetSession(c, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { params = params + `&refresh_token=` + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } if redirectURL == "" { diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go index 21046ce..ebd0c06 100644 --- a/server/memorystore/providers/inmemory/store.go +++ b/server/memorystore/providers/inmemory/store.go @@ -26,24 +26,34 @@ func (c *provider) GetUserSession(userId, sessionToken string) (string, error) { // DeleteAllUserSessions deletes all the user sessions from in-memory store. func (c *provider) DeleteAllUserSessions(userId string) error { - if os.Getenv("ENV") != constants.TestEnv { - c.mutex.Lock() - defer c.mutex.Unlock() + namespaces := []string{ + constants.AuthRecipeMethodBasicAuth, + constants.AuthRecipeMethodMagicLinkLogin, + constants.AuthRecipeMethodApple, + constants.AuthRecipeMethodFacebook, + constants.AuthRecipeMethodGithub, + constants.AuthRecipeMethodGoogle, + constants.AuthRecipeMethodLinkedIn, + } + + for _, namespace := range namespaces { + c.sessionStore.RemoveAll(namespace + ":" + userId) } - c.sessionStore.RemoveAll(userId) return nil } // DeleteUserSession deletes the user session from the in-memory store. func (c *provider) DeleteUserSession(userId, sessionToken string) error { - if os.Getenv("ENV") != constants.TestEnv { - c.mutex.Lock() - defer c.mutex.Unlock() - } c.sessionStore.Remove(userId, sessionToken) return nil } +// DeleteSessionForNamespace to delete session for a given namespace example google,github +func (c *provider) DeleteSessionForNamespace(namespace string) error { + c.sessionStore.RemoveByNamespace(namespace) + return nil +} + // SetState sets the state in the in-memory store. func (c *provider) SetState(key, state string) error { if os.Getenv("ENV") != constants.TestEnv { diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index c75a3d7..d702fa0 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -2,6 +2,7 @@ package stores import ( "os" + "strings" "sync" "github.com/authorizerdev/authorizer/server/constants" @@ -65,3 +66,18 @@ func (s *SessionStore) GetAll(key string) map[string]string { } return s.store[key] } + +// RemoveByNamespace to delete session for a given namespace example google,github +func (s *SessionStore) RemoveByNamespace(namespace string) error { + if os.Getenv("ENV") != constants.TestEnv { + s.mutex.Lock() + defer s.mutex.Unlock() + } + + for key := range s.store { + if strings.Contains(key, namespace+":") { + delete(s.store, key) + } + } + return nil +} diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index f3b2471..4953edb 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -12,6 +12,8 @@ type Provider interface { DeleteUserSession(userId, key string) error // DeleteAllSessions deletes all the sessions from the session store DeleteAllUserSessions(userId string) error + // DeleteSessionForNamespace deletes the session for a given namespace + DeleteSessionForNamespace(namespace string) error // SetState sets the login state (key, value form) in the session store SetState(key, state string) error diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index f70365d..36b4b0c 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -63,11 +63,48 @@ func (c *provider) DeleteUserSession(userId, key string) error { // DeleteAllUserSessions deletes all the user session from redis func (c *provider) DeleteAllUserSessions(userID string) error { - err := c.store.Del(c.ctx, userID).Err() - if err != nil { - log.Debug("Error deleting all user sessions from redis: ", err) - return err + namespaces := []string{ + constants.AuthRecipeMethodBasicAuth, + constants.AuthRecipeMethodMagicLinkLogin, + constants.AuthRecipeMethodApple, + constants.AuthRecipeMethodFacebook, + constants.AuthRecipeMethodGithub, + constants.AuthRecipeMethodGoogle, + constants.AuthRecipeMethodLinkedIn, } + for _, namespace := range namespaces { + err := c.store.Del(c.ctx, namespace+":"+userID).Err() + if err != nil { + log.Debug("Error deleting all user sessions from redis: ", err) + return err + } + } + return nil +} + +// DeleteSessionForNamespace to delete session for a given namespace example google,github +func (c *provider) DeleteSessionForNamespace(namespace string) error { + var cursor uint64 + for { + keys := []string{} + keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result() + if err != nil { + log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error()) + return err + } + + for _, key := range keys { + err := c.store.Del(c.ctx, key).Err() + if err != nil { + log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error()) + return err + } + } + if cursor == 0 { // no more keys + break + } + } + return nil } diff --git a/server/resolvers/invite_members.go b/server/resolvers/invite_members.go index c454dc8..7b77ca4 100644 --- a/server/resolvers/invite_members.go +++ b/server/resolvers/invite_members.go @@ -129,11 +129,11 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) // use magic link login if that option is on if !isMagicLinkLoginDisabled { - user.SignupMethods = constants.SignupMethodMagicLinkLogin + user.SignupMethods = constants.AuthRecipeMethodMagicLinkLogin verificationRequest.Identifier = constants.VerificationTypeMagicLinkLogin } else { // use basic authentication if that option is on - user.SignupMethods = constants.SignupMethodBasicAuth + user.SignupMethods = constants.AuthRecipeMethodBasicAuth verificationRequest.Identifier = constants.VerificationTypeForgotPassword verifyEmailURL = appURL + "/setup-password" diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 993ffce..4cffee1 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -56,7 +56,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes return res, fmt.Errorf(`user access has been revoked`) } - if !strings.Contains(user.SignupMethods, constants.SignupMethodBasicAuth) { + if !strings.Contains(user.SignupMethods, constants.AuthRecipeMethodBasicAuth) { log.Debug("User signup method is not basic auth") return res, fmt.Errorf(`user has not signed up email & password`) } @@ -97,7 +97,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes scope = params.Scope } - authToken, err := token.CreateAuthToken(gc, user, roles, scope) + authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth) if err != nil { log.Debug("Failed to create auth token", err) return res, err @@ -117,12 +117,13 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes } cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } go db.Provider.AddSession(models.Session{ diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index 4dd8269..54ce873 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -41,7 +41,12 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) { return nil, err } - memorystore.Provider.DeleteUserSession(sessionData.Subject, sessionData.Nonce) + sessionKey := sessionData.Subject + if sessionData.LoginMethod != "" { + sessionKey = sessionData.LoginMethod + ":" + sessionData.Subject + } + + memorystore.Provider.DeleteUserSession(sessionKey, sessionData.Nonce) cookie.DeleteSession(gc) res := &model.Response{ diff --git a/server/resolvers/magic_link_login.go b/server/resolvers/magic_link_login.go index 541d622..caae37a 100644 --- a/server/resolvers/magic_link_login.go +++ b/server/resolvers/magic_link_login.go @@ -70,7 +70,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu return res, fmt.Errorf(`signup is disabled for this instance`) } - user.SignupMethods = constants.SignupMethodMagicLinkLogin + user.SignupMethods = constants.AuthRecipeMethodMagicLinkLogin // define roles for new user if len(params.Roles) > 0 { // check if roles exists @@ -158,8 +158,8 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu } signupMethod := existingUser.SignupMethods - if !strings.Contains(signupMethod, constants.SignupMethodMagicLinkLogin) { - signupMethod = signupMethod + "," + constants.SignupMethodMagicLinkLogin + if !strings.Contains(signupMethod, constants.AuthRecipeMethodMagicLinkLogin) { + signupMethod = signupMethod + "," + constants.AuthRecipeMethodMagicLinkLogin } user.SignupMethods = signupMethod diff --git a/server/resolvers/reset_password.go b/server/resolvers/reset_password.go index d6f0780..e747830 100644 --- a/server/resolvers/reset_password.go +++ b/server/resolvers/reset_password.go @@ -82,8 +82,8 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput) user.Password = &password signupMethod := user.SignupMethods - if !strings.Contains(signupMethod, constants.SignupMethodBasicAuth) { - signupMethod = signupMethod + "," + constants.SignupMethodBasicAuth + if !strings.Contains(signupMethod, constants.AuthRecipeMethodBasicAuth) { + signupMethod = signupMethod + "," + constants.AuthRecipeMethodBasicAuth } user.SignupMethods = signupMethod diff --git a/server/resolvers/session.go b/server/resolvers/session.go index e2f5061..903f8cb 100644 --- a/server/resolvers/session.go +++ b/server/resolvers/session.go @@ -70,14 +70,18 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod scope = params.Scope } - authToken, err := token.CreateAuthToken(gc, user, claimRoles, scope) + authToken, err := token.CreateAuthToken(gc, user, claimRoles, scope, claims.LoginMethod) if err != nil { log.Debug("Failed to create auth token: ", err) return res, err } // rollover the session for security - go memorystore.Provider.DeleteUserSession(userID, claims.Nonce) + sessionKey := userID + if claims.LoginMethod != "" { + sessionKey = claims.LoginMethod + ":" + userID + } + go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() if expiresIn <= 0 { @@ -93,12 +97,12 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod } cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } return res, nil } diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 4db06e4..4f7051e 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -157,7 +157,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR user.Picture = params.Picture } - user.SignupMethods = constants.SignupMethodBasicAuth + user.SignupMethods = constants.AuthRecipeMethodBasicAuth isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) if err != nil { log.Debug("Error getting email verification disabled: ", err) @@ -219,7 +219,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR scope = params.Scope } - authToken, err := token.CreateAuthToken(gc, user, roles, scope) + authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth) if err != nil { log.Debug("Failed to create auth token: ", err) return res, err @@ -243,13 +243,14 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR User: userToReturn, } + sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } } diff --git a/server/resolvers/update_env.go b/server/resolvers/update_env.go index d7023a0..41fe5fa 100644 --- a/server/resolvers/update_env.go +++ b/server/resolvers/update_env.go @@ -21,6 +21,54 @@ import ( "github.com/authorizerdev/authorizer/server/utils" ) +// check if login methods have been disabled +// remove the session tokens for those methods +func clearSessionIfRequired(currentData, updatedData map[string]interface{}) { + isCurrentBasicAuthEnabled := !currentData[constants.EnvKeyDisableBasicAuthentication].(bool) + isCurrentMagicLinkLoginEnabled := !currentData[constants.EnvKeyDisableMagicLinkLogin].(bool) + isCurrentAppleLoginEnabled := currentData[constants.EnvKeyAppleClientID] != nil && currentData[constants.EnvKeyAppleClientSecret] != nil && currentData[constants.EnvKeyAppleClientID].(string) != "" && currentData[constants.EnvKeyAppleClientSecret].(string) != "" + isCurrentFacebookLoginEnabled := currentData[constants.EnvKeyFacebookClientID] != nil && currentData[constants.EnvKeyFacebookClientSecret] != nil && currentData[constants.EnvKeyFacebookClientID].(string) != "" && currentData[constants.EnvKeyFacebookClientSecret].(string) != "" + isCurrentGoogleLoginEnabled := currentData[constants.EnvKeyGoogleClientID] != nil && currentData[constants.EnvKeyGoogleClientSecret] != nil && currentData[constants.EnvKeyGoogleClientID].(string) != "" && currentData[constants.EnvKeyGoogleClientSecret].(string) != "" + isCurrentGithubLoginEnabled := currentData[constants.EnvKeyGithubClientID] != nil && currentData[constants.EnvKeyGithubClientSecret] != nil && currentData[constants.EnvKeyGithubClientID].(string) != "" && currentData[constants.EnvKeyGithubClientSecret].(string) != "" + isCurrentLinkedInLoginEnabled := currentData[constants.EnvKeyLinkedInClientID] != nil && currentData[constants.EnvKeyLinkedInClientSecret] != nil && currentData[constants.EnvKeyLinkedInClientID].(string) != "" && currentData[constants.EnvKeyLinkedInClientSecret].(string) != "" + + isUpdatedBasicAuthEnabled := !updatedData[constants.EnvKeyDisableBasicAuthentication].(bool) + isUpdatedMagicLinkLoginEnabled := !updatedData[constants.EnvKeyDisableMagicLinkLogin].(bool) + isUpdatedAppleLoginEnabled := updatedData[constants.EnvKeyAppleClientID] != nil && updatedData[constants.EnvKeyAppleClientSecret] != nil && updatedData[constants.EnvKeyAppleClientID].(string) != "" && updatedData[constants.EnvKeyAppleClientSecret].(string) != "" + isUpdatedFacebookLoginEnabled := updatedData[constants.EnvKeyFacebookClientID] != nil && updatedData[constants.EnvKeyFacebookClientSecret] != nil && updatedData[constants.EnvKeyFacebookClientID].(string) != "" && updatedData[constants.EnvKeyFacebookClientSecret].(string) != "" + isUpdatedGoogleLoginEnabled := updatedData[constants.EnvKeyGoogleClientID] != nil && updatedData[constants.EnvKeyGoogleClientSecret] != nil && updatedData[constants.EnvKeyGoogleClientID].(string) != "" && updatedData[constants.EnvKeyGoogleClientSecret].(string) != "" + isUpdatedGithubLoginEnabled := updatedData[constants.EnvKeyGithubClientID] != nil && updatedData[constants.EnvKeyGithubClientSecret] != nil && updatedData[constants.EnvKeyGithubClientID].(string) != "" && updatedData[constants.EnvKeyGithubClientSecret].(string) != "" + isUpdatedLinkedInLoginEnabled := updatedData[constants.EnvKeyLinkedInClientID] != nil && updatedData[constants.EnvKeyLinkedInClientSecret] != nil && updatedData[constants.EnvKeyLinkedInClientID].(string) != "" && updatedData[constants.EnvKeyLinkedInClientSecret].(string) != "" + + if isCurrentBasicAuthEnabled && !isUpdatedBasicAuthEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodBasicAuth) + } + + if isCurrentMagicLinkLoginEnabled && !isUpdatedMagicLinkLoginEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodMagicLinkLogin) + } + + if isCurrentAppleLoginEnabled && !isUpdatedAppleLoginEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodApple) + } + + if isCurrentFacebookLoginEnabled && !isUpdatedFacebookLoginEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodFacebook) + } + + if isCurrentGoogleLoginEnabled && !isUpdatedGoogleLoginEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodGoogle) + } + + if isCurrentGithubLoginEnabled && !isUpdatedGithubLoginEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodGithub) + } + + if isCurrentLinkedInLoginEnabled && !isUpdatedLinkedInLoginEnabled { + memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodLinkedIn) + } +} + // UpdateEnvResolver is a resolver for update config mutation // This is admin only mutation func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model.Response, error) { @@ -37,12 +85,19 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model return res, fmt.Errorf("unauthorized") } - updatedData, err := memorystore.Provider.GetEnvStore() + currentData, err := memorystore.Provider.GetEnvStore() if err != nil { log.Debug("Failed to get env store: ", err) return res, err } + // clone currentData in new var + // that will be updated based on the req + updatedData := make(map[string]interface{}) + for key, val := range currentData { + updatedData[key] = val + } + isJWTUpdated := false algo := updatedData[constants.EnvKeyJwtType].(string) if params.JwtType != nil { @@ -210,6 +265,8 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model } } + go clearSessionIfRequired(currentData, updatedData) + // Update local store memorystore.Provider.UpdateEnvStore(updatedData) jwk, err := crypto.GenerateJWKBasedOnEnv() @@ -224,12 +281,6 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model return res, err } - // TODO check how to update session store based on env change. - // err = sessionstore.InitSession() - // if err != nil { - // log.Debug("Failed to init session store: ", err) - // return res, err - // } err = oauth.InitOAuth() if err != nil { return res, err diff --git a/server/resolvers/validate_jwt_token.go b/server/resolvers/validate_jwt_token.go index 4139826..dff4622 100644 --- a/server/resolvers/validate_jwt_token.go +++ b/server/resolvers/validate_jwt_token.go @@ -50,7 +50,12 @@ func ValidateJwtTokenResolver(ctx context.Context, params model.ValidateJWTToken // access_token and refresh_token should be validated from session store as well if tokenType == constants.TokenTypeAccessToken || tokenType == constants.TokenTypeRefreshToken { nonce = claims["nonce"].(string) - token, err := memorystore.Provider.GetUserSession(userID, tokenType+"_"+claims["nonce"].(string)) + loginMethod := claims["login_method"] + sessionKey := userID + if loginMethod != nil && loginMethod != "" { + sessionKey = loginMethod.(string) + ":" + userID + } + token, err := memorystore.Provider.GetUserSession(sessionKey, tokenType+"_"+claims["nonce"].(string)) if err != nil || token == "" { log.Debug("Failed to get user session: ", err) return nil, errors.New("invalid token") diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index e1d5827..a10d83d 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -73,9 +73,14 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m return res, err } + loginMethod := constants.AuthRecipeMethodBasicAuth + if loginMethod == constants.VerificationTypeMagicLinkLogin { + loginMethod = constants.AuthRecipeMethodMagicLinkLogin + } + roles := strings.Split(user.Roles, ",") scope := []string{"openid", "email", "profile"} - authToken, err := token.CreateAuthToken(gc, user, roles, scope) + authToken, err := token.CreateAuthToken(gc, user, roles, scope, loginMethod) if err != nil { log.Debug("Failed to create auth token: ", err) return res, err @@ -100,13 +105,14 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m User: user.AsAPIUser(), } + sessionKey := loginMethod + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } return res, nil } diff --git a/server/test/logout_test.go b/server/test/logout_test.go index d0d6228..2b7a897 100644 --- a/server/test/logout_test.go +++ b/server/test/logout_test.go @@ -36,7 +36,13 @@ func logoutTests(t *testing.T, s TestSetup) { assert.NoError(t, err) assert.NotEmpty(t, claims) - sessionToken, err := memorystore.Provider.GetUserSession(verifyRes.User.ID, constants.TokenTypeSessionToken+"_"+claims["nonce"].(string)) + loginMethod := claims["login_method"] + sessionKey := verifyRes.User.ID + if loginMethod != nil && loginMethod != "" { + sessionKey = loginMethod.(string) + ":" + verifyRes.User.ID + } + + sessionToken, err := memorystore.Provider.GetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+claims["nonce"].(string)) assert.NoError(t, err) assert.NotEmpty(t, sessionToken) diff --git a/server/test/profile_test.go b/server/test/profile_test.go index 8f5a283..8f5592c 100644 --- a/server/test/profile_test.go +++ b/server/test/profile_test.go @@ -37,9 +37,9 @@ func profileTests(t *testing.T, s TestSetup) { ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) profileRes, err := resolvers.ProfileResolver(ctx) assert.Nil(t, err) + assert.NotNil(t, profileRes) s.GinContext.Request.Header.Set("Authorization", "") - - newEmail := *&profileRes.Email + newEmail := profileRes.Email assert.Equal(t, email, newEmail, "emails should be equal") cleanData(email) diff --git a/server/test/session_test.go b/server/test/session_test.go index 474be37..4d2e8df 100644 --- a/server/test/session_test.go +++ b/server/test/session_test.go @@ -41,7 +41,8 @@ func sessionTests(t *testing.T, s TestSetup) { assert.NoError(t, err) assert.NotEmpty(t, claims) - sessionToken, err := memorystore.Provider.GetUserSession(verifyRes.User.ID, constants.TokenTypeSessionToken+"_"+claims["nonce"].(string)) + sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + verifyRes.User.ID + sessionToken, err := memorystore.Provider.GetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+claims["nonce"].(string)) assert.NoError(t, err) assert.NotEmpty(t, sessionToken) diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go index c353e6e..0d5358a 100644 --- a/server/test/validate_jwt_token_test.go +++ b/server/test/validate_jwt_token_test.go @@ -50,12 +50,13 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { roles := []string{"user"} gc, err := utils.GinContextFromContext(ctx) assert.NoError(t, err) - authToken, err := token.CreateAuthToken(gc, user, roles, scope) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID + authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) if authToken.RefreshToken != nil { - memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) } t.Run(`should validate the access token`, func(t *testing.T) { diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 0076ff9..f108647 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -38,23 +38,25 @@ type Token struct { // SessionData type SessionData struct { - Subject string `json:"sub"` - Roles []string `json:"roles"` - Scope []string `json:"scope"` - Nonce string `json:"nonce"` - IssuedAt int64 `json:"iat"` - ExpiresAt int64 `json:"exp"` + Subject string `json:"sub"` + Roles []string `json:"roles"` + Scope []string `json:"scope"` + Nonce string `json:"nonce"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp"` + LoginMethod string `json:"login_method"` } // CreateSessionToken creates a new session token -func CreateSessionToken(user models.User, nonce string, roles, scope []string) (*SessionData, string, error) { +func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, error) { fingerPrintMap := &SessionData{ - Nonce: nonce, - Roles: roles, - Subject: user.ID, - Scope: scope, - IssuedAt: time.Now().Unix(), - ExpiresAt: time.Now().AddDate(1, 0, 0).Unix(), + Nonce: nonce, + Roles: roles, + Subject: user.ID, + Scope: scope, + LoginMethod: loginMethod, + IssuedAt: time.Now().Unix(), + ExpiresAt: time.Now().AddDate(1, 0, 0).Unix(), } fingerPrintBytes, _ := json.Marshal(fingerPrintMap) fingerPrintHash, err := crypto.EncryptAES(string(fingerPrintBytes)) @@ -66,19 +68,19 @@ func CreateSessionToken(user models.User, nonce string, roles, scope []string) ( } // CreateAuthToken creates a new auth token when userlogs in -func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) (*Token, error) { +func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, loginMethod string) (*Token, error) { hostname := parsers.GetHost(gc) nonce := uuid.New().String() - _, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope) + _, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) if err != nil { return nil, err } - accessToken, accessTokenExpiresAt, err := CreateAccessToken(user, roles, scope, hostname, nonce) + accessToken, accessTokenExpiresAt, err := CreateAccessToken(user, roles, scope, hostname, nonce, loginMethod) if err != nil { return nil, err } - idToken, idTokenExpiresAt, err := CreateIDToken(user, roles, hostname, nonce) + idToken, idTokenExpiresAt, err := CreateIDToken(user, roles, hostname, nonce, loginMethod) if err != nil { return nil, err } @@ -91,7 +93,7 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) ( } if utils.StringSliceContains(scope, "offline_access") { - refreshToken, refreshTokenExpiresAt, err := CreateRefreshToken(user, roles, scope, hostname, nonce) + refreshToken, refreshTokenExpiresAt, err := CreateRefreshToken(user, roles, scope, hostname, nonce, loginMethod) if err != nil { return nil, err } @@ -103,7 +105,7 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) ( } // CreateRefreshToken util to create JWT token -func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonce string) (string, int64, error) { +func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonce, loginMethod string) (string, int64, error) { // expires in 1 year expiryBound := time.Hour * 8760 expiresAt := time.Now().Add(expiryBound).Unix() @@ -112,15 +114,16 @@ func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonc return "", 0, err } customClaims := jwt.MapClaims{ - "iss": hostname, - "aud": clientID, - "sub": user.ID, - "exp": expiresAt, - "iat": time.Now().Unix(), - "token_type": constants.TokenTypeRefreshToken, - "roles": roles, - "scope": scopes, - "nonce": nonce, + "iss": hostname, + "aud": clientID, + "sub": user.ID, + "exp": expiresAt, + "iat": time.Now().Unix(), + "token_type": constants.TokenTypeRefreshToken, + "roles": roles, + "scope": scopes, + "nonce": nonce, + "login_method": loginMethod, } token, err := SignJWTToken(customClaims) @@ -133,7 +136,7 @@ func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonc // CreateAccessToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT -func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce string) (string, int64, error) { +func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce, loginMethod string) (string, int64, error) { expireTime, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime) if err != nil { return "", 0, err @@ -150,15 +153,16 @@ func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce return "", 0, err } customClaims := jwt.MapClaims{ - "iss": hostName, - "aud": clientID, - "nonce": nonce, - "sub": user.ID, - "exp": expiresAt, - "iat": time.Now().Unix(), - "token_type": constants.TokenTypeAccessToken, - "scope": scopes, - "roles": roles, + "iss": hostName, + "aud": clientID, + "nonce": nonce, + "sub": user.ID, + "exp": expiresAt, + "iat": time.Now().Unix(), + "token_type": constants.TokenTypeAccessToken, + "scope": scopes, + "roles": roles, + "login_method": loginMethod, } token, err := SignJWTToken(customClaims) @@ -205,7 +209,13 @@ func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interf userID := res["sub"].(string) nonce := res["nonce"].(string) - token, err := memorystore.Provider.GetUserSession(userID, constants.TokenTypeAccessToken+"_"+nonce) + loginMethod := res["login_method"] + sessionKey := userID + if loginMethod != nil && loginMethod != "" { + sessionKey = loginMethod.(string) + ":" + userID + } + + token, err := memorystore.Provider.GetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce) if nonce == "" || err != nil { return res, fmt.Errorf(`unauthorized`) } @@ -241,7 +251,13 @@ func ValidateRefreshToken(gc *gin.Context, refreshToken string) (map[string]inte userID := res["sub"].(string) nonce := res["nonce"].(string) - token, err := memorystore.Provider.GetUserSession(userID, constants.TokenTypeRefreshToken+"_"+nonce) + loginMethod := res["login_method"] + sessionKey := userID + if loginMethod != nil && loginMethod != "" { + sessionKey = loginMethod.(string) + ":" + userID + } + + token, err := memorystore.Provider.GetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+nonce) if nonce == "" || err != nil { return res, fmt.Errorf(`unauthorized`) } @@ -278,7 +294,12 @@ func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionD return nil, err } - token, err := memorystore.Provider.GetUserSession(res.Subject, constants.TokenTypeSessionToken+"_"+res.Nonce) + sessionStoreKey := res.Subject + if res.LoginMethod != "" { + sessionStoreKey = res.LoginMethod + ":" + res.Subject + } + + token, err := memorystore.Provider.GetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+res.Nonce) if token == "" || err != nil { log.Debug("invalid browser session:", err) return nil, fmt.Errorf(`unauthorized`) @@ -297,7 +318,7 @@ func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionD // CreateIDToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT -func CreateIDToken(user models.User, roles []string, hostname, nonce string) (string, int64, error) { +func CreateIDToken(user models.User, roles []string, hostname, nonce, loginMethod string) (string, int64, error) { expireTime, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime) if err != nil { return "", 0, err @@ -332,6 +353,7 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce string) (st "iat": time.Now().Unix(), "token_type": constants.TokenTypeIdentityToken, "allowed_roles": strings.Split(user.Roles, ","), + "login_method": loginMethod, claimKey: roles, }