diff --git a/server/handlers/authorize.go b/server/handlers/authorize.go index bad2363..d5e9a28 100644 --- a/server/handlers/authorize.go +++ b/server/handlers/authorize.go @@ -194,7 +194,7 @@ func AuthorizeHandler() gin.HandlerFunc { // rollover the session for security go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) if responseType == constants.ResponseTypeCode { - newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod) + newSessionTokenData, newSessionToken, newSessionExpiresAt, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod) if err != nil { log.Debug("CreateSessionToken failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) @@ -215,7 +215,7 @@ func AuthorizeHandler() gin.HandlerFunc { return } - if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+newSessionTokenData.Nonce, newSessionToken); err != nil { + if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+newSessionTokenData.Nonce, newSessionToken, newSessionExpiresAt); err != nil { log.Debug("SetUserSession failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) return @@ -271,13 +271,13 @@ func AuthorizeHandler() gin.HandlerFunc { return } - if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+nonce, authToken.FingerPrintHash); err != nil { + if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+nonce, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt); err != nil { log.Debug("SetUserSession failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) return } - if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce, authToken.AccessToken.Token); err != nil { + if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt); err != nil { log.Debug("SetUserSession failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) return @@ -305,7 +305,7 @@ func AuthorizeHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token params += "&refresh_token=" + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } if responseMode == constants.ResponseModeQuery { diff --git a/server/handlers/logout.go b/server/handlers/logout.go index bf1d69e..fbac22d 100644 --- a/server/handlers/logout.go +++ b/server/handlers/logout.go @@ -47,7 +47,14 @@ func LogoutHandler() gin.HandlerFunc { return } - memorystore.Provider.DeleteUserSession(sessionData.Subject, sessionData.Nonce) + userID := sessionData.Subject + loginMethod := sessionData.LoginMethod + sessionToken := userID + if loginMethod != "" { + sessionToken = loginMethod + ":" + userID + } + + memorystore.Provider.DeleteUserSession(sessionToken, sessionData.Nonce) cookie.DeleteSession(gc) if redirectURL != "" { diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index c206458..1b682f5 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -249,12 +249,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { sessionKey := provider + ":" + user.ID cookie.SetSession(ctx, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { params += `&refresh_token=` + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/handlers/token.go b/server/handlers/token.go index 8249d6f..2b38f7e 100644 --- a/server/handlers/token.go +++ b/server/handlers/token.go @@ -247,8 +247,8 @@ func TokenHandler() gin.HandlerFunc { return } - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -266,7 +266,7 @@ func TokenHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } gc.JSON(http.StatusOK, res) diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index cf5ec1a..91a2128 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -154,12 +154,12 @@ func VerifyEmailHandler() gin.HandlerFunc { sessionKey := loginMethod + ":" + user.ID cookie.SetSession(c, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { params = params + `&refresh_token=` + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } if redirectURL == "" { diff --git a/server/memorystore/providers/inmemory/provider_test.go b/server/memorystore/providers/inmemory/provider_test.go new file mode 100644 index 0000000..99446a1 --- /dev/null +++ b/server/memorystore/providers/inmemory/provider_test.go @@ -0,0 +1,14 @@ +package inmemory + +import ( + "testing" + + "github.com/authorizerdev/authorizer/server/memorystore/providers" + "github.com/stretchr/testify/assert" +) + +func TestInMemoryProvider(t *testing.T) { + p, err := NewInMemoryProvider() + assert.NoError(t, err) + providers.ProviderTests(t, p) +} diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go index 85f46c5..4a8e8ce 100644 --- a/server/memorystore/providers/inmemory/store.go +++ b/server/memorystore/providers/inmemory/store.go @@ -8,39 +8,31 @@ import ( ) // SetUserSession sets the user session for given user identifier in form recipe:user_id -func (c *provider) SetUserSession(userId, key, token string) error { - c.sessionStore.Set(userId, key, token) +func (c *provider) SetUserSession(userId, key, token string, expiration int64) error { + c.sessionStore.Set(userId, key, token, expiration) return nil } // GetUserSession returns value for given session token func (c *provider) GetUserSession(userId, sessionToken string) (string, error) { - return c.sessionStore.Get(userId, sessionToken), nil + val := c.sessionStore.Get(userId, sessionToken) + if val == "" { + return "", fmt.Errorf("Not found") + } + return val, nil } // DeleteAllUserSessions deletes all the user sessions from in-memory store. func (c *provider) DeleteAllUserSessions(userId string) error { - namespaces := []string{ - constants.AuthRecipeMethodBasicAuth, - constants.AuthRecipeMethodMagicLinkLogin, - constants.AuthRecipeMethodApple, - constants.AuthRecipeMethodFacebook, - constants.AuthRecipeMethodGithub, - constants.AuthRecipeMethodGoogle, - constants.AuthRecipeMethodLinkedIn, - constants.AuthRecipeMethodTwitter, - constants.AuthRecipeMethodMicrosoft, - } - - for _, namespace := range namespaces { - c.sessionStore.RemoveAll(namespace + ":" + userId) - } + c.sessionStore.RemoveAll(userId) return nil } // DeleteUserSession deletes the user session from the in-memory store. func (c *provider) DeleteUserSession(userId, sessionToken string) error { - c.sessionStore.Remove(userId, sessionToken) + c.sessionStore.Remove(userId, constants.TokenTypeSessionToken+"_"+sessionToken) + c.sessionStore.Remove(userId, constants.TokenTypeAccessToken+"_"+sessionToken) + c.sessionStore.Remove(userId, constants.TokenTypeRefreshToken+"_"+sessionToken) return nil } diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index 627d37e..d3d429d 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -1,8 +1,15 @@ package stores import ( + "fmt" "strings" "sync" + "time" +) + +const ( + // Maximum entries to keep in session storage + maxCacheSize = 1000 ) // SessionEntry is the struct for entry stored in store @@ -13,15 +20,16 @@ type SessionEntry struct { // SessionStore struct to store the env variables type SessionStore struct { - mutex sync.Mutex - store map[string]map[string]*SessionEntry + mutex sync.Mutex + store map[string]*SessionEntry + itemsToEvict []string } // NewSessionStore create a new session store func NewSessionStore() *SessionStore { return &SessionStore{ mutex: sync.Mutex{}, - store: make(map[string]map[string]*SessionEntry), + store: make(map[string]*SessionEntry), } } @@ -29,53 +37,59 @@ func NewSessionStore() *SessionStore { func (s *SessionStore) Get(key, subKey string) string { s.mutex.Lock() defer s.mutex.Unlock() - return s.store[key][subKey].Value + currentTime := time.Now().Unix() + k := fmt.Sprintf("%s:%s", key, subKey) + if v, ok := s.store[k]; ok { + if v.ExpiresAt > currentTime { + return v.Value + } + s.itemsToEvict = append(s.itemsToEvict, k) + } + return "" } // Set sets the value of the key in state store -func (s *SessionStore) Set(key string, subKey, value string) { +func (s *SessionStore) Set(key string, subKey, value string, expiration int64) { s.mutex.Lock() defer s.mutex.Unlock() - - if _, ok := s.store[key]; !ok { - s.store[key] = make(map[string]string) + k := fmt.Sprintf("%s:%s", key, subKey) + if _, ok := s.store[k]; !ok { + s.store[k] = &SessionEntry{ + Value: value, + ExpiresAt: expiration, + // TODO add expire time + } + } + s.store[k] = &SessionEntry{ + Value: value, + ExpiresAt: expiration, + // TODO add expire time } - s.store[key][subKey] = value } // RemoveAll all values for given key func (s *SessionStore) RemoveAll(key string) { s.mutex.Lock() defer s.mutex.Unlock() - - delete(s.store, key) + for k := range s.store { + if strings.Contains(k, key) { + delete(s.store, k) + } + } } // Remove value for given key and subkey func (s *SessionStore) Remove(key, subKey string) { s.mutex.Lock() defer s.mutex.Unlock() - if _, ok := s.store[key]; ok { - delete(s.store[key], subKey) - } -} - -// Get all the values for given key -func (s *SessionStore) GetAll(key string) map[string]string { - s.mutex.Lock() - defer s.mutex.Unlock() - - if _, ok := s.store[key]; !ok { - s.store[key] = make(map[string]string) - } - return s.store[key] + k := fmt.Sprintf("%s:%s", key, subKey) + delete(s.store, k) } // RemoveByNamespace to delete session for a given namespace example google,github func (s *SessionStore) RemoveByNamespace(namespace string) error { s.mutex.Lock() defer s.mutex.Unlock() - for key := range s.store { if strings.Contains(key, namespace+":") { delete(s.store, key) diff --git a/server/memorystore/providers/provider_tests.go b/server/memorystore/providers/provider_tests.go new file mode 100644 index 0000000..e569fe8 --- /dev/null +++ b/server/memorystore/providers/provider_tests.go @@ -0,0 +1,115 @@ +package providers + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// ProviderTests runs all provider tests +func ProviderTests(t *testing.T, p Provider) { + + err := p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Same user multiple session + err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "access_token_key1", "test_jwt1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Different user session + err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(5*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider:124", "access_token_key", "test_jwt124", time.Now().Add(5*time.Second).Unix()) + assert.NoError(t, err) + // Different provider session + err = p.SetUserSession("auth_provider1:124", "session_token_key", "test_hash124", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider1:124", "access_token_key", "test_jwt124", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Different provider session + err = p.SetUserSession("auth_provider1:123", "session_token_key", "test_hash1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider1:123", "access_token_key", "test_jwt1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Get session + key, err := p.GetUserSession("auth_provider:123", "session_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_hash123", key) + key, err = p.GetUserSession("auth_provider:123", "access_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_jwt123", key) + key, err = p.GetUserSession("auth_provider:124", "session_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_hash124", key) + key, err = p.GetUserSession("auth_provider:124", "access_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_jwt124", key) + // Expire some tokens and make sure they are empty + time.Sleep(5 * time.Second) + key, err = p.GetUserSession("auth_provider:124", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + // Delete user session + err = p.DeleteUserSession("auth_provider:123", "key") + assert.NoError(t, err) + err = p.DeleteUserSession("auth_provider:123", "key") + assert.NoError(t, err) + key, err = p.GetUserSession("auth_provider:123", "key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + // Delete all user session + err = p.DeleteAllUserSessions("123") + assert.NoError(t, err) + err = p.DeleteAllUserSessions("123") + assert.NoError(t, err) + key, err = p.GetUserSession("auth_provider:123", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "access_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + // Delete namespace + err = p.DeleteSessionForNamespace("auth_provider") + assert.NoError(t, err) + err = p.DeleteSessionForNamespace("auth_provider1") + assert.NoError(t, err) + key, err = p.GetUserSession("auth_provider:123", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "access_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "access_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:124", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:124", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) +} diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index 7388a97..db58aa7 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -3,7 +3,7 @@ package providers // Provider defines current memory store provider type Provider interface { // SetUserSession sets the user session for given user identifier in form recipe:user_id - SetUserSession(userId, key, token string) error + SetUserSession(userId, key, token string, expiration int64) error // GetUserSession returns the session token for given token GetUserSession(userId, key string) (string, error) // DeleteUserSession deletes the user session diff --git a/server/memorystore/providers/redis/provider.go b/server/memorystore/providers/redis/provider.go index 084f207..894a75e 100644 --- a/server/memorystore/providers/redis/provider.go +++ b/server/memorystore/providers/redis/provider.go @@ -32,7 +32,6 @@ type provider struct { // NewRedisProvider returns a new redis provider func NewRedisProvider(redisURL string) (*provider, error) { redisURLHostPortsList := strings.Split(redisURL, ",") - if len(redisURLHostPortsList) > 1 { opt, err := redis.ParseURL(redisURLHostPortsList[0]) if err != nil { diff --git a/server/memorystore/providers/redis/provider_test.go b/server/memorystore/providers/redis/provider_test.go new file mode 100644 index 0000000..280616c --- /dev/null +++ b/server/memorystore/providers/redis/provider_test.go @@ -0,0 +1,15 @@ +package redis + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/authorizerdev/authorizer/server/memorystore/providers" +) + +func TestRedisProvider(t *testing.T) { + p, err := NewRedisProvider("redis://127.0.0.1:6379") + assert.NoError(t, err) + providers.ProviderTests(t, p) +} diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index 7a1c7a1..bceb187 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -3,6 +3,7 @@ package redis import ( "fmt" "strconv" + "time" "github.com/authorizerdev/authorizer/server/constants" log "github.com/sirupsen/logrus" @@ -16,8 +17,11 @@ var ( ) // SetUserSession sets the user session for given user identifier in form recipe:user_id -func (c *provider) SetUserSession(userId, key, token string) error { - err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, 0).Err() +func (c *provider) SetUserSession(userId, key, token string, expiration int64) error { + currentTime := time.Now() + expireTime := time.Unix(expiration, 0) + duration := expireTime.Sub(currentTime) + err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, duration).Err() if err != nil { log.Debug("Error saving user session to redis: ", err) return err @@ -38,37 +42,35 @@ func (c *provider) GetUserSession(userId, key string) (string, error) { func (c *provider) DeleteUserSession(userId, key string) error { if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeSessionToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - return err + fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeSessionToken, key) + // continue } if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeAccessToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - return err + fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeAccessToken, key) + // continue } if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeRefreshToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - return err + fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeRefreshToken, key) + // continue } return nil } // DeleteAllUserSessions deletes all the user session from redis func (c *provider) DeleteAllUserSessions(userID string) error { - namespaces := []string{ - constants.AuthRecipeMethodBasicAuth, - constants.AuthRecipeMethodMagicLinkLogin, - constants.AuthRecipeMethodApple, - constants.AuthRecipeMethodFacebook, - constants.AuthRecipeMethodGithub, - constants.AuthRecipeMethodGoogle, - constants.AuthRecipeMethodLinkedIn, - constants.AuthRecipeMethodTwitter, - constants.AuthRecipeMethodMicrosoft, + res := c.store.Keys(c.ctx, fmt.Sprintf("*%s*", userID)) + if res.Err() != nil { + log.Debug("Error getting all user sessions from redis: ", res.Err()) + return res.Err() } - for _, namespace := range namespaces { - err := c.store.Del(c.ctx, namespace+":"+userID).Err() + keys := res.Val() + for _, key := range keys { + err := c.store.Del(c.ctx, key).Err() if err != nil { log.Debug("Error deleting all user sessions from redis: ", err) - return err + continue } } return nil @@ -76,27 +78,19 @@ func (c *provider) DeleteAllUserSessions(userID string) error { // DeleteSessionForNamespace to delete session for a given namespace example google,github func (c *provider) DeleteSessionForNamespace(namespace string) error { - var cursor uint64 - for { - keys := []string{} - keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result() + res := c.store.Keys(c.ctx, fmt.Sprintf("%s:*", namespace)) + if res.Err() != nil { + log.Debug("Error getting all user sessions from redis: ", res.Err()) + return res.Err() + } + keys := res.Val() + for _, key := range keys { + err := c.store.Del(c.ctx, key).Err() if err != nil { - log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error()) - return err - } - - for _, key := range keys { - err := c.store.Del(c.ctx, key).Err() - if err != nil { - log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error()) - return err - } - } - if cursor == 0 { // no more keys - break + log.Debug("Error deleting all user sessions from redis: ", err) + continue } } - return nil } diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 4bae30a..28a2289 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -193,12 +193,12 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes cookie.SetSession(gc, authToken.FingerPrintHash) sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/mobile_login.go b/server/resolvers/mobile_login.go index b6302b4..9da0a53 100644 --- a/server/resolvers/mobile_login.go +++ b/server/resolvers/mobile_login.go @@ -195,12 +195,12 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m cookie.SetSession(gc, authToken.FingerPrintHash) sessionStoreKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/mobile_signup.go b/server/resolvers/mobile_signup.go index e6b0027..94d5b03 100644 --- a/server/resolvers/mobile_signup.go +++ b/server/resolvers/mobile_signup.go @@ -249,12 +249,12 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) sessionKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/session.go b/server/resolvers/session.go index 79ea012..b3454b6 100644 --- a/server/resolvers/session.go +++ b/server/resolvers/session.go @@ -99,12 +99,12 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod } cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } return res, nil } diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 43f2a96..433f801 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -91,7 +91,6 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR } inputRoles := []string{} - if len(params.Roles) > 0 { // check if roles exists rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) @@ -293,12 +292,12 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index 47b4429..d1fd81d 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -150,12 +150,12 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m sessionKey := loginMethod + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } return res, nil } diff --git a/server/resolvers/verify_otp.go b/server/resolvers/verify_otp.go index 016678d..80080d9 100644 --- a/server/resolvers/verify_otp.go +++ b/server/resolvers/verify_otp.go @@ -123,12 +123,12 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod sessionKey := loginMethod + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } return res, nil } diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go index e2fcf8c..d2ab257 100644 --- a/server/test/validate_jwt_token_test.go +++ b/server/test/validate_jwt_token_test.go @@ -55,11 +55,11 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth, nonce, "") assert.NoError(t, err) assert.NotNil(t, authToken) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } t.Run(`should validate the access token`, func(t *testing.T) { diff --git a/server/token/auth_token.go b/server/token/auth_token.go index deeda0a..6d2c942 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -30,11 +30,13 @@ type JWTToken struct { // Token object to hold the finger print and refresh token information type Token struct { - FingerPrint string `json:"fingerprint"` - FingerPrintHash string `json:"fingerprint_hash"` - RefreshToken *JWTToken `json:"refresh_token"` - AccessToken *JWTToken `json:"access_token"` - IDToken *JWTToken `json:"id_token"` + FingerPrint string `json:"fingerprint"` + // Session Token + FingerPrintHash string `json:"fingerprint_hash"` + SessionTokenExpiresAt int64 `json:"expires_at"` + RefreshToken *JWTToken `json:"refresh_token"` + AccessToken *JWTToken `json:"access_token"` + IDToken *JWTToken `json:"id_token"` } // SessionData @@ -51,7 +53,7 @@ type SessionData struct { // CreateAuthToken creates a new auth token when userlogs in func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, loginMethod, nonce string, code string) (*Token, error) { hostname := parsers.GetHost(gc) - _, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) + _, fingerPrintHash, sessionTokenExpiresAt, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) if err != nil { return nil, err } @@ -82,10 +84,11 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l } res := &Token{ - FingerPrint: nonce, - FingerPrintHash: fingerPrintHash, - AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt}, - IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt}, + FingerPrint: nonce, + FingerPrintHash: fingerPrintHash, + SessionTokenExpiresAt: sessionTokenExpiresAt, + AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt}, + IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt}, } if utils.StringSliceContains(scope, "offline_access") { @@ -101,7 +104,8 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l } // CreateSessionToken creates a new session token -func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, error) { +func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, int64, error) { + expiresAt := time.Now().AddDate(1, 0, 0).Unix() fingerPrintMap := &SessionData{ Nonce: nonce, Roles: roles, @@ -109,15 +113,15 @@ func CreateSessionToken(user models.User, nonce string, roles, scope []string, l Scope: scope, LoginMethod: loginMethod, IssuedAt: time.Now().Unix(), - ExpiresAt: time.Now().AddDate(1, 0, 0).Unix(), + ExpiresAt: expiresAt, } fingerPrintBytes, _ := json.Marshal(fingerPrintMap) fingerPrintHash, err := crypto.EncryptAES(string(fingerPrintBytes)) if err != nil { - return nil, "", err + return nil, "", 0, err } - return fingerPrintMap, fingerPrintHash, nil + return fingerPrintMap, fingerPrintHash, expiresAt, nil } // CreateRefreshToken util to create JWT token