From 43359f1dbaaa1b66dc2db8398d58c4603ce51095 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Sun, 29 May 2022 17:22:46 +0530 Subject: [PATCH] fix: update store method till handlers --- server/constants/cookie.go | 8 + server/constants/env.go | 14 +- server/cookie/admin_cookie.go | 9 +- server/cookie/cookie.go | 13 +- server/crypto/aes.go | 30 ++- server/crypto/common.go | 40 +++- server/db/providers/arangodb/provider.go | 19 +- server/db/providers/arangodb/user.go | 8 +- server/db/providers/cassandradb/provider.go | 59 ++++-- server/db/providers/cassandradb/user.go | 8 +- server/db/providers/mongodb/provider.go | 24 ++- server/db/providers/mongodb/user.go | 8 +- server/db/providers/provider_template/user.go | 8 +- server/db/providers/sql/provider.go | 23 ++- server/db/providers/sql/user.go | 8 +- server/email/email.go | 52 ++++- server/email/forgot_password_email.go | 21 +- server/email/invite_email.go | 15 +- server/email/verification_email.go | 15 +- server/env/env.go | 191 +++++++++--------- server/env/persist_env.go | 105 +++++----- server/envstore/store.go | 111 ---------- server/handlers/app.go | 21 +- server/handlers/authorize.go | 3 +- server/handlers/dashboard.go | 6 +- server/handlers/jwks.go | 13 +- server/handlers/oauth_callback.go | 28 ++- server/handlers/oauth_login.go | 50 ++++- server/handlers/openid_config.go | 4 +- server/handlers/revoke.go | 3 +- server/handlers/token.go | 7 +- server/memorystore/memory_store.go | 20 ++ .../providers/inmemory/envstore.go | 41 ++++ .../providers/inmemory/provider.go | 9 +- .../inmemory/{inmemory.go => store.go} | 48 ++++- server/memorystore/providers/providers.go | 17 +- server/memorystore/providers/redis/reids.go | 85 -------- server/memorystore/providers/redis/store.go | 162 +++++++++++++++ server/memorystore/required_env_store.go | 3 +- server/oauth/oauth.go | 46 ++++- server/resolvers/validate_jwt_token.go | 4 +- server/token/auth_token.go | 16 +- server/utils/common.go | 11 + 43 files changed, 882 insertions(+), 504 deletions(-) create mode 100644 server/constants/cookie.go delete mode 100644 server/envstore/store.go create mode 100644 server/memorystore/providers/inmemory/envstore.go rename server/memorystore/providers/inmemory/{inmemory.go => store.go} (51%) delete mode 100644 server/memorystore/providers/redis/reids.go create mode 100644 server/memorystore/providers/redis/store.go diff --git a/server/constants/cookie.go b/server/constants/cookie.go new file mode 100644 index 0000000..71320a9 --- /dev/null +++ b/server/constants/cookie.go @@ -0,0 +1,8 @@ +package constants + +const ( + // AppCookieName is the name of the cookie that is used to store the application token + AppCookieName = "cookie" + // AdminCookieName is the name of the cookie that is used to store the admin token + AdminCookieName = "authorizer-admin" +) diff --git a/server/constants/env.go b/server/constants/env.go index b73048b..afca3f9 100644 --- a/server/constants/env.go +++ b/server/constants/env.go @@ -5,11 +5,11 @@ var VERSION = "0.0.1" const ( // Envstore identifier // StringStore string store identifier - StringStoreIdentifier = "stringStore" - // BoolStore bool store identifier - BoolStoreIdentifier = "boolStore" - // SliceStore slice store identifier - SliceStoreIdentifier = "sliceStore" + // StringStoreIdentifier = "stringStore" + // // BoolStore bool store identifier + // BoolStoreIdentifier = "boolStore" + // // SliceStore slice store identifier + // SliceStoreIdentifier = "sliceStore" // EnvKeyEnv key for env variable ENV EnvKeyEnv = "ENV" @@ -68,10 +68,6 @@ const ( EnvKeyAppURL = "APP_URL" // EnvKeyRedisURL key for env variable REDIS_URL EnvKeyRedisURL = "REDIS_URL" - // EnvKeyCookieName key for env variable COOKIE_NAME - EnvKeyCookieName = "COOKIE_NAME" - // EnvKeyAdminCookieName key for env variable ADMIN_COOKIE_NAME - EnvKeyAdminCookieName = "ADMIN_COOKIE_NAME" // EnvKeyResetPasswordURL key for env variable RESET_PASSWORD_URL EnvKeyResetPasswordURL = "RESET_PASSWORD_URL" // EnvKeyDisableEmailVerification key for env variable DISABLE_EMAIL_VERIFICATION diff --git a/server/cookie/admin_cookie.go b/server/cookie/admin_cookie.go index 58f2c56..22b87f7 100644 --- a/server/cookie/admin_cookie.go +++ b/server/cookie/admin_cookie.go @@ -4,7 +4,6 @@ import ( "net/url" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/utils" "github.com/gin-gonic/gin" ) @@ -15,13 +14,12 @@ func SetAdminCookie(gc *gin.Context, token string) { httpOnly := true hostname := utils.GetHost(gc) host, _ := utils.GetHostParts(hostname) - - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), token, 3600, "/", host, secure, httpOnly) + gc.SetCookie(constants.AdminCookieName, token, 3600, "/", host, secure, httpOnly) } // GetAdminCookie gets the admin cookie from the request func GetAdminCookie(gc *gin.Context) (string, error) { - cookie, err := gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName)) + cookie, err := gc.Request.Cookie(constants.AdminCookieName) if err != nil { return "", err } @@ -41,6 +39,5 @@ func DeleteAdminCookie(gc *gin.Context) { httpOnly := true hostname := utils.GetHost(gc) host, _ := utils.GetHostParts(hostname) - - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), "", -1, "/", host, secure, httpOnly) + gc.SetCookie(constants.AdminCookieName, "", -1, "/", host, secure, httpOnly) } diff --git a/server/cookie/cookie.go b/server/cookie/cookie.go index 54600af..562b886 100644 --- a/server/cookie/cookie.go +++ b/server/cookie/cookie.go @@ -5,7 +5,6 @@ import ( "net/url" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/utils" "github.com/gin-gonic/gin" ) @@ -25,8 +24,8 @@ func SetSession(gc *gin.Context, sessionID string) { year := 60 * 60 * 24 * 365 gc.SetSameSite(http.SameSiteNoneMode) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", sessionID, year, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain", sessionID, year, "/", domain, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session", sessionID, year, "/", host, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session_domain", sessionID, year, "/", domain, secure, httpOnly) } // DeleteSession sets session cookies to expire @@ -41,17 +40,17 @@ func DeleteSession(gc *gin.Context) { } gc.SetSameSite(http.SameSiteNoneMode) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", "", -1, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain", "", -1, "/", domain, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session", "", -1, "/", host, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session_domain", "", -1, "/", domain, secure, httpOnly) } // GetSession gets the session cookie from context func GetSession(gc *gin.Context) (string, error) { var cookie *http.Cookie var err error - cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session") + cookie, err = gc.Request.Cookie(constants.AppCookieName + "_session") if err != nil { - cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session_domain") + cookie, err = gc.Request.Cookie(constants.AppCookieName + "_session_domain") if err != nil { return "", err } diff --git a/server/crypto/aes.go b/server/crypto/aes.go index 8d06ffb..422f694 100644 --- a/server/crypto/aes.go +++ b/server/crypto/aes.go @@ -7,14 +7,18 @@ import ( "io" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) var bytes = []byte{35, 46, 57, 24, 85, 35, 24, 74, 87, 35, 88, 98, 66, 32, 14, 0o5} // EncryptAES method is to encrypt or hide any classified text func EncryptAES(text string) (string, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return "", err + } + key := []byte(k) block, err := aes.NewCipher(key) if err != nil { return "", err @@ -28,7 +32,11 @@ func EncryptAES(text string) (string, error) { // DecryptAES method is to extract back the encrypted text func DecryptAES(text string) (string, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return "", err + } + key := []byte(k) block, err := aes.NewCipher(key) if err != nil { return "", err @@ -46,9 +54,13 @@ func DecryptAES(text string) (string, error) { // EncryptAESEnv encrypts data using AES algorithm // kept for the backward compatibility of env data encryption func EncryptAESEnv(text []byte) ([]byte, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) - c, err := aes.NewCipher(key) var res []byte + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return res, err + } + key := []byte(k) + c, err := aes.NewCipher(key) if err != nil { return res, err } @@ -81,9 +93,13 @@ func EncryptAESEnv(text []byte) ([]byte, error) { // DecryptAES decrypts data using AES algorithm // Kept for the backward compatibility of env data decryption func DecryptAESEnv(ciphertext []byte) ([]byte, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) - c, err := aes.NewCipher(key) var res []byte + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return res, err + } + key := []byte(k) + c, err := aes.NewCipher(key) if err != nil { return res, err } diff --git a/server/crypto/common.go b/server/crypto/common.go index 35af515..91aed06 100644 --- a/server/crypto/common.go +++ b/server/crypto/common.go @@ -5,7 +5,7 @@ import ( "encoding/json" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "golang.org/x/crypto/bcrypt" "gopkg.in/square/go-jose.v2" ) @@ -37,20 +37,35 @@ func GetPubJWK(algo, keyID string, publicKey interface{}) (string, error) { // this is called while initializing app / when env is updated func GenerateJWKBasedOnEnv() (string, error) { jwk := "" - algo := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) - clientID := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) + algo, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + if err != nil { + return jwk, err + } + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return jwk, err + } + + jwtSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret) + if err != nil { + return jwk, err + } - var err error // check if jwt secret is provided if IsHMACA(algo) { - jwk, err = GetPubJWK(algo, clientID, []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) + jwk, err = GetPubJWK(algo, clientID, []byte(jwtSecret)) if err != nil { return "", err } } + jwtPublicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return jwk, err + } + if IsRSA(algo) { - publicKeyInstance, err := ParseRsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + publicKeyInstance, err := ParseRsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return "", err } @@ -62,7 +77,11 @@ func GenerateJWKBasedOnEnv() (string, error) { } if IsECDSA(algo) { - publicKeyInstance, err := ParseEcdsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + jwtPublicKey, err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return jwk, err + } + publicKeyInstance, err := ParseEcdsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return "", err } @@ -77,13 +96,16 @@ func GenerateJWKBasedOnEnv() (string, error) { } // EncryptEnvData is used to encrypt the env data -func EncryptEnvData(data envstore.Store) (string, error) { +func EncryptEnvData(data map[string]interface{}) (string, error) { jsonBytes, err := json.Marshal(data) if err != nil { return "", err } - storeData := envstore.EnvStoreObj.GetEnvStoreClone() + storeData, err := memorystore.Provider.GetEnvStore() + if err != nil { + return "", err + } err = json.Unmarshal(jsonBytes, &storeData) if err != nil { diff --git a/server/db/providers/arangodb/provider.go b/server/db/providers/arangodb/provider.go index 92c007c..a9c693e 100644 --- a/server/db/providers/arangodb/provider.go +++ b/server/db/providers/arangodb/provider.go @@ -8,7 +8,7 @@ import ( "github.com/arangodb/go-driver/http" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) type provider struct { @@ -22,8 +22,12 @@ type provider struct { // NewProvider to initialize arangodb connection func NewProvider() (*provider, error) { ctx := context.Background() + dbURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + if err != nil { + return nil, err + } conn, err := http.NewConnection(http.ConnectionConfig{ - Endpoints: []string{envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)}, + Endpoints: []string{dbURL}, }) if err != nil { return nil, err @@ -37,16 +41,19 @@ func NewProvider() (*provider, error) { } var arangodb driver.Database - - arangodb_exists, err := arangoClient.DatabaseExists(nil, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName)) + dbName, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName) + if err != nil { + return nil, err + } + arangodb_exists, err := arangoClient.DatabaseExists(nil, dbName) if arangodb_exists { - arangodb, err = arangoClient.Database(nil, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName)) + arangodb, err = arangoClient.Database(nil, dbName) if err != nil { return nil, err } } else { - arangodb, err = arangoClient.CreateDatabase(nil, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName), nil) + arangodb, err = arangoClient.CreateDatabase(nil, dbName, nil) if err != nil { return nil, err } diff --git a/server/db/providers/arangodb/user.go b/server/db/providers/arangodb/user.go index fc466a4..8c303a0 100644 --- a/server/db/providers/arangodb/user.go +++ b/server/db/providers/arangodb/user.go @@ -10,8 +10,8 @@ import ( arangoDriver "github.com/arangodb/go-driver" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" ) @@ -22,7 +22,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = strings.Join(defaultRoles, ",") } user.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index e7bf3b0..0dcc07e 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -9,7 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gocql/gocql" cansandraDriver "github.com/gocql/gocql" ) @@ -23,15 +23,25 @@ var KeySpace string // NewProvider to initialize arangodb connection func NewProvider() (*provider, error) { - dbURL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + dbURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + if err != nil { + return nil, err + } if dbURL == "" { - dbURL = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseHost) - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePort) != "" { - dbURL = fmt.Sprintf("%s:%s", dbURL, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePort)) + dbURL, err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseHost) + dbPort, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabasePort) + if err != nil { + return nil, err + } + if dbPort != "" { + dbURL = fmt.Sprintf("%s:%s", dbURL, dbPort) } } - KeySpace = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName) + KeySpace, err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName) + if err != nil || KeySpace == "" { + KeySpace = constants.EnvKeyDatabaseName + } clusterURL := []string{} if strings.Contains(dbURL, ",") { clusterURL = strings.Split(dbURL, ",") @@ -39,25 +49,48 @@ func NewProvider() (*provider, error) { clusterURL = append(clusterURL, dbURL) } cassandraClient := cansandraDriver.NewCluster(clusterURL...) - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword) != "" { + dbUsername, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername) + if err != nil { + return nil, err + } + dbPassword, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword) + if err != nil { + return nil, err + } + + if dbUsername != "" && dbPassword != "" { cassandraClient.Authenticator = &cansandraDriver.PasswordAuthenticator{ - Username: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername), - Password: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword), + Username: dbUsername, + Password: dbPassword, } } - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCert) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCACert) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCertKey) != "" { - certString, err := crypto.DecryptB64(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCert)) + dbCert, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCert) + if err != nil { + return nil, err + } + + dbCACert, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCACert) + if err != nil { + return nil, err + } + + dbCertKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCertKey) + if err != nil { + return nil, err + } + if dbCert != "" && dbCACert != "" && dbCertKey != "" { + certString, err := crypto.DecryptB64(dbCert) if err != nil { return nil, err } - keyString, err := crypto.DecryptB64(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCertKey)) + keyString, err := crypto.DecryptB64(dbCertKey) if err != nil { return nil, err } - caString, err := crypto.DecryptB64(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCACert)) + caString, err := crypto.DecryptB64(dbCACert) if err != nil { return nil, err } diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go index 09b7476..68fa36e 100644 --- a/server/db/providers/cassandradb/user.go +++ b/server/db/providers/cassandradb/user.go @@ -9,8 +9,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gocql/gocql" "github.com/google/uuid" ) @@ -22,7 +22,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = strings.Join(defaultRoles, ",") } user.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/mongodb/provider.go b/server/db/providers/mongodb/provider.go index d29fca1..b3ae62f 100644 --- a/server/db/providers/mongodb/provider.go +++ b/server/db/providers/mongodb/provider.go @@ -6,7 +6,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" @@ -19,7 +19,11 @@ type provider struct { // NewProvider to initialize mongodb connection func NewProvider() (*provider, error) { - mongodbOptions := options.Client().ApplyURI(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)) + dbURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + if err != nil { + return nil, err + } + mongodbOptions := options.Client().ApplyURI(dbURL) maxWait := time.Duration(5 * time.Second) mongodbOptions.ConnectTimeout = &maxWait mongoClient, err := mongo.NewClient(mongodbOptions) @@ -37,18 +41,22 @@ func NewProvider() (*provider, error) { return nil, err } - mongodb := mongoClient.Database(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName), options.Database()) + dbName, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName) + if err != nil { + return nil, err + } + mongodb := mongoClient.Database(dbName, options.Database()) mongodb.CreateCollection(ctx, models.Collections.User, options.CreateCollection()) userCollection := mongodb.Collection(models.Collections.User, options.Collection()) userCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"email": 1}, Options: options.Index().SetUnique(true).SetSparse(true), }, }, options.CreateIndexes()) userCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"phone_number": 1}, Options: options.Index().SetUnique(true).SetSparse(true).SetPartialFilterExpression(map[string]interface{}{ "phone_number": map[string]string{"$type": "string"}, @@ -59,13 +67,13 @@ func NewProvider() (*provider, error) { mongodb.CreateCollection(ctx, models.Collections.VerificationRequest, options.CreateCollection()) verificationRequestCollection := mongodb.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"email": 1, "identifier": 1}, Options: options.Index().SetUnique(true).SetSparse(true), }, }, options.CreateIndexes()) verificationRequestCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"token": 1}, Options: options.Index().SetSparse(true), }, @@ -74,7 +82,7 @@ func NewProvider() (*provider, error) { mongodb.CreateCollection(ctx, models.Collections.Session, options.CreateCollection()) sessionCollection := mongodb.Collection(models.Collections.Session, options.Collection()) sessionCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"user_id": 1}, Options: options.Index().SetSparse(true), }, diff --git a/server/db/providers/mongodb/user.go b/server/db/providers/mongodb/user.go index af6c799..f1a5f73 100644 --- a/server/db/providers/mongodb/user.go +++ b/server/db/providers/mongodb/user.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -20,7 +20,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = strings.Join(defaultRoles, ",") } user.CreatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix() diff --git a/server/db/providers/provider_template/user.go b/server/db/providers/provider_template/user.go index 07f6a06..3f45421 100644 --- a/server/db/providers/provider_template/user.go +++ b/server/db/providers/provider_template/user.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" ) @@ -18,7 +18,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = strings.Join(defaultRoles, ",") } user.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/sql/provider.go b/server/db/providers/sql/provider.go index 279b707..20f19d0 100644 --- a/server/db/providers/sql/provider.go +++ b/server/db/providers/sql/provider.go @@ -7,7 +7,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -41,15 +41,26 @@ func NewProvider() (*provider, error) { TablePrefix: models.Prefix, }, } - switch envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) { + + dbType, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) + if err != nil { + return nil, err + } + + dbURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + if err != nil { + return nil, err + } + + switch dbType { case constants.DbTypePostgres, constants.DbTypeYugabyte: - sqlDB, err = gorm.Open(postgres.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(postgres.Open(dbURL), ormConfig) case constants.DbTypeSqlite: - sqlDB, err = gorm.Open(sqlite.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(sqlite.Open(dbURL), ormConfig) case constants.DbTypeMysql, constants.DbTypeMariaDB: - sqlDB, err = gorm.Open(mysql.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(mysql.Open(dbURL), ormConfig) case constants.DbTypeSqlserver: - sqlDB, err = gorm.Open(sqlserver.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(sqlserver.Open(dbURL), ormConfig) } if err != nil { diff --git a/server/db/providers/sql/user.go b/server/db/providers/sql/user.go index ef295c6..4c57c5d 100644 --- a/server/db/providers/sql/user.go +++ b/server/db/providers/sql/user.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" "gorm.io/gorm/clause" ) @@ -19,7 +19,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = strings.Join(defaultRoles, ",") } user.CreatedAt = time.Now().Unix() diff --git a/server/email/email.go b/server/email/email.go index b8e6d80..4234eff 100644 --- a/server/email/email.go +++ b/server/email/email.go @@ -11,7 +11,7 @@ import ( gomail "gopkg.in/mail.v2" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // addEmailTemplate is used to add html template in email body @@ -33,17 +33,57 @@ func addEmailTemplate(a string, b map[string]interface{}, templateName string) s // SendMail function to send mail func SendMail(to []string, Subject, bodyMessage string) error { // dont trigger email sending in case of test - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEnv) == "test" { + envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) + if err != nil { + return err + } + if envKey == "test" { return nil } m := gomail.NewMessage() - m.SetHeader("From", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySenderEmail)) + senderEmail, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySenderEmail) + if err != nil { + log.Errorf("Error while getting sender email from env variable: %v", err) + return err + } + + smtpPort, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpPort) + if err != nil { + log.Errorf("Error while getting smtp port from env variable: %v", err) + return err + } + + smtpHost, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpHost) + if err != nil { + log.Errorf("Error while getting smtp host from env variable: %v", err) + return err + } + + smtpUsername, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpUsername) + if err != nil { + log.Errorf("Error while getting smtp username from env variable: %v", err) + return err + } + + smtpPassword, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpPassword) + if err != nil { + log.Errorf("Error while getting smtp password from env variable: %v", err) + return err + } + + isProd, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsProd) + if err != nil { + log.Errorf("Error while getting env variable: %v", err) + return err + } + + m.SetHeader("From", senderEmail) m.SetHeader("To", to...) m.SetHeader("Subject", Subject) m.SetBody("text/html", bodyMessage) - port, _ := strconv.Atoi(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpPort)) - d := gomail.NewDialer(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpHost), port, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpUsername), envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpPassword)) - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEnv) == "development" { + port, _ := strconv.Atoi(smtpPort) + d := gomail.NewDialer(smtpHost, port, smtpUsername, smtpPassword) + if !isProd { d.TLSConfig = &tls.Config{InsecureSkipVerify: true} } if err := d.DialAndSend(m); err != nil { diff --git a/server/email/forgot_password_email.go b/server/email/forgot_password_email.go index 1e06437..aabd6a9 100644 --- a/server/email/forgot_password_email.go +++ b/server/email/forgot_password_email.go @@ -2,14 +2,19 @@ package email import ( "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // SendForgotPasswordMail to send forgot password email func SendForgotPasswordMail(toEmail, token, hostname string) error { - resetPasswordUrl := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyResetPasswordURL) + resetPasswordUrl, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyResetPasswordURL) + if err != nil { + return err + } if resetPasswordUrl == "" { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyResetPasswordURL, hostname+"/app/reset-password") + if err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyResetPasswordURL, hostname+"/app/reset-password"); err != nil { + return err + } } // The receiver needs to be in slice as the receive supports multiple receiver @@ -103,8 +108,14 @@ func SendForgotPasswordMail(toEmail, token, hostname string) error { ` data := make(map[string]interface{}, 3) - data["org_logo"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) - data["org_name"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + data["org_logo"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + return err + } + data["org_name"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + return err + } data["verification_url"] = resetPasswordUrl + "?token=" + token message = addEmailTemplate(message, data, "reset_password_email.tmpl") diff --git a/server/email/invite_email.go b/server/email/invite_email.go index 8689353..ef561a6 100644 --- a/server/email/invite_email.go +++ b/server/email/invite_email.go @@ -4,7 +4,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // InviteEmail to send invite email @@ -99,13 +99,20 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error { ` data := make(map[string]interface{}, 3) - data["org_logo"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) - data["org_name"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + var err error + data["org_logo"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + return err + } + data["org_name"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + return err + } data["verification_url"] = verificationURL + "?token=" + token + "&redirect_uri=" + redirectURI message = addEmailTemplate(message, data, "invite_email.tmpl") // bodyMessage := sender.WriteHTMLEmail(Receiver, Subject, message) - err := SendMail(Receiver, Subject, message) + err = SendMail(Receiver, Subject, message) if err != nil { log.Warn("error sending email: ", err) } diff --git a/server/email/verification_email.go b/server/email/verification_email.go index dd73657..dded5ef 100644 --- a/server/email/verification_email.go +++ b/server/email/verification_email.go @@ -4,7 +4,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // SendVerificationMail to send verification email @@ -99,13 +99,20 @@ func SendVerificationMail(toEmail, token, hostname string) error { ` data := make(map[string]interface{}, 3) - data["org_logo"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) - data["org_name"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + var err error + data["org_logo"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + return err + } + data["org_name"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + return err + } data["verification_url"] = hostname + "/verify_email?token=" + token message = addEmailTemplate(message, data, "verify_email.tmpl") // bodyMessage := sender.WriteHTMLEmail(Receiver, Subject, message) - err := SendMail(Receiver, Subject, message) + err = SendMail(Receiver, Subject, message) if err != nil { log.Warn("error sending email: ", err) } diff --git a/server/env/env.go b/server/env/env.go index 8a29ec9..ac652cf 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -10,7 +10,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -20,90 +20,94 @@ func InitAllEnv() error { if err != nil { log.Info("No env data found in db, using local clone of env data") // get clone of current store - envData = envstore.EnvStoreObj.GetEnvStoreClone() + envData, err = memorystore.Provider.GetEnvStore() + if err != nil { + log.Debug("Error while getting env data from memorystore: ", err) + return err + } } - clientID := envData.StringEnv[constants.EnvKeyClientID] + clientID := envData[constants.EnvKeyClientID].(string) // unique client id for each instance if clientID == "" { clientID = uuid.New().String() - envData.StringEnv[constants.EnvKeyClientID] = clientID + envData[constants.EnvKeyClientID] = clientID } - clientSecret := envData.StringEnv[constants.EnvKeyClientSecret] + clientSecret := envData[constants.EnvKeyClientSecret] // unique client id for each instance if clientSecret == "" { clientSecret = uuid.New().String() - envData.StringEnv[constants.EnvKeyClientSecret] = clientSecret + envData[constants.EnvKeyClientSecret] = clientSecret } - if envData.StringEnv[constants.EnvKeyEnv] == "" { - envData.StringEnv[constants.EnvKeyEnv] = os.Getenv(constants.EnvKeyEnv) - if envData.StringEnv[constants.EnvKeyEnv] == "" { - envData.StringEnv[constants.EnvKeyEnv] = "production" + if envData[constants.EnvKeyEnv] == "" { + envData[constants.EnvKeyEnv] = os.Getenv(constants.EnvKeyEnv) + if envData[constants.EnvKeyEnv] == "" { + envData[constants.EnvKeyEnv] = "production" } - if envData.StringEnv[constants.EnvKeyEnv] == "production" { - envData.BoolEnv[constants.EnvKeyIsProd] = true + if envData[constants.EnvKeyEnv] == "production" { + envData[constants.EnvKeyIsProd] = true } else { - envData.BoolEnv[constants.EnvKeyIsProd] = false + envData[constants.EnvKeyIsProd] = false } } - if envData.StringEnv[constants.EnvKeyAppURL] == "" { - envData.StringEnv[constants.EnvKeyAppURL] = os.Getenv(constants.EnvKeyAppURL) + if envData[constants.EnvKeyAppURL] == "" { + envData[constants.EnvKeyAppURL] = os.Getenv(constants.EnvKeyAppURL) } - if envData.StringEnv[constants.EnvKeyAuthorizerURL] == "" { - envData.StringEnv[constants.EnvKeyAuthorizerURL] = os.Getenv(constants.EnvKeyAuthorizerURL) + if envData[constants.EnvKeyAuthorizerURL] == "" { + envData[constants.EnvKeyAuthorizerURL] = os.Getenv(constants.EnvKeyAuthorizerURL) } - if envData.StringEnv[constants.EnvKeyPort] == "" { - envData.StringEnv[constants.EnvKeyPort] = os.Getenv(constants.EnvKeyPort) - if envData.StringEnv[constants.EnvKeyPort] == "" { - envData.StringEnv[constants.EnvKeyPort] = "8080" + if envData[constants.EnvKeyPort] == "" { + envData[constants.EnvKeyPort] = os.Getenv(constants.EnvKeyPort) + if envData[constants.EnvKeyPort] == "" { + envData[constants.EnvKeyPort] = "8080" } } - if envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] == "" { - envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] = os.Getenv(constants.EnvKeyAccessTokenExpiryTime) - if envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] == "" { - envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] = "30m" + if envData[constants.EnvKeyAccessTokenExpiryTime] == "" { + envData[constants.EnvKeyAccessTokenExpiryTime] = os.Getenv(constants.EnvKeyAccessTokenExpiryTime) + if envData[constants.EnvKeyAccessTokenExpiryTime] == "" { + envData[constants.EnvKeyAccessTokenExpiryTime] = "30m" } } - if envData.StringEnv[constants.EnvKeyAdminSecret] == "" { - envData.StringEnv[constants.EnvKeyAdminSecret] = os.Getenv(constants.EnvKeyAdminSecret) + if envData[constants.EnvKeyAdminSecret] == "" { + envData[constants.EnvKeyAdminSecret] = os.Getenv(constants.EnvKeyAdminSecret) } - if envData.StringEnv[constants.EnvKeySmtpHost] == "" { - envData.StringEnv[constants.EnvKeySmtpHost] = os.Getenv(constants.EnvKeySmtpHost) + if envData[constants.EnvKeySmtpHost] == "" { + envData[constants.EnvKeySmtpHost] = os.Getenv(constants.EnvKeySmtpHost) } - if envData.StringEnv[constants.EnvKeySmtpPort] == "" { - envData.StringEnv[constants.EnvKeySmtpPort] = os.Getenv(constants.EnvKeySmtpPort) + if envData[constants.EnvKeySmtpPort] == "" { + envData[constants.EnvKeySmtpPort] = os.Getenv(constants.EnvKeySmtpPort) } - if envData.StringEnv[constants.EnvKeySmtpUsername] == "" { - envData.StringEnv[constants.EnvKeySmtpUsername] = os.Getenv(constants.EnvKeySmtpUsername) + if envData[constants.EnvKeySmtpUsername] == "" { + envData[constants.EnvKeySmtpUsername] = os.Getenv(constants.EnvKeySmtpUsername) } - if envData.StringEnv[constants.EnvKeySmtpPassword] == "" { - envData.StringEnv[constants.EnvKeySmtpPassword] = os.Getenv(constants.EnvKeySmtpPassword) + if envData[constants.EnvKeySmtpPassword] == "" { + envData[constants.EnvKeySmtpPassword] = os.Getenv(constants.EnvKeySmtpPassword) } - if envData.StringEnv[constants.EnvKeySenderEmail] == "" { - envData.StringEnv[constants.EnvKeySenderEmail] = os.Getenv(constants.EnvKeySenderEmail) + if envData[constants.EnvKeySenderEmail] == "" { + envData[constants.EnvKeySenderEmail] = os.Getenv(constants.EnvKeySenderEmail) } - algo := envData.StringEnv[constants.EnvKeyJwtType] + algo := envData[constants.EnvKeyJwtType].(string) if algo == "" { - envData.StringEnv[constants.EnvKeyJwtType] = os.Getenv(constants.EnvKeyJwtType) - if envData.StringEnv[constants.EnvKeyJwtType] == "" { - envData.StringEnv[constants.EnvKeyJwtType] = "RS256" - algo = envData.StringEnv[constants.EnvKeyJwtType] + envData[constants.EnvKeyJwtType] = os.Getenv(constants.EnvKeyJwtType) + if envData[constants.EnvKeyJwtType] == "" { + envData[constants.EnvKeyJwtType] = "RS256" + algo = envData[constants.EnvKeyJwtType].(string) } else { - algo = envData.StringEnv[constants.EnvKeyJwtType] + algo = envData[constants.EnvKeyJwtType].(string) if !crypto.IsHMACA(algo) && !crypto.IsRSA(algo) && !crypto.IsECDSA(algo) { log.Debug("Invalid JWT Algorithm") return errors.New("invalid JWT_TYPE") @@ -112,10 +116,10 @@ func InitAllEnv() error { } if crypto.IsHMACA(algo) { - if envData.StringEnv[constants.EnvKeyJwtSecret] == "" { - envData.StringEnv[constants.EnvKeyJwtSecret] = os.Getenv(constants.EnvKeyJwtSecret) - if envData.StringEnv[constants.EnvKeyJwtSecret] == "" { - envData.StringEnv[constants.EnvKeyJwtSecret], _, err = crypto.NewHMACKey(algo, clientID) + if envData[constants.EnvKeyJwtSecret] == "" { + envData[constants.EnvKeyJwtSecret] = os.Getenv(constants.EnvKeyJwtSecret) + if envData[constants.EnvKeyJwtSecret] == "" { + envData[constants.EnvKeyJwtSecret], _, err = crypto.NewHMACKey(algo, clientID) if err != nil { return err } @@ -126,11 +130,11 @@ func InitAllEnv() error { if crypto.IsRSA(algo) || crypto.IsECDSA(algo) { privateKey, publicKey := "", "" - if envData.StringEnv[constants.EnvKeyJwtPrivateKey] == "" { + if envData[constants.EnvKeyJwtPrivateKey] == "" { privateKey = os.Getenv(constants.EnvKeyJwtPrivateKey) } - if envData.StringEnv[constants.EnvKeyJwtPublicKey] == "" { + if envData[constants.EnvKeyJwtPublicKey] == "" { publicKey = os.Getenv(constants.EnvKeyJwtPublicKey) } @@ -174,76 +178,69 @@ func InitAllEnv() error { } } - envData.StringEnv[constants.EnvKeyJwtPrivateKey] = privateKey - envData.StringEnv[constants.EnvKeyJwtPublicKey] = publicKey + envData[constants.EnvKeyJwtPrivateKey] = privateKey + envData[constants.EnvKeyJwtPublicKey] = publicKey } - if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" { - envData.StringEnv[constants.EnvKeyJwtRoleClaim] = os.Getenv(constants.EnvKeyJwtRoleClaim) + if envData[constants.EnvKeyJwtRoleClaim] == "" { + envData[constants.EnvKeyJwtRoleClaim] = os.Getenv(constants.EnvKeyJwtRoleClaim) - if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" { - envData.StringEnv[constants.EnvKeyJwtRoleClaim] = "role" + if envData[constants.EnvKeyJwtRoleClaim] == "" { + envData[constants.EnvKeyJwtRoleClaim] = "role" } } - if envData.StringEnv[constants.EnvKeyCustomAccessTokenScript] == "" { - envData.StringEnv[constants.EnvKeyCustomAccessTokenScript] = os.Getenv(constants.EnvKeyCustomAccessTokenScript) + if envData[constants.EnvKeyCustomAccessTokenScript] == "" { + envData[constants.EnvKeyCustomAccessTokenScript] = os.Getenv(constants.EnvKeyCustomAccessTokenScript) } - if envData.StringEnv[constants.EnvKeyRedisURL] == "" { - envData.StringEnv[constants.EnvKeyRedisURL] = os.Getenv(constants.EnvKeyRedisURL) + if envData[constants.EnvKeyRedisURL] == "" { + envData[constants.EnvKeyRedisURL] = os.Getenv(constants.EnvKeyRedisURL) } - if envData.StringEnv[constants.EnvKeyCookieName] == "" { - envData.StringEnv[constants.EnvKeyCookieName] = os.Getenv(constants.EnvKeyCookieName) - if envData.StringEnv[constants.EnvKeyCookieName] == "" { - envData.StringEnv[constants.EnvKeyCookieName] = "authorizer" - } + if envData[constants.EnvKeyGoogleClientID] == "" { + envData[constants.EnvKeyGoogleClientID] = os.Getenv(constants.EnvKeyGoogleClientID) } - if envData.StringEnv[constants.EnvKeyGoogleClientID] == "" { - envData.StringEnv[constants.EnvKeyGoogleClientID] = os.Getenv(constants.EnvKeyGoogleClientID) + if envData[constants.EnvKeyGoogleClientSecret] == "" { + envData[constants.EnvKeyGoogleClientSecret] = os.Getenv(constants.EnvKeyGoogleClientSecret) } - if envData.StringEnv[constants.EnvKeyGoogleClientSecret] == "" { - envData.StringEnv[constants.EnvKeyGoogleClientSecret] = os.Getenv(constants.EnvKeyGoogleClientSecret) + if envData[constants.EnvKeyGithubClientID] == "" { + envData[constants.EnvKeyGithubClientID] = os.Getenv(constants.EnvKeyGithubClientID) } - if envData.StringEnv[constants.EnvKeyGithubClientID] == "" { - envData.StringEnv[constants.EnvKeyGithubClientID] = os.Getenv(constants.EnvKeyGithubClientID) + if envData[constants.EnvKeyGithubClientSecret] == "" { + envData[constants.EnvKeyGithubClientSecret] = os.Getenv(constants.EnvKeyGithubClientSecret) } - if envData.StringEnv[constants.EnvKeyGithubClientSecret] == "" { - envData.StringEnv[constants.EnvKeyGithubClientSecret] = os.Getenv(constants.EnvKeyGithubClientSecret) + if envData[constants.EnvKeyFacebookClientID] == "" { + envData[constants.EnvKeyFacebookClientID] = os.Getenv(constants.EnvKeyFacebookClientID) } - if envData.StringEnv[constants.EnvKeyFacebookClientID] == "" { - envData.StringEnv[constants.EnvKeyFacebookClientID] = os.Getenv(constants.EnvKeyFacebookClientID) + if envData[constants.EnvKeyFacebookClientSecret] == "" { + envData[constants.EnvKeyFacebookClientSecret] = os.Getenv(constants.EnvKeyFacebookClientSecret) } - if envData.StringEnv[constants.EnvKeyFacebookClientSecret] == "" { - envData.StringEnv[constants.EnvKeyFacebookClientSecret] = os.Getenv(constants.EnvKeyFacebookClientSecret) + if envData[constants.EnvKeyResetPasswordURL] == "" { + envData[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv(constants.EnvKeyResetPasswordURL), "/") } - if envData.StringEnv[constants.EnvKeyResetPasswordURL] == "" { - envData.StringEnv[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv(constants.EnvKeyResetPasswordURL), "/") - } - - envData.BoolEnv[constants.EnvKeyDisableBasicAuthentication] = os.Getenv(constants.EnvKeyDisableBasicAuthentication) == "true" - envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = os.Getenv(constants.EnvKeyDisableEmailVerification) == "true" - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv(constants.EnvKeyDisableMagicLinkLogin) == "true" - envData.BoolEnv[constants.EnvKeyDisableLoginPage] = os.Getenv(constants.EnvKeyDisableLoginPage) == "true" - envData.BoolEnv[constants.EnvKeyDisableSignUp] = os.Getenv(constants.EnvKeyDisableSignUp) == "true" + envData[constants.EnvKeyDisableBasicAuthentication] = os.Getenv(constants.EnvKeyDisableBasicAuthentication) == "true" + envData[constants.EnvKeyDisableEmailVerification] = os.Getenv(constants.EnvKeyDisableEmailVerification) == "true" + envData[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv(constants.EnvKeyDisableMagicLinkLogin) == "true" + envData[constants.EnvKeyDisableLoginPage] = os.Getenv(constants.EnvKeyDisableLoginPage) == "true" + envData[constants.EnvKeyDisableSignUp] = os.Getenv(constants.EnvKeyDisableSignUp) == "true" // no need to add nil check as its already done above - if envData.StringEnv[constants.EnvKeySmtpHost] == "" || envData.StringEnv[constants.EnvKeySmtpUsername] == "" || envData.StringEnv[constants.EnvKeySmtpPassword] == "" || envData.StringEnv[constants.EnvKeySenderEmail] == "" && envData.StringEnv[constants.EnvKeySmtpPort] == "" { - envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = true - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if envData[constants.EnvKeySmtpHost] == "" || envData[constants.EnvKeySmtpUsername] == "" || envData[constants.EnvKeySmtpPassword] == "" || envData[constants.EnvKeySenderEmail] == "" && envData[constants.EnvKeySmtpPort] == "" { + envData[constants.EnvKeyDisableEmailVerification] = true + envData[constants.EnvKeyDisableMagicLinkLogin] = true } - if envData.BoolEnv[constants.EnvKeyDisableEmailVerification] { - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if envData[constants.EnvKeyDisableEmailVerification].(bool) { + envData[constants.EnvKeyDisableMagicLinkLogin] = true } allowedOriginsSplit := strings.Split(os.Getenv(constants.EnvKeyAllowedOrigins), ",") @@ -272,7 +269,7 @@ func InitAllEnv() error { allowedOrigins = []string{"*"} } - envData.SliceEnv[constants.EnvKeyAllowedOrigins] = allowedOrigins + envData[constants.EnvKeyAllowedOrigins] = allowedOrigins rolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyRoles)) rolesSplit := strings.Split(rolesEnv, ",") @@ -315,18 +312,18 @@ func InitAllEnv() error { return errors.New(`invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) } - envData.SliceEnv[constants.EnvKeyRoles] = roles - envData.SliceEnv[constants.EnvKeyDefaultRoles] = defaultRoles - envData.SliceEnv[constants.EnvKeyProtectedRoles] = protectedRoles + envData[constants.EnvKeyRoles] = roles + envData[constants.EnvKeyDefaultRoles] = defaultRoles + envData[constants.EnvKeyProtectedRoles] = protectedRoles if os.Getenv(constants.EnvKeyOrganizationName) != "" { - envData.StringEnv[constants.EnvKeyOrganizationName] = os.Getenv(constants.EnvKeyOrganizationName) + envData[constants.EnvKeyOrganizationName] = os.Getenv(constants.EnvKeyOrganizationName) } if os.Getenv(constants.EnvKeyOrganizationLogo) != "" { - envData.StringEnv[constants.EnvKeyOrganizationLogo] = os.Getenv(constants.EnvKeyOrganizationLogo) + envData[constants.EnvKeyOrganizationLogo] = os.Getenv(constants.EnvKeyOrganizationLogo) } - envstore.EnvStoreObj.UpdateEnvStore(envData) + memorystore.Provider.UpdateEnvStore(envData) return nil } diff --git a/server/env/persist_env.go b/server/env/persist_env.go index 9b3c23e..134aca6 100644 --- a/server/env/persist_env.go +++ b/server/env/persist_env.go @@ -13,13 +13,13 @@ import ( "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) // GetEnvData returns the env data from database -func GetEnvData() (envstore.Store, error) { - var result envstore.Store +func GetEnvData() (map[string]interface{}, error) { + var result map[string]interface{} env, err := db.Provider.GetEnv() // config not found in db if err != nil { @@ -34,7 +34,7 @@ func GetEnvData() (envstore.Store, error) { return result, err } - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, decryptedEncryptionKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, decryptedEncryptionKey) b64DecryptedConfig, err := crypto.DecryptB64(env.EnvData) if err != nil { @@ -64,10 +64,16 @@ func PersistEnv() error { if err != nil { // AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid hash := uuid.New().String()[:36-4] - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, hash) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash) encodedHash := crypto.EncryptB64(hash) - encryptedConfig, err := crypto.EncryptEnvData(envstore.EnvStoreObj.GetEnvStoreClone()) + res, err := memorystore.Provider.GetEnvStore() + if err != nil { + log.Debug("Error while getting env store: ", err) + return err + } + + encryptedConfig, err := crypto.EncryptEnvData(res) if err != nil { log.Debug("Error while encrypting env data: ", err) return err @@ -93,7 +99,7 @@ func PersistEnv() error { return err } - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, decryptedEncryptionKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, decryptedEncryptionKey) b64DecryptedConfig, err := crypto.DecryptB64(env.EnvData) if err != nil { @@ -108,7 +114,7 @@ func PersistEnv() error { } // temp store variable - var storeData envstore.Store + storeData := map[string]interface{}{} err = json.Unmarshal(decryptedConfigs, &storeData) if err != nil { @@ -120,71 +126,72 @@ func PersistEnv() error { // give that higher preference and update db, but we don't recommend it hasChanged := false - - for key, value := range storeData.StringEnv { + for key, value := range storeData { // don't override unexposed envs + // check only for derivative keys + // No need to check for ENCRYPTION_KEY which special key we use for encrypting config data + // as we have removed it from json if key != constants.EnvKeyEncryptionKey { - // check only for derivative keys - // No need to check for ENCRYPTION_KEY which special key we use for encrypting config data - // as we have removed it from json envValue := strings.TrimSpace(os.Getenv(key)) - - // env is not empty if envValue != "" { - if value != envValue { - storeData.StringEnv[key] = envValue - hasChanged = true + switch key { + case constants.EnvKeyRoles, constants.EnvKeyDefaultRoles, constants.EnvKeyProtectedRoles: + envStringArr := strings.Split(envValue, ",") + originalValue := utils.ConvertInterfaceToStringSlice(value) + if !utils.IsStringArrayEqual(originalValue, envStringArr) { + storeData[key] = envStringArr + hasChanged = true + } + + break + case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp: + if envValueBool, err := strconv.ParseBool(envValue); err == nil { + if value.(bool) != envValueBool { + storeData[key] = envValueBool + hasChanged = true + } + } + + break + default: + if value.(string) != envValue { + storeData[key] = envValue + hasChanged = true + } + + break } } } } - for key, value := range storeData.BoolEnv { - envValue := strings.TrimSpace(os.Getenv(key)) - // env is not empty - if envValue != "" { - envValueBool, _ := strconv.ParseBool(envValue) - if value != envValueBool { - storeData.BoolEnv[key] = envValueBool - hasChanged = true - } - } - } - - for key, value := range storeData.SliceEnv { - envValue := strings.TrimSpace(os.Getenv(key)) - // env is not empty - if envValue != "" { - envStringArr := strings.Split(envValue, ",") - if !utils.IsStringArrayEqual(value, envStringArr) { - storeData.SliceEnv[key] = envStringArr - hasChanged = true - } - } - } - // handle derivative cases like disabling email verification & magic login // in case SMTP is off but env is set to true - if storeData.StringEnv[constants.EnvKeySmtpHost] == "" || storeData.StringEnv[constants.EnvKeySmtpUsername] == "" || storeData.StringEnv[constants.EnvKeySmtpPassword] == "" || storeData.StringEnv[constants.EnvKeySenderEmail] == "" && storeData.StringEnv[constants.EnvKeySmtpPort] == "" { - if !storeData.BoolEnv[constants.EnvKeyDisableEmailVerification] { - storeData.BoolEnv[constants.EnvKeyDisableEmailVerification] = true + if storeData[constants.EnvKeySmtpHost] == "" || storeData[constants.EnvKeySmtpUsername] == "" || storeData[constants.EnvKeySmtpPassword] == "" || storeData[constants.EnvKeySenderEmail] == "" && storeData[constants.EnvKeySmtpPort] == "" { + if !storeData[constants.EnvKeyDisableEmailVerification].(bool) { + storeData[constants.EnvKeyDisableEmailVerification] = true hasChanged = true } - if !storeData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] { - storeData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if !storeData[constants.EnvKeyDisableMagicLinkLogin].(bool) { + storeData[constants.EnvKeyDisableMagicLinkLogin] = true hasChanged = true } } - envstore.EnvStoreObj.UpdateEnvStore(storeData) + err = memorystore.Provider.UpdateEnvStore(storeData) + if err != nil { + log.Debug("Error while updating env store: ", err) + return err + } + jwk, err := crypto.GenerateJWKBasedOnEnv() if err != nil { log.Debug("Error while generating JWK: ", err) return err } // updating jwk - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJWK, jwk) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJWK, jwk) if hasChanged { encryptedConfig, err := crypto.EncryptEnvData(storeData) diff --git a/server/envstore/store.go b/server/envstore/store.go deleted file mode 100644 index 1140087..0000000 --- a/server/envstore/store.go +++ /dev/null @@ -1,111 +0,0 @@ -package envstore - -import ( - "sync" - - "github.com/authorizerdev/authorizer/server/constants" -) - -// Store data structure -type Store struct { - StringEnv map[string]string `json:"string_env"` - BoolEnv map[string]bool `json:"bool_env"` - SliceEnv map[string][]string `json:"slice_env"` -} - -// EnvStore struct -type EnvStore struct { - mutex sync.Mutex - store *Store -} - -var defaultStore = &EnvStore{ - store: &Store{ - StringEnv: map[string]string{ - constants.EnvKeyAdminCookieName: "authorizer-admin", - constants.EnvKeyJwtRoleClaim: "role", - constants.EnvKeyOrganizationName: "Authorizer", - constants.EnvKeyOrganizationLogo: "https://www.authorizer.dev/images/logo.png", - }, - BoolEnv: map[string]bool{ - constants.EnvKeyDisableBasicAuthentication: false, - constants.EnvKeyDisableMagicLinkLogin: false, - constants.EnvKeyDisableEmailVerification: false, - constants.EnvKeyDisableLoginPage: false, - constants.EnvKeyDisableSignUp: false, - }, - SliceEnv: map[string][]string{}, - }, -} - -// EnvStoreObj.GetBoolStoreEnvVariable global variable for EnvStore -var EnvStoreObj = defaultStore - -// UpdateEnvStore to update the whole env store object -func (e *EnvStore) UpdateEnvStore(store Store) { - e.mutex.Lock() - defer e.mutex.Unlock() - // just override the keys + new keys - - for key, value := range store.StringEnv { - e.store.StringEnv[key] = value - } - - for key, value := range store.BoolEnv { - e.store.BoolEnv[key] = value - } - - for key, value := range store.SliceEnv { - e.store.SliceEnv[key] = value - } -} - -// UpdateEnvVariable to update the particular env variable -func (e *EnvStore) UpdateEnvVariable(storeIdentifier, key string, value interface{}) { - e.mutex.Lock() - defer e.mutex.Unlock() - switch storeIdentifier { - case constants.StringStoreIdentifier: - e.store.StringEnv[key] = value.(string) - case constants.BoolStoreIdentifier: - e.store.BoolEnv[key] = value.(bool) - case constants.SliceStoreIdentifier: - e.store.SliceEnv[key] = value.([]string) - } -} - -// GetStringStoreEnvVariable to get the env variable from string store object -func (e *EnvStore) GetStringStoreEnvVariable(key string) string { - // e.mutex.Lock() - // defer e.mutex.Unlock() - return e.store.StringEnv[key] -} - -// GetBoolStoreEnvVariable to get the env variable from bool store object -func (e *EnvStore) GetBoolStoreEnvVariable(key string) bool { - // e.mutex.Lock() - // defer e.mutex.Unlock() - return e.store.BoolEnv[key] -} - -// GetSliceStoreEnvVariable to get the env variable from slice store object -func (e *EnvStore) GetSliceStoreEnvVariable(key string) []string { - // e.mutex.Lock() - // defer e.mutex.Unlock() - return e.store.SliceEnv[key] -} - -// GetEnvStoreClone to get clone of current env store object -func (e *EnvStore) GetEnvStoreClone() Store { - e.mutex.Lock() - defer e.mutex.Unlock() - - result := *e.store - return result -} - -func (e *EnvStore) ResetStore() { - e.mutex.Lock() - defer e.mutex.Unlock() - e.store = defaultStore.store -} diff --git a/server/handlers/app.go b/server/handlers/app.go index d855db7..1827d54 100644 --- a/server/handlers/app.go +++ b/server/handlers/app.go @@ -8,7 +8,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -23,7 +23,7 @@ type State struct { func AppHandler() gin.HandlerFunc { return func(c *gin.Context) { hostname := utils.GetHost(c) - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableLoginPage) { + if isLoginPageDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableLoginPage); err != nil || isLoginPageDisabled { log.Debug("Login page is disabled") c.JSON(400, gin.H{"error": "login page is not enabled"}) return @@ -58,14 +58,27 @@ func AppHandler() gin.HandlerFunc { log.Debug("Failed to push file path: ", err) } } + + orgName, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + log.Debug("Failed to get organization name") + c.JSON(400, gin.H{"error": "failed to get organization name"}) + return + } + orgLogo, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + log.Debug("Failed to get organization logo") + c.JSON(400, gin.H{"error": "failed to get organization logo"}) + return + } c.HTML(http.StatusOK, "app.tmpl", gin.H{ "data": map[string]interface{}{ "authorizerURL": hostname, "redirectURL": redirect_uri, "scope": scope, "state": state, - "organizationName": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName), - "organizationLogo": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo), + "organizationName": orgName, + "organizationLogo": orgLogo, }, }) } diff --git a/server/handlers/authorize.go b/server/handlers/authorize.go index 2ff450a..1fa84cc 100644 --- a/server/handlers/authorize.go +++ b/server/handlers/authorize.go @@ -13,7 +13,6 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" ) @@ -80,7 +79,7 @@ func AuthorizeHandler() gin.HandlerFunc { return } - if clientID != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + if client, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID); client != clientID || err != nil { if isQuery { gc.Redirect(http.StatusFound, loginURL) } else { diff --git a/server/handlers/dashboard.go b/server/handlers/dashboard.go index 7eb7dce..55d1534 100644 --- a/server/handlers/dashboard.go +++ b/server/handlers/dashboard.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gin-gonic/gin" ) @@ -12,8 +12,8 @@ import ( func DashboardHandler() gin.HandlerFunc { return func(c *gin.Context) { isOnboardingCompleted := false - - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) != "" { + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil || adminSecret != "" { isOnboardingCompleted = true } diff --git a/server/handlers/jwks.go b/server/handlers/jwks.go index 2e13dc2..7a2cc54 100644 --- a/server/handlers/jwks.go +++ b/server/handlers/jwks.go @@ -7,14 +7,21 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) func JWKsHandler() gin.HandlerFunc { return func(c *gin.Context) { var data map[string]string - jwk := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJWK) - err := json.Unmarshal([]byte(jwk), &data) + jwk, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJWK) + if err != nil { + log.Debug("Error getting JWK from memorystore: ", err) + c.JSON(500, gin.H{ + "error": err.Error(), + }) + return + } + err = json.Unmarshal([]byte(jwk), &data) if err != nil { log.Debug("Failed to parse JWK: ", err) c.JSON(500, gin.H{ diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index 6898331..bffbdcf 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -19,7 +19,6 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/token" @@ -32,8 +31,8 @@ func OAuthCallbackHandler() gin.HandlerFunc { provider := c.Param("oauth_provider") state := c.Request.FormValue("state") - sessionState := memorystore.Provider.GetState(state) - if sessionState == "" { + sessionState, err := memorystore.Provider.GetState(state) + if sessionState == "" || err != nil { log.Debug("Invalid oauth state: ", state) c.JSON(400, gin.H{"error": "invalid oauth state"}) } @@ -52,7 +51,6 @@ func OAuthCallbackHandler() gin.HandlerFunc { inputRoles := strings.Split(sessionSplit[2], ",") scopes := strings.Split(sessionSplit[3], ",") - var err error user := models.User{} code := c.Request.FormValue("code") switch provider { @@ -77,7 +75,13 @@ func OAuthCallbackHandler() gin.HandlerFunc { log := log.WithField("user", user.Email) if err != nil { - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) { + isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) + if err != nil { + log.Debug("Failed to get signup disabled env variable: ", err) + c.JSON(400, gin.H{"error": err.Error()}) + return + } + if isSignupDisabled { log.Debug("Failed to signup as disabled") c.JSON(400, gin.H{"error": "signup is disabled for this instance"}) return @@ -87,7 +91,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { // make sure inputRoles don't include protected roles hasProtectedRole := false for _, ir := range inputRoles { - if utils.StringSliceContains(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles), ir) { + protectedRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles) + if err != nil { + log.Debug("Failed to get protected roles: ", err) + protectedRoles = []string{} + } + if utils.StringSliceContains(protectedRoles, ir) { hasProtectedRole = true } } @@ -140,7 +149,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { // check if it contains protected unassigned role hasProtectedRole := false for _, ur := range unasignedRoles { - if utils.StringSliceContains(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles), ur) { + protectedRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles) + if err != nil { + log.Debug("Failed to get protected roles: ", err) + protectedRoles = []string{} + } + if utils.StringSliceContains(protectedRoles, ur) { hasProtectedRole = true } } diff --git a/server/handlers/oauth_login.go b/server/handlers/oauth_login.go index 2b5948e..cf7c640 100644 --- a/server/handlers/oauth_login.go +++ b/server/handlers/oauth_login.go @@ -8,7 +8,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/utils" @@ -56,7 +55,16 @@ func OAuthLoginHandler() gin.HandlerFunc { // use protected roles verification for admin login only. // though if not associated with user, it will be rejected from oauth_callback - if !utils.IsValidRoles(rolesSplit, append([]string{}, append(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyRoles), envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles)...)...)) { + roles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyRoles) + if err != nil { + log.Debug("Error getting roles: ", err) + } + protectedRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles) + if err != nil { + log.Debug("Error getting protected roles: ", err) + } + + if !utils.IsValidRoles(rolesSplit, append([]string{}, append(roles, protectedRoles...)...)) { log.Debug("Invalid roles: ", roles) c.JSON(400, gin.H{ "error": "invalid role", @@ -64,7 +72,16 @@ func OAuthLoginHandler() gin.HandlerFunc { return } } else { - roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + log.Debug("Error getting default roles: ", err) + c.JSON(400, gin.H{ + "error": "invalid role", + }) + return + } + roles = strings.Join(defaultRoles, ",") + } oauthStateString := state + "___" + redirectURI + "___" + roles + "___" + strings.Join(scope, ",") @@ -78,7 +95,14 @@ func OAuthLoginHandler() gin.HandlerFunc { isProviderConfigured = false break } - memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGoogle) + err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGoogle) + if err != nil { + log.Debug("Error setting state: ", err) + c.JSON(500, gin.H{ + "error": "internal server error", + }) + return + } // during the init of OAuthProvider authorizer url might be empty oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/google" url := oauth.OAuthProviders.GoogleConfig.AuthCodeURL(oauthStateString) @@ -89,7 +113,14 @@ func OAuthLoginHandler() gin.HandlerFunc { isProviderConfigured = false break } - memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGithub) + err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGithub) + if err != nil { + log.Debug("Error setting state: ", err) + c.JSON(500, gin.H{ + "error": "internal server error", + }) + return + } oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/github" url := oauth.OAuthProviders.GithubConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) @@ -99,7 +130,14 @@ func OAuthLoginHandler() gin.HandlerFunc { isProviderConfigured = false break } - memorystore.Provider.SetState(oauthStateString, constants.SignupMethodFacebook) + err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodFacebook) + if err != nil { + log.Debug("Error setting state: ", err) + c.JSON(500, gin.H{ + "error": "internal server error", + }) + return + } oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/facebook" url := oauth.OAuthProviders.FacebookConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) diff --git a/server/handlers/openid_config.go b/server/handlers/openid_config.go index 5b98d03..a7b4f8c 100644 --- a/server/handlers/openid_config.go +++ b/server/handlers/openid_config.go @@ -4,7 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -12,7 +12,7 @@ import ( func OpenIDConfigurationHandler() gin.HandlerFunc { return func(c *gin.Context) { issuer := utils.GetHost(c) - jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + jwtType, _ := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) c.JSON(200, gin.H{ "issuer": issuer, diff --git a/server/handlers/revoke.go b/server/handlers/revoke.go index a63457c..9cc5b07 100644 --- a/server/handlers/revoke.go +++ b/server/handlers/revoke.go @@ -8,7 +8,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/memorystore" ) @@ -37,7 +36,7 @@ func RevokeHandler() gin.HandlerFunc { return } - if clientID != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + if client, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID); client != clientID || err != nil { log.Debug("Client ID is invalid: ", clientID) gc.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_client_id", diff --git a/server/handlers/token.go b/server/handlers/token.go index 6fc2275..4bcbe83 100644 --- a/server/handlers/token.go +++ b/server/handlers/token.go @@ -13,7 +13,6 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" ) @@ -62,7 +61,7 @@ func TokenHandler() gin.HandlerFunc { return } - if clientID != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + if client, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID); clientID != client || err != nil { log.Debug("Client ID is invalid: ", clientID) gc.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_client_id", @@ -98,8 +97,8 @@ func TokenHandler() gin.HandlerFunc { encryptedCode := strings.ReplaceAll(base64.URLEncoding.EncodeToString(hash.Sum(nil)), "+", "-") encryptedCode = strings.ReplaceAll(encryptedCode, "/", "_") encryptedCode = strings.ReplaceAll(encryptedCode, "=", "") - sessionData := memorystore.Provider.GetState(encryptedCode) - if sessionData == "" { + sessionData, err := memorystore.Provider.GetState(encryptedCode) + if sessionData == "" || err != nil { log.Debug("Session data is empty") gc.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_code_verifier", diff --git a/server/memorystore/memory_store.go b/server/memorystore/memory_store.go index ad66b8b..a3c43c8 100644 --- a/server/memorystore/memory_store.go +++ b/server/memorystore/memory_store.go @@ -3,6 +3,7 @@ package memorystore import ( log "github.com/sirupsen/logrus" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/memorystore/providers" "github.com/authorizerdev/authorizer/server/memorystore/providers/inmemory" "github.com/authorizerdev/authorizer/server/memorystore/providers/redis" @@ -15,6 +16,20 @@ var Provider providers.Provider func InitMemStore() error { var err error + defaultEnvs := map[string]interface{}{ + // string envs + constants.EnvKeyJwtRoleClaim: "role", + constants.EnvKeyOrganizationName: "Authorizer", + constants.EnvKeyOrganizationLogo: "https://www.authorizer.dev/images/logo.png", + + // boolean envs + constants.EnvKeyDisableBasicAuthentication: false, + constants.EnvKeyDisableMagicLinkLogin: false, + constants.EnvKeyDisableEmailVerification: false, + constants.EnvKeyDisableLoginPage: false, + constants.EnvKeyDisableSignUp: false, + } + redisURL := RequiredEnvStoreObj.GetRequiredEnv().RedisURL if redisURL != "" { log.Info("Initializing Redis memory store") @@ -23,6 +38,9 @@ func InitMemStore() error { return err } + // set default envs in redis + Provider.UpdateEnvStore(defaultEnvs) + return nil } @@ -32,5 +50,7 @@ func InitMemStore() error { if err != nil { return err } + // set default envs in local env + Provider.UpdateEnvStore(defaultEnvs) return nil } diff --git a/server/memorystore/providers/inmemory/envstore.go b/server/memorystore/providers/inmemory/envstore.go new file mode 100644 index 0000000..a290280 --- /dev/null +++ b/server/memorystore/providers/inmemory/envstore.go @@ -0,0 +1,41 @@ +package inmemory + +import "sync" + +// EnvStore struct to store the env variables +type EnvStore struct { + mutex sync.Mutex + store map[string]interface{} +} + +// UpdateEnvStore to update the whole env store object +func (e *EnvStore) UpdateStore(store map[string]interface{}) { + e.mutex.Lock() + defer e.mutex.Unlock() + // just override the keys + new keys + + for key, value := range store { + e.store[key] = value + } +} + +// GetStore returns the env store +func (e *EnvStore) GetStore() map[string]interface{} { + e.mutex.Lock() + defer e.mutex.Unlock() + return e.store +} + +// Get returns the value of the key in evn store +func (s *EnvStore) Get(key string) interface{} { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.store[key] +} + +// Set sets the value of the key in env store +func (s *EnvStore) Set(key string, value interface{}) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.store[key] = value +} diff --git a/server/memorystore/providers/inmemory/provider.go b/server/memorystore/providers/inmemory/provider.go index 767e2fa..0dec662 100644 --- a/server/memorystore/providers/inmemory/provider.go +++ b/server/memorystore/providers/inmemory/provider.go @@ -1,11 +1,14 @@ package inmemory -import "sync" +import ( + "sync" +) type provider struct { mutex sync.Mutex sessionStore map[string]map[string]string stateStore map[string]string + envStore *EnvStore } // NewInMemoryStore returns a new in-memory store. @@ -14,5 +17,9 @@ func NewInMemoryProvider() (*provider, error) { mutex: sync.Mutex{}, sessionStore: map[string]map[string]string{}, stateStore: map[string]string{}, + envStore: &EnvStore{ + mutex: sync.Mutex{}, + store: map[string]interface{}{}, + }, }, nil } diff --git a/server/memorystore/providers/inmemory/inmemory.go b/server/memorystore/providers/inmemory/store.go similarity index 51% rename from server/memorystore/providers/inmemory/inmemory.go rename to server/memorystore/providers/inmemory/store.go index 72c7d59..f935f84 100644 --- a/server/memorystore/providers/inmemory/inmemory.go +++ b/server/memorystore/providers/inmemory/store.go @@ -1,6 +1,10 @@ package inmemory -import "strings" +import ( + "strings" + + "github.com/authorizerdev/authorizer/server/utils" +) // ClearStore clears the in-memory store. func (c *provider) ClearStore() error { @@ -42,14 +46,13 @@ func (c *provider) DeleteAllUserSession(userId string) error { func (c *provider) SetState(key, state string) error { c.mutex.Lock() defer c.mutex.Unlock() - c.stateStore[key] = state return nil } // GetState gets the state from the in-memory store. -func (c *provider) GetState(key string) string { +func (c *provider) GetState(key string) (string, error) { c.mutex.Lock() defer c.mutex.Unlock() @@ -58,15 +61,50 @@ func (c *provider) GetState(key string) string { state = stateVal } - return state + return state, nil } // RemoveState removes the state from the in-memory store. func (c *provider) RemoveState(key string) error { c.mutex.Lock() defer c.mutex.Unlock() - delete(c.stateStore, key) return nil } + +// UpdateEnvStore to update the whole env store object +func (c *provider) UpdateEnvStore(store map[string]interface{}) error { + c.envStore.UpdateStore(store) + return nil +} + +// GetEnvStore returns the env store object +func (c *provider) GetEnvStore() (map[string]interface{}, error) { + return c.envStore.GetStore(), nil +} + +// UpdateEnvVariable to update the particular env variable +func (c *provider) UpdateEnvVariable(key string, value interface{}) error { + c.envStore.Set(key, value) + return nil +} + +// GetStringStoreEnvVariable to get the env variable from string store object +func (c *provider) GetStringStoreEnvVariable(key string) (string, error) { + res := c.envStore.Get(key) + return res.(string), nil +} + +// GetBoolStoreEnvVariable to get the env variable from bool store object +func (c *provider) GetBoolStoreEnvVariable(key string) (bool, error) { + res := c.envStore.Get(key) + return res.(bool), nil +} + +// GetSliceStoreEnvVariable to get the env variable from slice store object +func (c *provider) GetSliceStoreEnvVariable(key string) ([]string, error) { + res := c.envStore.Get(key) + resSlice := utils.ConvertInterfaceToStringSlice(res) + return resSlice, nil +} diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index 3270bb4..df4507a 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -11,7 +11,22 @@ type Provider interface { // SetState sets the login state (key, value form) in the session store SetState(key, state string) error // GetState returns the state from the session store - GetState(key string) string + GetState(key string) (string, error) // RemoveState removes the social login state from the session store RemoveState(key string) error + + // methods for env store + + // UpdateEnvStore to update the whole env store object + UpdateEnvStore(store map[string]interface{}) error + // GetEnvStore() returns the env store object + GetEnvStore() (map[string]interface{}, error) + // UpdateEnvVariable to update the particular env variable + UpdateEnvVariable(key string, value interface{}) error + // GetStringStoreEnvVariable to get the string env variable from env store + GetStringStoreEnvVariable(key string) (string, error) + // GetBoolStoreEnvVariable to get the bool env variable from env store + GetBoolStoreEnvVariable(key string) (bool, error) + // GetSliceStoreEnvVariable to get the string slice env variable from env store + GetSliceStoreEnvVariable(key string) ([]string, error) } diff --git a/server/memorystore/providers/redis/reids.go b/server/memorystore/providers/redis/reids.go deleted file mode 100644 index 427ae9d..0000000 --- a/server/memorystore/providers/redis/reids.go +++ /dev/null @@ -1,85 +0,0 @@ -package redis - -import ( - "strings" - - log "github.com/sirupsen/logrus" -) - -// ClearStore clears the redis store for authorizer related tokens -func (c *provider) ClearStore() error { - err := c.store.Del(c.ctx, "authorizer_*").Err() - if err != nil { - log.Debug("Error clearing redis store: ", err) - return err - } - - return nil -} - -// GetUserSessions returns all the user session token from the redis store. -func (c *provider) GetUserSessions(userID string) map[string]string { - data, err := c.store.HGetAll(c.ctx, "*").Result() - if err != nil { - log.Debug("error getting token from redis store: ", err) - } - - res := map[string]string{} - for k, v := range data { - split := strings.Split(v, "@") - if split[1] == userID { - res[k] = split[0] - } - } - - return res -} - -// DeleteAllUserSession deletes all the user session from redis -func (c *provider) DeleteAllUserSession(userId string) error { - sessions := c.GetUserSessions(userId) - for k, v := range sessions { - if k == "token" { - err := c.store.Del(c.ctx, v).Err() - if err != nil { - log.Debug("Error deleting redis token: ", err) - return err - } - } - } - - return nil -} - -// SetState sets the state in redis store. -func (c *provider) SetState(key, value string) error { - err := c.store.Set(c.ctx, "authorizer_"+key, value, 0).Err() - if err != nil { - log.Debug("Error saving redis token: ", err) - return err - } - - return nil -} - -// GetState gets the state from redis store. -func (c *provider) GetState(key string) string { - state := "" - state, err := c.store.Get(c.ctx, "authorizer_"+key).Result() - if err != nil { - log.Debug("error getting token from redis store: ", err) - } - - return state -} - -// RemoveState removes the state from redis store. -func (c *provider) RemoveState(key string) error { - err := c.store.Del(c.ctx, "authorizer_"+key).Err() - if err != nil { - log.Fatalln("Error deleting redis token: ", err) - return err - } - - return nil -} diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go new file mode 100644 index 0000000..0390436 --- /dev/null +++ b/server/memorystore/providers/redis/store.go @@ -0,0 +1,162 @@ +package redis + +import ( + "strings" + + log "github.com/sirupsen/logrus" +) + +var ( + // session store prefix + sessionStorePrefix = "authorizer_session_" + // env store prefix + envStorePrefix = "authorizer_env_" +) + +// ClearStore clears the redis store for authorizer related tokens +func (c *provider) ClearStore() error { + err := c.store.Del(c.ctx, sessionStorePrefix+"*").Err() + if err != nil { + log.Debug("Error clearing redis store: ", err) + return err + } + + return nil +} + +// GetUserSessions returns all the user session token from the redis store. +func (c *provider) GetUserSessions(userID string) map[string]string { + data, err := c.store.HGetAll(c.ctx, "*").Result() + if err != nil { + log.Debug("error getting token from redis store: ", err) + } + + res := map[string]string{} + for k, v := range data { + split := strings.Split(v, "@") + if split[1] == userID { + res[k] = split[0] + } + } + + return res +} + +// DeleteAllUserSession deletes all the user session from redis +func (c *provider) DeleteAllUserSession(userId string) error { + sessions := c.GetUserSessions(userId) + for k, v := range sessions { + if k == "token" { + err := c.store.Del(c.ctx, v).Err() + if err != nil { + log.Debug("Error deleting redis token: ", err) + return err + } + } + } + + return nil +} + +// SetState sets the state in redis store. +func (c *provider) SetState(key, value string) error { + err := c.store.Set(c.ctx, sessionStorePrefix+key, value, 0).Err() + if err != nil { + log.Debug("Error saving redis token: ", err) + return err + } + + return nil +} + +// GetState gets the state from redis store. +func (c *provider) GetState(key string) (string, error) { + var res string + err := c.store.Get(c.ctx, sessionStorePrefix+key).Scan(&res) + if err != nil { + log.Debug("error getting token from redis store: ", err) + } + + return res, err +} + +// RemoveState removes the state from redis store. +func (c *provider) RemoveState(key string) error { + err := c.store.Del(c.ctx, sessionStorePrefix+key).Err() + if err != nil { + log.Fatalln("Error deleting redis token: ", err) + return err + } + + return nil +} + +// UpdateEnvStore to update the whole env store object +func (c *provider) UpdateEnvStore(store map[string]interface{}) error { + for key, value := range store { + err := c.store.Set(c.ctx, envStorePrefix+key, value, 0).Err() + if err != nil { + log.Debug("Error saving redis token: ", err) + return err + } + } + return nil +} + +// GetEnvStore returns the whole env store object +func (c *provider) GetEnvStore() (map[string]interface{}, error) { + var res map[string]interface{} + err := c.store.HGetAll(c.ctx, envStorePrefix+"*").Scan(res) + if err != nil { + log.Debug("error getting token from redis store: ", err) + return nil, err + } + + return res, nil +} + +// UpdateEnvVariable to update the particular env variable +func (c *provider) UpdateEnvVariable(key string, value interface{}) error { + err := c.store.Set(c.ctx, envStorePrefix+key, value, 0).Err() + if err != nil { + log.Debug("Error saving redis token: ", err) + return err + } + return nil +} + +// GetStringStoreEnvVariable to get the string env variable from env store +func (c *provider) GetStringStoreEnvVariable(key string) (string, error) { + var res string + err := c.store.Get(c.ctx, envStorePrefix+key).Scan(&res) + if err != nil { + log.Debug("error getting token from redis store: ", err) + return "", err + } + + return res, nil +} + +// GetBoolStoreEnvVariable to get the bool env variable from env store +func (c *provider) GetBoolStoreEnvVariable(key string) (bool, error) { + var res bool + err := c.store.Get(c.ctx, envStorePrefix+key).Scan(res) + if err != nil { + log.Debug("error getting token from redis store: ", err) + return false, err + } + + return res, nil +} + +// GetSliceStoreEnvVariable to get the string slice env variable from env store +func (c *provider) GetSliceStoreEnvVariable(key string) ([]string, error) { + var res []string + err := c.store.Get(c.ctx, envStorePrefix+key).Scan(&res) + if err != nil { + log.Debug("error getting token from redis store: ", err) + return nil, err + } + + return res, nil +} diff --git a/server/memorystore/required_env_store.go b/server/memorystore/required_env_store.go index d2780f7..a6c44c7 100644 --- a/server/memorystore/required_env_store.go +++ b/server/memorystore/required_env_store.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -96,7 +95,7 @@ func InitRequiredEnv() error { } } - if strings.TrimSpace(dbURL) == "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) == "" { + if strings.TrimSpace(dbURL) == "" { if utils.ARG_DB_URL != nil && *utils.ARG_DB_URL != "" { dbURL = strings.TrimSpace(*utils.ARG_DB_URL) } diff --git a/server/oauth/oauth.go b/server/oauth/oauth.go index 3618a9a..27bfa69 100644 --- a/server/oauth/oauth.go +++ b/server/oauth/oauth.go @@ -9,7 +9,7 @@ import ( githubOAuth2 "golang.org/x/oauth2/github" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // OAuthProviders is a struct that contains reference all the OAuth providers @@ -34,32 +34,58 @@ var ( // InitOAuth initializes the OAuth providers based on EnvData func InitOAuth() error { ctx := context.Background() - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) != "" { + googleClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) + if err != nil { + googleClientID = "" + } + googleClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) + if err != nil { + googleClientSecret = "" + } + if googleClientID != "" && googleClientSecret != "" { p, err := oidc.NewProvider(ctx, "https://accounts.google.com") if err != nil { return err } OIDCProviders.GoogleOIDC = p OAuthProviders.GoogleConfig = &oauth2.Config{ - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID), - ClientSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret), + ClientID: googleClientID, + ClientSecret: googleClientSecret, RedirectURL: "/oauth_callback/google", Endpoint: OIDCProviders.GoogleOIDC.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } } - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret) != "" { + + githubClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID) + if err != nil { + githubClientID = "" + } + githubClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret) + if err != nil { + githubClientSecret = "" + } + if githubClientID != "" && githubClientSecret != "" { OAuthProviders.GithubConfig = &oauth2.Config{ - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID), - ClientSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret), + ClientID: githubClientID, + ClientSecret: githubClientSecret, RedirectURL: "/oauth_callback/github", Endpoint: githubOAuth2.Endpoint, } } - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) != "" { + + facebookClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID) + if err != nil { + facebookClientID = "" + } + facebookClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientSecret) + if err != nil { + facebookClientSecret = "" + } + if facebookClientID != "" && facebookClientSecret != "" { OAuthProviders.FacebookConfig = &oauth2.Config{ - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID), - ClientSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientSecret), + ClientID: facebookClientID, + ClientSecret: facebookClientSecret, RedirectURL: "/oauth_callback/facebook", Endpoint: facebookOAuth2.Endpoint, Scopes: []string{"public_profile", "email"}, diff --git a/server/resolvers/validate_jwt_token.go b/server/resolvers/validate_jwt_token.go index 0fba3b9..eda7918 100644 --- a/server/resolvers/validate_jwt_token.go +++ b/server/resolvers/validate_jwt_token.go @@ -38,8 +38,8 @@ func ValidateJwtTokenResolver(ctx context.Context, params model.ValidateJWTToken nonce := "" // access_token and refresh_token should be validated from session store as well if tokenType == "access_token" || tokenType == "refresh_token" { - savedSession := memorystore.Provider.GetState(params.Token) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(params.Token) + if savedSession == "" || err != nil { return &model.ValidateJWTTokenResponse{ IsValid: false, }, nil diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 1ef6700..1e74baa 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -186,8 +186,8 @@ func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interf return res, fmt.Errorf(`unauthorized`) } - savedSession := memorystore.Provider.GetState(accessToken) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(accessToken) + if savedSession == "" || err != nil { return res, fmt.Errorf(`unauthorized`) } @@ -196,7 +196,7 @@ func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interf userID := savedSessionSplit[1] hostname := utils.GetHost(gc) - res, err := ParseJWTToken(accessToken, hostname, nonce, userID) + res, err = ParseJWTToken(accessToken, hostname, nonce, userID) if err != nil { return res, err } @@ -216,8 +216,8 @@ func ValidateRefreshToken(gc *gin.Context, refreshToken string) (map[string]inte return res, fmt.Errorf(`unauthorized`) } - savedSession := memorystore.Provider.GetState(refreshToken) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(refreshToken) + if savedSession == "" || err != nil { return res, fmt.Errorf(`unauthorized`) } @@ -226,7 +226,7 @@ func ValidateRefreshToken(gc *gin.Context, refreshToken string) (map[string]inte userID := savedSessionSplit[1] hostname := utils.GetHost(gc) - res, err := ParseJWTToken(refreshToken, hostname, nonce, userID) + res, err = ParseJWTToken(refreshToken, hostname, nonce, userID) if err != nil { return res, err } @@ -243,8 +243,8 @@ func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionD return nil, fmt.Errorf(`unauthorized`) } - savedSession := memorystore.Provider.GetState(encryptedSession) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(encryptedSession) + if savedSession == "" || err != nil { return nil, fmt.Errorf(`unauthorized`) } diff --git a/server/utils/common.go b/server/utils/common.go index badd7ea..86156be 100644 --- a/server/utils/common.go +++ b/server/utils/common.go @@ -47,3 +47,14 @@ func ConvertInterfaceToSlice(slice interface{}) []interface{} { return ret } + +// ConvertInterfaceToStringSlice to convert interface to string slice +func ConvertInterfaceToStringSlice(slice interface{}) []string { + data := slice.([]interface{}) + var resSlice []string + + for _, v := range data { + resSlice = append(resSlice, v.(string)) + } + return resSlice +}