From 9a284c03ca1639315ed88bf79b30cc76ab923e4c Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Mon, 3 Apr 2023 10:26:27 +0530 Subject: [PATCH 1/3] fix: redis session --- server/go.mod | 1 + server/go.sum | 6 +++++ .../memorystore/providers/inmemory/store.go | 6 ----- .../providers/inmemory/stores/env_store.go | 4 ++++ .../inmemory/stores/session_store.go | 14 ++++++++--- .../providers/inmemory/stores/state_store.go | 2 ++ server/memorystore/providers/providers.go | 2 -- .../memorystore/providers/redis/provider.go | 6 ++--- server/memorystore/providers/redis/store.go | 24 ++++++------------- 9 files changed, 34 insertions(+), 31 deletions(-) diff --git a/server/go.mod b/server/go.mod index 1402756..9e8db75 100644 --- a/server/go.mod +++ b/server/go.mod @@ -21,6 +21,7 @@ require ( github.com/joho/godotenv v1.3.0 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect + github.com/redis/go-redis/v9 v9.0.3 // indirect github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.8.0 diff --git a/server/go.sum b/server/go.sum index e6c51a5..98b2ca9 100644 --- a/server/go.sum +++ b/server/go.sum @@ -58,11 +58,15 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= +github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenkalti/backoff/v4 v4.1.2 h1:6Yo7N8UP2K6LWZnW94DLVSSrbobcWdVzAYOisuDPIFo= github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -295,6 +299,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= +github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f h1:a7clxaGmmqtdNTXyvrp/lVO/Gnkzlhc/+dLs5v965GM= diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go index befd77f..85f46c5 100644 --- a/server/memorystore/providers/inmemory/store.go +++ b/server/memorystore/providers/inmemory/store.go @@ -13,12 +13,6 @@ func (c *provider) SetUserSession(userId, key, token string) error { return nil } -// GetAllUserSessions returns all the user sessions token from the in-memory store. -func (c *provider) GetAllUserSessions(userId string) (map[string]string, error) { - data := c.sessionStore.GetAll(userId) - return data, nil -} - // GetUserSession returns value for given session token func (c *provider) GetUserSession(userId, sessionToken string) (string, error) { return c.sessionStore.Get(userId, sessionToken), nil diff --git a/server/memorystore/providers/inmemory/stores/env_store.go b/server/memorystore/providers/inmemory/stores/env_store.go index d6ffe6a..46a909a 100644 --- a/server/memorystore/providers/inmemory/stores/env_store.go +++ b/server/memorystore/providers/inmemory/stores/env_store.go @@ -31,11 +31,15 @@ func (e *EnvStore) UpdateStore(store map[string]interface{}) { // GetStore returns the env store func (e *EnvStore) GetStore() map[string]interface{} { + e.mutex.Lock() + defer e.mutex.Unlock() return e.store } // Get returns the value of the key in evn store func (e *EnvStore) Get(key string) interface{} { + e.mutex.Lock() + defer e.mutex.Unlock() return e.store[key] } diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index d035312..627d37e 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -5,23 +5,31 @@ import ( "sync" ) +// SessionEntry is the struct for entry stored in store +type SessionEntry struct { + Value string + ExpiresAt int64 +} + // SessionStore struct to store the env variables type SessionStore struct { mutex sync.Mutex - store map[string]map[string]string + store map[string]map[string]*SessionEntry } // NewSessionStore create a new session store func NewSessionStore() *SessionStore { return &SessionStore{ mutex: sync.Mutex{}, - store: make(map[string]map[string]string), + store: make(map[string]map[string]*SessionEntry), } } // Get returns the value of the key in state store func (s *SessionStore) Get(key, subKey string) string { - return s.store[key][subKey] + s.mutex.Lock() + defer s.mutex.Unlock() + return s.store[key][subKey].Value } // Set sets the value of the key in state store diff --git a/server/memorystore/providers/inmemory/stores/state_store.go b/server/memorystore/providers/inmemory/stores/state_store.go index 2ba8417..2189a11 100644 --- a/server/memorystore/providers/inmemory/stores/state_store.go +++ b/server/memorystore/providers/inmemory/stores/state_store.go @@ -20,6 +20,8 @@ func NewStateStore() *StateStore { // Get returns the value of the key in state store func (s *StateStore) Get(key string) string { + s.mutex.Lock() + defer s.mutex.Unlock() return s.store[key] } diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index a4816c5..7388a97 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -4,8 +4,6 @@ package providers type Provider interface { // SetUserSession sets the user session for given user identifier in form recipe:user_id SetUserSession(userId, key, token string) error - // GetAllUserSessions returns all the user sessions from the session store - GetAllUserSessions(userId string) (map[string]string, error) // GetUserSession returns the session token for given token GetUserSession(userId, key string) (string, error) // DeleteUserSession deletes the user session diff --git a/server/memorystore/providers/redis/provider.go b/server/memorystore/providers/redis/provider.go index a91a300..084f207 100644 --- a/server/memorystore/providers/redis/provider.go +++ b/server/memorystore/providers/redis/provider.go @@ -5,7 +5,7 @@ import ( "strings" "time" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" log "github.com/sirupsen/logrus" ) @@ -17,10 +17,11 @@ type RedisClient interface { HMGet(ctx context.Context, key string, fields ...string) *redis.SliceCmd HSet(ctx context.Context, key string, values ...interface{}) *redis.IntCmd HGet(ctx context.Context, key, field string) *redis.StringCmd - HGetAll(ctx context.Context, key string) *redis.StringStringMapCmd + HGetAll(ctx context.Context, key string) *redis.MapStringStringCmd Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd Get(ctx context.Context, key string) *redis.StringCmd Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd + Keys(ctx context.Context, pattern string) *redis.StringSliceCmd } type provider struct { @@ -70,7 +71,6 @@ func NewRedisProvider(redisURL string) (*provider, error) { log.Debug("error connecting to redis: ", err) return nil, err } - return &provider{ ctx: ctx, store: rdb, diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index 21e3ecc..7a1c7a1 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -1,6 +1,7 @@ package redis import ( + "fmt" "strconv" "github.com/authorizerdev/authorizer/server/constants" @@ -16,28 +17,17 @@ var ( // SetUserSession sets the user session for given user identifier in form recipe:user_id func (c *provider) SetUserSession(userId, key, token string) error { - err := c.store.HSet(c.ctx, userId, key, token).Err() + err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, 0).Err() if err != nil { - log.Debug("Error saving to redis: ", err) + log.Debug("Error saving user session to redis: ", err) return err } return nil } -// GetAllUserSessions returns all the user session token from the redis store. -func (c *provider) GetAllUserSessions(userID string) (map[string]string, error) { - data, err := c.store.HGetAll(c.ctx, userID).Result() - if err != nil { - log.Debug("error getting all user sessions from redis store: ", err) - return nil, err - } - - return data, nil -} - // GetUserSession returns the user session from redis store. func (c *provider) GetUserSession(userId, key string) (string, error) { - data, err := c.store.HGet(c.ctx, userId, key).Result() + data, err := c.store.Get(c.ctx, fmt.Sprintf("%s:%s", userId, key)).Result() if err != nil { return "", err } @@ -46,15 +36,15 @@ func (c *provider) GetUserSession(userId, key string) (string, error) { // DeleteUserSession deletes the user session from redis store. func (c *provider) DeleteUserSession(userId, key string) error { - if err := c.store.HDel(c.ctx, userId, constants.TokenTypeSessionToken+"_"+key).Err(); err != nil { + if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeSessionToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) return err } - if err := c.store.HDel(c.ctx, userId, constants.TokenTypeAccessToken+"_"+key).Err(); err != nil { + if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeAccessToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) return err } - if err := c.store.HDel(c.ctx, userId, constants.TokenTypeRefreshToken+"_"+key).Err(); err != nil { + if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeRefreshToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) return err } From 02c0ebb9c40ceea22486ac83f25b12371924ac87 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Sat, 8 Apr 2023 13:06:15 +0530 Subject: [PATCH 2/3] fix: session storage --- server/handlers/authorize.go | 10 +- server/handlers/logout.go | 9 +- server/handlers/oauth_callback.go | 6 +- server/handlers/token.go | 6 +- server/handlers/verify_email.go | 6 +- .../providers/inmemory/provider_test.go | 14 +++ .../memorystore/providers/inmemory/store.go | 30 ++--- .../inmemory/stores/session_store.go | 66 ++++++---- .../memorystore/providers/provider_tests.go | 115 ++++++++++++++++++ server/memorystore/providers/providers.go | 2 +- .../memorystore/providers/redis/provider.go | 1 - .../providers/redis/provider_test.go | 15 +++ server/memorystore/providers/redis/store.go | 66 +++++----- server/resolvers/login.go | 6 +- server/resolvers/mobile_login.go | 6 +- server/resolvers/mobile_signup.go | 6 +- server/resolvers/session.go | 6 +- server/resolvers/signup.go | 7 +- server/resolvers/verify_email.go | 6 +- server/resolvers/verify_otp.go | 6 +- server/test/validate_jwt_token_test.go | 6 +- server/token/auth_token.go | 32 ++--- 22 files changed, 290 insertions(+), 137 deletions(-) create mode 100644 server/memorystore/providers/inmemory/provider_test.go create mode 100644 server/memorystore/providers/provider_tests.go create mode 100644 server/memorystore/providers/redis/provider_test.go diff --git a/server/handlers/authorize.go b/server/handlers/authorize.go index bad2363..d5e9a28 100644 --- a/server/handlers/authorize.go +++ b/server/handlers/authorize.go @@ -194,7 +194,7 @@ func AuthorizeHandler() gin.HandlerFunc { // rollover the session for security go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) if responseType == constants.ResponseTypeCode { - newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod) + newSessionTokenData, newSessionToken, newSessionExpiresAt, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod) if err != nil { log.Debug("CreateSessionToken failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) @@ -215,7 +215,7 @@ func AuthorizeHandler() gin.HandlerFunc { return } - if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+newSessionTokenData.Nonce, newSessionToken); err != nil { + if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+newSessionTokenData.Nonce, newSessionToken, newSessionExpiresAt); err != nil { log.Debug("SetUserSession failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) return @@ -271,13 +271,13 @@ func AuthorizeHandler() gin.HandlerFunc { return } - if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+nonce, authToken.FingerPrintHash); err != nil { + if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+nonce, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt); err != nil { log.Debug("SetUserSession failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) return } - if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce, authToken.AccessToken.Token); err != nil { + if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt); err != nil { log.Debug("SetUserSession failed: ", err) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) return @@ -305,7 +305,7 @@ func AuthorizeHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token params += "&refresh_token=" + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } if responseMode == constants.ResponseModeQuery { diff --git a/server/handlers/logout.go b/server/handlers/logout.go index bf1d69e..fbac22d 100644 --- a/server/handlers/logout.go +++ b/server/handlers/logout.go @@ -47,7 +47,14 @@ func LogoutHandler() gin.HandlerFunc { return } - memorystore.Provider.DeleteUserSession(sessionData.Subject, sessionData.Nonce) + userID := sessionData.Subject + loginMethod := sessionData.LoginMethod + sessionToken := userID + if loginMethod != "" { + sessionToken = loginMethod + ":" + userID + } + + memorystore.Provider.DeleteUserSession(sessionToken, sessionData.Nonce) cookie.DeleteSession(gc) if redirectURL != "" { diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index c206458..1b682f5 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -249,12 +249,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { sessionKey := provider + ":" + user.ID cookie.SetSession(ctx, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { params += `&refresh_token=` + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/handlers/token.go b/server/handlers/token.go index 8249d6f..2b38f7e 100644 --- a/server/handlers/token.go +++ b/server/handlers/token.go @@ -247,8 +247,8 @@ func TokenHandler() gin.HandlerFunc { return } - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -266,7 +266,7 @@ func TokenHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } gc.JSON(http.StatusOK, res) diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index cf5ec1a..91a2128 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -154,12 +154,12 @@ func VerifyEmailHandler() gin.HandlerFunc { sessionKey := loginMethod + ":" + user.ID cookie.SetSession(c, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { params = params + `&refresh_token=` + authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } if redirectURL == "" { diff --git a/server/memorystore/providers/inmemory/provider_test.go b/server/memorystore/providers/inmemory/provider_test.go new file mode 100644 index 0000000..99446a1 --- /dev/null +++ b/server/memorystore/providers/inmemory/provider_test.go @@ -0,0 +1,14 @@ +package inmemory + +import ( + "testing" + + "github.com/authorizerdev/authorizer/server/memorystore/providers" + "github.com/stretchr/testify/assert" +) + +func TestInMemoryProvider(t *testing.T) { + p, err := NewInMemoryProvider() + assert.NoError(t, err) + providers.ProviderTests(t, p) +} diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go index 85f46c5..4a8e8ce 100644 --- a/server/memorystore/providers/inmemory/store.go +++ b/server/memorystore/providers/inmemory/store.go @@ -8,39 +8,31 @@ import ( ) // SetUserSession sets the user session for given user identifier in form recipe:user_id -func (c *provider) SetUserSession(userId, key, token string) error { - c.sessionStore.Set(userId, key, token) +func (c *provider) SetUserSession(userId, key, token string, expiration int64) error { + c.sessionStore.Set(userId, key, token, expiration) return nil } // GetUserSession returns value for given session token func (c *provider) GetUserSession(userId, sessionToken string) (string, error) { - return c.sessionStore.Get(userId, sessionToken), nil + val := c.sessionStore.Get(userId, sessionToken) + if val == "" { + return "", fmt.Errorf("Not found") + } + return val, nil } // DeleteAllUserSessions deletes all the user sessions from in-memory store. func (c *provider) DeleteAllUserSessions(userId string) error { - namespaces := []string{ - constants.AuthRecipeMethodBasicAuth, - constants.AuthRecipeMethodMagicLinkLogin, - constants.AuthRecipeMethodApple, - constants.AuthRecipeMethodFacebook, - constants.AuthRecipeMethodGithub, - constants.AuthRecipeMethodGoogle, - constants.AuthRecipeMethodLinkedIn, - constants.AuthRecipeMethodTwitter, - constants.AuthRecipeMethodMicrosoft, - } - - for _, namespace := range namespaces { - c.sessionStore.RemoveAll(namespace + ":" + userId) - } + c.sessionStore.RemoveAll(userId) return nil } // DeleteUserSession deletes the user session from the in-memory store. func (c *provider) DeleteUserSession(userId, sessionToken string) error { - c.sessionStore.Remove(userId, sessionToken) + c.sessionStore.Remove(userId, constants.TokenTypeSessionToken+"_"+sessionToken) + c.sessionStore.Remove(userId, constants.TokenTypeAccessToken+"_"+sessionToken) + c.sessionStore.Remove(userId, constants.TokenTypeRefreshToken+"_"+sessionToken) return nil } diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index 627d37e..d3d429d 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -1,8 +1,15 @@ package stores import ( + "fmt" "strings" "sync" + "time" +) + +const ( + // Maximum entries to keep in session storage + maxCacheSize = 1000 ) // SessionEntry is the struct for entry stored in store @@ -13,15 +20,16 @@ type SessionEntry struct { // SessionStore struct to store the env variables type SessionStore struct { - mutex sync.Mutex - store map[string]map[string]*SessionEntry + mutex sync.Mutex + store map[string]*SessionEntry + itemsToEvict []string } // NewSessionStore create a new session store func NewSessionStore() *SessionStore { return &SessionStore{ mutex: sync.Mutex{}, - store: make(map[string]map[string]*SessionEntry), + store: make(map[string]*SessionEntry), } } @@ -29,53 +37,59 @@ func NewSessionStore() *SessionStore { func (s *SessionStore) Get(key, subKey string) string { s.mutex.Lock() defer s.mutex.Unlock() - return s.store[key][subKey].Value + currentTime := time.Now().Unix() + k := fmt.Sprintf("%s:%s", key, subKey) + if v, ok := s.store[k]; ok { + if v.ExpiresAt > currentTime { + return v.Value + } + s.itemsToEvict = append(s.itemsToEvict, k) + } + return "" } // Set sets the value of the key in state store -func (s *SessionStore) Set(key string, subKey, value string) { +func (s *SessionStore) Set(key string, subKey, value string, expiration int64) { s.mutex.Lock() defer s.mutex.Unlock() - - if _, ok := s.store[key]; !ok { - s.store[key] = make(map[string]string) + k := fmt.Sprintf("%s:%s", key, subKey) + if _, ok := s.store[k]; !ok { + s.store[k] = &SessionEntry{ + Value: value, + ExpiresAt: expiration, + // TODO add expire time + } + } + s.store[k] = &SessionEntry{ + Value: value, + ExpiresAt: expiration, + // TODO add expire time } - s.store[key][subKey] = value } // RemoveAll all values for given key func (s *SessionStore) RemoveAll(key string) { s.mutex.Lock() defer s.mutex.Unlock() - - delete(s.store, key) + for k := range s.store { + if strings.Contains(k, key) { + delete(s.store, k) + } + } } // Remove value for given key and subkey func (s *SessionStore) Remove(key, subKey string) { s.mutex.Lock() defer s.mutex.Unlock() - if _, ok := s.store[key]; ok { - delete(s.store[key], subKey) - } -} - -// Get all the values for given key -func (s *SessionStore) GetAll(key string) map[string]string { - s.mutex.Lock() - defer s.mutex.Unlock() - - if _, ok := s.store[key]; !ok { - s.store[key] = make(map[string]string) - } - return s.store[key] + k := fmt.Sprintf("%s:%s", key, subKey) + delete(s.store, k) } // RemoveByNamespace to delete session for a given namespace example google,github func (s *SessionStore) RemoveByNamespace(namespace string) error { s.mutex.Lock() defer s.mutex.Unlock() - for key := range s.store { if strings.Contains(key, namespace+":") { delete(s.store, key) diff --git a/server/memorystore/providers/provider_tests.go b/server/memorystore/providers/provider_tests.go new file mode 100644 index 0000000..e569fe8 --- /dev/null +++ b/server/memorystore/providers/provider_tests.go @@ -0,0 +1,115 @@ +package providers + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// ProviderTests runs all provider tests +func ProviderTests(t *testing.T, p Provider) { + + err := p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Same user multiple session + err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "access_token_key1", "test_jwt1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Different user session + err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(5*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider:124", "access_token_key", "test_jwt124", time.Now().Add(5*time.Second).Unix()) + assert.NoError(t, err) + // Different provider session + err = p.SetUserSession("auth_provider1:124", "session_token_key", "test_hash124", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider1:124", "access_token_key", "test_jwt124", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Different provider session + err = p.SetUserSession("auth_provider1:123", "session_token_key", "test_hash1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + err = p.SetUserSession("auth_provider1:123", "access_token_key", "test_jwt1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + // Get session + key, err := p.GetUserSession("auth_provider:123", "session_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_hash123", key) + key, err = p.GetUserSession("auth_provider:123", "access_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_jwt123", key) + key, err = p.GetUserSession("auth_provider:124", "session_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_hash124", key) + key, err = p.GetUserSession("auth_provider:124", "access_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_jwt124", key) + // Expire some tokens and make sure they are empty + time.Sleep(5 * time.Second) + key, err = p.GetUserSession("auth_provider:124", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + // Delete user session + err = p.DeleteUserSession("auth_provider:123", "key") + assert.NoError(t, err) + err = p.DeleteUserSession("auth_provider:123", "key") + assert.NoError(t, err) + key, err = p.GetUserSession("auth_provider:123", "key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + // Delete all user session + err = p.DeleteAllUserSessions("123") + assert.NoError(t, err) + err = p.DeleteAllUserSessions("123") + assert.NoError(t, err) + key, err = p.GetUserSession("auth_provider:123", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "access_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + // Delete namespace + err = p.DeleteSessionForNamespace("auth_provider") + assert.NoError(t, err) + err = p.DeleteSessionForNamespace("auth_provider1") + assert.NoError(t, err) + key, err = p.GetUserSession("auth_provider:123", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "access_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:123", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "access_token_key1") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:124", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) + key, err = p.GetUserSession("auth_provider1:124", "access_token_key") + assert.Empty(t, key) + assert.Error(t, err) +} diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index 7388a97..db58aa7 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -3,7 +3,7 @@ package providers // Provider defines current memory store provider type Provider interface { // SetUserSession sets the user session for given user identifier in form recipe:user_id - SetUserSession(userId, key, token string) error + SetUserSession(userId, key, token string, expiration int64) error // GetUserSession returns the session token for given token GetUserSession(userId, key string) (string, error) // DeleteUserSession deletes the user session diff --git a/server/memorystore/providers/redis/provider.go b/server/memorystore/providers/redis/provider.go index 084f207..894a75e 100644 --- a/server/memorystore/providers/redis/provider.go +++ b/server/memorystore/providers/redis/provider.go @@ -32,7 +32,6 @@ type provider struct { // NewRedisProvider returns a new redis provider func NewRedisProvider(redisURL string) (*provider, error) { redisURLHostPortsList := strings.Split(redisURL, ",") - if len(redisURLHostPortsList) > 1 { opt, err := redis.ParseURL(redisURLHostPortsList[0]) if err != nil { diff --git a/server/memorystore/providers/redis/provider_test.go b/server/memorystore/providers/redis/provider_test.go new file mode 100644 index 0000000..280616c --- /dev/null +++ b/server/memorystore/providers/redis/provider_test.go @@ -0,0 +1,15 @@ +package redis + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/authorizerdev/authorizer/server/memorystore/providers" +) + +func TestRedisProvider(t *testing.T) { + p, err := NewRedisProvider("redis://127.0.0.1:6379") + assert.NoError(t, err) + providers.ProviderTests(t, p) +} diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index 7a1c7a1..bceb187 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -3,6 +3,7 @@ package redis import ( "fmt" "strconv" + "time" "github.com/authorizerdev/authorizer/server/constants" log "github.com/sirupsen/logrus" @@ -16,8 +17,11 @@ var ( ) // SetUserSession sets the user session for given user identifier in form recipe:user_id -func (c *provider) SetUserSession(userId, key, token string) error { - err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, 0).Err() +func (c *provider) SetUserSession(userId, key, token string, expiration int64) error { + currentTime := time.Now() + expireTime := time.Unix(expiration, 0) + duration := expireTime.Sub(currentTime) + err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, duration).Err() if err != nil { log.Debug("Error saving user session to redis: ", err) return err @@ -38,37 +42,35 @@ func (c *provider) GetUserSession(userId, key string) (string, error) { func (c *provider) DeleteUserSession(userId, key string) error { if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeSessionToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - return err + fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeSessionToken, key) + // continue } if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeAccessToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - return err + fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeAccessToken, key) + // continue } if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeRefreshToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - return err + fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeRefreshToken, key) + // continue } return nil } // DeleteAllUserSessions deletes all the user session from redis func (c *provider) DeleteAllUserSessions(userID string) error { - namespaces := []string{ - constants.AuthRecipeMethodBasicAuth, - constants.AuthRecipeMethodMagicLinkLogin, - constants.AuthRecipeMethodApple, - constants.AuthRecipeMethodFacebook, - constants.AuthRecipeMethodGithub, - constants.AuthRecipeMethodGoogle, - constants.AuthRecipeMethodLinkedIn, - constants.AuthRecipeMethodTwitter, - constants.AuthRecipeMethodMicrosoft, + res := c.store.Keys(c.ctx, fmt.Sprintf("*%s*", userID)) + if res.Err() != nil { + log.Debug("Error getting all user sessions from redis: ", res.Err()) + return res.Err() } - for _, namespace := range namespaces { - err := c.store.Del(c.ctx, namespace+":"+userID).Err() + keys := res.Val() + for _, key := range keys { + err := c.store.Del(c.ctx, key).Err() if err != nil { log.Debug("Error deleting all user sessions from redis: ", err) - return err + continue } } return nil @@ -76,27 +78,19 @@ func (c *provider) DeleteAllUserSessions(userID string) error { // DeleteSessionForNamespace to delete session for a given namespace example google,github func (c *provider) DeleteSessionForNamespace(namespace string) error { - var cursor uint64 - for { - keys := []string{} - keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result() + res := c.store.Keys(c.ctx, fmt.Sprintf("%s:*", namespace)) + if res.Err() != nil { + log.Debug("Error getting all user sessions from redis: ", res.Err()) + return res.Err() + } + keys := res.Val() + for _, key := range keys { + err := c.store.Del(c.ctx, key).Err() if err != nil { - log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error()) - return err - } - - for _, key := range keys { - err := c.store.Del(c.ctx, key).Err() - if err != nil { - log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error()) - return err - } - } - if cursor == 0 { // no more keys - break + log.Debug("Error deleting all user sessions from redis: ", err) + continue } } - return nil } diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 4bae30a..28a2289 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -193,12 +193,12 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes cookie.SetSession(gc, authToken.FingerPrintHash) sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/mobile_login.go b/server/resolvers/mobile_login.go index b6302b4..9da0a53 100644 --- a/server/resolvers/mobile_login.go +++ b/server/resolvers/mobile_login.go @@ -195,12 +195,12 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m cookie.SetSession(gc, authToken.FingerPrintHash) sessionStoreKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/mobile_signup.go b/server/resolvers/mobile_signup.go index e6b0027..94d5b03 100644 --- a/server/resolvers/mobile_signup.go +++ b/server/resolvers/mobile_signup.go @@ -249,12 +249,12 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) sessionKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/session.go b/server/resolvers/session.go index 79ea012..b3454b6 100644 --- a/server/resolvers/session.go +++ b/server/resolvers/session.go @@ -99,12 +99,12 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod } cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } return res, nil } diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 43f2a96..433f801 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -91,7 +91,6 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR } inputRoles := []string{} - if len(params.Roles) > 0 { // check if roles exists rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) @@ -293,12 +292,12 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } go func() { diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index 47b4429..d1fd81d 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -150,12 +150,12 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m sessionKey := loginMethod + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } return res, nil } diff --git a/server/resolvers/verify_otp.go b/server/resolvers/verify_otp.go index 016678d..80080d9 100644 --- a/server/resolvers/verify_otp.go +++ b/server/resolvers/verify_otp.go @@ -123,12 +123,12 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod sessionKey := loginMethod + ":" + user.ID cookie.SetSession(gc, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } return res, nil } diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go index e2fcf8c..d2ab257 100644 --- a/server/test/validate_jwt_token_test.go +++ b/server/test/validate_jwt_token_test.go @@ -55,11 +55,11 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth, nonce, "") assert.NoError(t, err) assert.NotNil(t, authToken) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt) if authToken.RefreshToken != nil { - memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) + memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt) } t.Run(`should validate the access token`, func(t *testing.T) { diff --git a/server/token/auth_token.go b/server/token/auth_token.go index deeda0a..6d2c942 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -30,11 +30,13 @@ type JWTToken struct { // Token object to hold the finger print and refresh token information type Token struct { - FingerPrint string `json:"fingerprint"` - FingerPrintHash string `json:"fingerprint_hash"` - RefreshToken *JWTToken `json:"refresh_token"` - AccessToken *JWTToken `json:"access_token"` - IDToken *JWTToken `json:"id_token"` + FingerPrint string `json:"fingerprint"` + // Session Token + FingerPrintHash string `json:"fingerprint_hash"` + SessionTokenExpiresAt int64 `json:"expires_at"` + RefreshToken *JWTToken `json:"refresh_token"` + AccessToken *JWTToken `json:"access_token"` + IDToken *JWTToken `json:"id_token"` } // SessionData @@ -51,7 +53,7 @@ type SessionData struct { // CreateAuthToken creates a new auth token when userlogs in func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, loginMethod, nonce string, code string) (*Token, error) { hostname := parsers.GetHost(gc) - _, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) + _, fingerPrintHash, sessionTokenExpiresAt, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) if err != nil { return nil, err } @@ -82,10 +84,11 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l } res := &Token{ - FingerPrint: nonce, - FingerPrintHash: fingerPrintHash, - AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt}, - IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt}, + FingerPrint: nonce, + FingerPrintHash: fingerPrintHash, + SessionTokenExpiresAt: sessionTokenExpiresAt, + AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt}, + IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt}, } if utils.StringSliceContains(scope, "offline_access") { @@ -101,7 +104,8 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l } // CreateSessionToken creates a new session token -func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, error) { +func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, int64, error) { + expiresAt := time.Now().AddDate(1, 0, 0).Unix() fingerPrintMap := &SessionData{ Nonce: nonce, Roles: roles, @@ -109,15 +113,15 @@ func CreateSessionToken(user models.User, nonce string, roles, scope []string, l Scope: scope, LoginMethod: loginMethod, IssuedAt: time.Now().Unix(), - ExpiresAt: time.Now().AddDate(1, 0, 0).Unix(), + ExpiresAt: expiresAt, } fingerPrintBytes, _ := json.Marshal(fingerPrintMap) fingerPrintHash, err := crypto.EncryptAES(string(fingerPrintBytes)) if err != nil { - return nil, "", err + return nil, "", 0, err } - return fingerPrintMap, fingerPrintHash, nil + return fingerPrintMap, fingerPrintHash, expiresAt, nil } // CreateRefreshToken util to create JWT token From 428a0be3db8ec9fb1a5b5007f1d12d6c08541eb1 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Sat, 8 Apr 2023 18:02:53 +0530 Subject: [PATCH 3/3] feat: add cache clear --- .../inmemory/stores/session_store.go | 73 +++++++++++++++---- server/memorystore/providers/redis/store.go | 3 - 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index d3d429d..5226e43 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -2,6 +2,7 @@ package stores import ( "fmt" + "sort" "strings" "sync" "time" @@ -10,6 +11,8 @@ import ( const ( // Maximum entries to keep in session storage maxCacheSize = 1000 + // Cache clear interval + clearInterval = 10 * time.Minute ) // SessionEntry is the struct for entry stored in store @@ -20,30 +23,64 @@ type SessionEntry struct { // SessionStore struct to store the env variables type SessionStore struct { - mutex sync.Mutex - store map[string]*SessionEntry - itemsToEvict []string + wg sync.WaitGroup + mutex sync.RWMutex + store map[string]*SessionEntry + // stores expireTime: key to remove data when cache is full + // map is sorted by key so older most entry can be deleted first + keyIndex map[int64]string + stop chan struct{} } // NewSessionStore create a new session store func NewSessionStore() *SessionStore { - return &SessionStore{ - mutex: sync.Mutex{}, - store: make(map[string]*SessionEntry), + store := &SessionStore{ + mutex: sync.RWMutex{}, + store: make(map[string]*SessionEntry), + keyIndex: make(map[int64]string), + stop: make(chan struct{}), + } + store.wg.Add(1) + go func() { + defer store.wg.Done() + store.clean() + }() + return store +} + +func (s *SessionStore) clean() { + t := time.NewTicker(clearInterval) + defer t.Stop() + for { + select { + case <-s.stop: + return + case <-t.C: + s.mutex.Lock() + currentTime := time.Now().Unix() + for k, v := range s.store { + if v.ExpiresAt < currentTime { + delete(s.store, k) + delete(s.keyIndex, v.ExpiresAt) + } + } + s.mutex.Unlock() + } } } // Get returns the value of the key in state store func (s *SessionStore) Get(key, subKey string) string { - s.mutex.Lock() - defer s.mutex.Unlock() + s.mutex.RLock() + defer s.mutex.RUnlock() currentTime := time.Now().Unix() k := fmt.Sprintf("%s:%s", key, subKey) if v, ok := s.store[k]; ok { if v.ExpiresAt > currentTime { return v.Value } - s.itemsToEvict = append(s.itemsToEvict, k) + // Delete expired items + delete(s.store, k) } return "" } @@ -54,17 +91,25 @@ func (s *SessionStore) Set(key string, subKey, value string, expiration int64) { defer s.mutex.Unlock() k := fmt.Sprintf("%s:%s", key, subKey) if _, ok := s.store[k]; !ok { - s.store[k] = &SessionEntry{ - Value: value, - ExpiresAt: expiration, - // TODO add expire time + // check if there is enough space in cache + // else delete entries based on FIFO + if len(s.store) == maxCacheSize { + // remove older most entry + sortedKeys := []int64{} + for ik := range s.keyIndex { + sortedKeys = append(sortedKeys, ik) + } + sort.Slice(sortedKeys, func(i, j int) bool { return sortedKeys[i] < sortedKeys[j] }) + itemToRemove := sortedKeys[0] + delete(s.store, s.keyIndex[itemToRemove]) + delete(s.keyIndex, itemToRemove) } } s.store[k] = &SessionEntry{ Value: value, ExpiresAt: expiration, - // TODO add expire time } + s.keyIndex[expiration] = k } // RemoveAll all values for given key diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index bceb187..058e95e 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -42,17 +42,14 @@ func (c *provider) GetUserSession(userId, key string) (string, error) { func (c *provider) DeleteUserSession(userId, key string) error { if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeSessionToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeSessionToken, key) // continue } if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeAccessToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeAccessToken, key) // continue } if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeRefreshToken+"_"+key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) - fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeRefreshToken, key) // continue } return nil