diff --git a/Makefile b/Makefile index b5539f7..ea87ffa 100644 --- a/Makefile +++ b/Makefile @@ -11,14 +11,26 @@ clean: rm -rf build test: rm -rf server/test/test.db && rm -rf test.db && cd server && go clean --testcache && TEST_DBS="sqlite" go test -p 1 -v ./test +test-mongodb: + docker run -d --name authorizer_mongodb_db -p 27017:27017 mongo:4.4.15 + cd server && go clean --testcache && TEST_DBS="mongodb" go test -p 1 -v ./test + docker rm -vf authorizer_mongodb_db +test-scylladb: + docker run -d --name authorizer_scylla_db -p 9042:9042 scylladb/scylla + cd server && go clean --testcache && TEST_DBS="scylladb" go test -p 1 -v ./test + docker rm -vf authorizer_scylla_db +test-arangodb: + docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.8.4 + cd server && go clean --testcache && TEST_DBS="arangodb" go test -p 1 -v ./test + docker rm -vf authorizer_arangodb test-all-db: rm -rf server/test/test.db && rm -rf test.db docker run -d --name authorizer_scylla_db -p 9042:9042 scylladb/scylla docker run -d --name authorizer_mongodb_db -p 27017:27017 mongo:4.4.15 docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.8.4 cd server && go clean --testcache && TEST_DBS="sqlite,mongodb,arangodb,scylladb" go test -p 1 -v ./test - docker rm -vf authorizer_mongodb_db docker rm -vf authorizer_scylla_db + docker rm -vf authorizer_mongodb_db docker rm -vf authorizer_arangodb generate: cd server && go get github.com/99designs/gqlgen/cmd@v0.14.0 && go run github.com/99designs/gqlgen generate diff --git a/server/constants/env.go b/server/constants/env.go index dc71bc6..2b26688 100644 --- a/server/constants/env.go +++ b/server/constants/env.go @@ -119,6 +119,8 @@ const ( EnvKeyDisableRedisForEnv = "DISABLE_REDIS_FOR_ENV" // EnvKeyDisableStrongPassword key for env variable DISABLE_STRONG_PASSWORD EnvKeyDisableStrongPassword = "DISABLE_STRONG_PASSWORD" + // EnvKeyEnforceMultiFactorAuthentication is key for env variable ENFORCE_MULTI_FACTOR_AUTHENTICATION + EnvKeyEnforceMultiFactorAuthentication = "ENFORCE_MULTI_FACTOR_AUTHENTICATION" // Slice variables // EnvKeyRoles key for env variable ROLES diff --git a/server/db/models/user.go b/server/db/models/user.go index f5fc61d..bc5bc04 100644 --- a/server/db/models/user.go +++ b/server/db/models/user.go @@ -38,12 +38,12 @@ func (user *User) AsAPIUser() *model.User { isEmailVerified := user.EmailVerifiedAt != nil isPhoneVerified := user.PhoneNumberVerifiedAt != nil - id := user.ID - if strings.Contains(id, Collections.WebhookLog+"/") { - id = strings.TrimPrefix(id, Collections.WebhookLog+"/") - } + // id := user.ID + // if strings.Contains(id, Collections.User+"/") { + // id = strings.TrimPrefix(id, Collections.User+"/") + // } return &model.User{ - ID: id, + ID: user.ID, Email: user.Email, EmailVerified: isEmailVerified, SignupMethods: user.SignupMethods, diff --git a/server/db/models/verification_requests.go b/server/db/models/verification_requests.go index 3e13a7f..992d9d8 100644 --- a/server/db/models/verification_requests.go +++ b/server/db/models/verification_requests.go @@ -25,8 +25,8 @@ type VerificationRequest struct { func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequest { id := v.ID - if strings.Contains(id, Collections.WebhookLog+"/") { - id = strings.TrimPrefix(id, Collections.WebhookLog+"/") + if strings.Contains(id, Collections.VerificationRequest+"/") { + id = strings.TrimPrefix(id, Collections.VerificationRequest+"/") } return &model.VerificationRequest{ diff --git a/server/db/providers/arangodb/email_template.go b/server/db/providers/arangodb/email_template.go index 4e64762..70dd474 100644 --- a/server/db/providers/arangodb/email_template.go +++ b/server/db/providers/arangodb/email_template.go @@ -16,6 +16,7 @@ import ( func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() + emailTemplate.Key = emailTemplate.ID } emailTemplate.Key = emailTemplate.ID diff --git a/server/db/providers/arangodb/env.go b/server/db/providers/arangodb/env.go index 2c884d4..29687a8 100644 --- a/server/db/providers/arangodb/env.go +++ b/server/db/providers/arangodb/env.go @@ -15,6 +15,7 @@ import ( func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() + env.Key = env.ID } env.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/arangodb/otp.go b/server/db/providers/arangodb/otp.go index 076990b..29f265a 100644 --- a/server/db/providers/arangodb/otp.go +++ b/server/db/providers/arangodb/otp.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/arangodb/go-driver" "github.com/authorizerdev/authorizer/server/db/models" "github.com/google/uuid" ) @@ -14,32 +15,38 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) shouldCreate := false if otp == nil { + id := uuid.NewString() + otp = &models.OTP{ + ID: id, + Key: id, + Otp: otpParam.Otp, + Email: otpParam.Email, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), + } shouldCreate = true - otp.ID = uuid.New().String() - otp.Key = otp.ID - otp.CreatedAt = time.Now().Unix() } else { - otp = otpParam + otp.Otp = otpParam.Otp + otp.ExpiresAt = otpParam.ExpiresAt } otp.UpdatedAt = time.Now().Unix() otpCollection, _ := p.db.Collection(ctx, models.Collections.OTP) + var meta driver.DocumentMeta + var err error if shouldCreate { - _, err := otpCollection.CreateDocument(ctx, otp) - if err != nil { - return nil, err - } + meta, err = otpCollection.CreateDocument(ctx, otp) } else { - meta, err := otpCollection.UpdateDocument(ctx, otp.Key, otp) - if err != nil { - return nil, err - } - - otp.Key = meta.Key - otp.ID = meta.ID.String() + meta, err = otpCollection.UpdateDocument(ctx, otp.Key, otp) } + if err != nil { + return nil, err + } + + otp.Key = meta.Key + otp.ID = meta.ID.String() return otp, nil } diff --git a/server/db/providers/arangodb/session.go b/server/db/providers/arangodb/session.go index 96896e5..9bc46ca 100644 --- a/server/db/providers/arangodb/session.go +++ b/server/db/providers/arangodb/session.go @@ -12,6 +12,7 @@ import ( func (p *provider) AddSession(ctx context.Context, session models.Session) error { if session.ID == "" { session.ID = uuid.New().String() + session.Key = session.ID } session.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/arangodb/user.go b/server/db/providers/arangodb/user.go index abc3ec0..945de33 100644 --- a/server/db/providers/arangodb/user.go +++ b/server/db/providers/arangodb/user.go @@ -2,22 +2,26 @@ package arangodb import ( "context" + "encoding/json" "fmt" + "strings" "time" "github.com/arangodb/go-driver" arangoDriver "github.com/arangodb/go-driver" + "github.com/google/uuid" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" - "github.com/google/uuid" ) // AddUser to save user information in database func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { if user.ID == "" { user.ID = uuid.New().String() + user.Key = user.ID } if user.Roles == "" { @@ -65,7 +69,7 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @user_id REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session) bindVars := map[string]interface{}{ - "user_id": user.ID, + "user_id": user.Key, } cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { @@ -174,3 +178,36 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err return user, nil } + +// UpdateUsers to update multiple users, with parameters of user IDs slice +// If ids set to nil / empty all the users will be updated +func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { + // set updated_at time for all users + data["updated_at"] = time.Now().Unix() + + userInfoBytes, err := json.Marshal(data) + if err != nil { + return err + } + + query := "" + if ids != nil && len(ids) > 0 { + keysArray := "" + for _, id := range ids { + keysArray += fmt.Sprintf("'%s', ", id) + } + keysArray = strings.Trim(keysArray, " ") + keysArray = strings.TrimSuffix(keysArray, ",") + query = fmt.Sprintf("FOR u IN %s FILTER u._id IN [%s] UPDATE u._key with %s IN %s", models.Collections.User, keysArray, string(userInfoBytes), models.Collections.User) + } else { + query = fmt.Sprintf("FOR u IN %s UPDATE u._key with %s IN %s", models.Collections.User, string(userInfoBytes), models.Collections.User) + } + + _, err = p.db.Query(ctx, query, nil) + + if err != nil { + return err + } + + return nil +} diff --git a/server/db/providers/arangodb/verification_requests.go b/server/db/providers/arangodb/verification_requests.go index a1dbfa2..8722bad 100644 --- a/server/db/providers/arangodb/verification_requests.go +++ b/server/db/providers/arangodb/verification_requests.go @@ -15,6 +15,7 @@ import ( func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() + verificationRequest.Key = verificationRequest.ID } verificationRequest.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/arangodb/webhook.go b/server/db/providers/arangodb/webhook.go index 302eb61..2fd62da 100644 --- a/server/db/providers/arangodb/webhook.go +++ b/server/db/providers/arangodb/webhook.go @@ -16,6 +16,7 @@ import ( func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() + webhook.Key = webhook.ID } webhook.Key = webhook.ID diff --git a/server/db/providers/arangodb/webhook_log.go b/server/db/providers/arangodb/webhook_log.go index bc758c4..35565e1 100644 --- a/server/db/providers/arangodb/webhook_log.go +++ b/server/db/providers/arangodb/webhook_log.go @@ -16,6 +16,7 @@ import ( func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() + webhookLog.Key = webhookLog.ID } webhookLog.Key = webhookLog.ID diff --git a/server/db/providers/cassandradb/otp.go b/server/db/providers/cassandradb/otp.go index 7ead206..bfe481d 100644 --- a/server/db/providers/cassandradb/otp.go +++ b/server/db/providers/cassandradb/otp.go @@ -16,21 +16,27 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models shouldCreate := false if otp == nil { shouldCreate = true - otp.ID = uuid.New().String() - otp.Key = otp.ID - otp.CreatedAt = time.Now().Unix() + otp = &models.OTP{ + ID: uuid.NewString(), + Otp: otpParam.Otp, + Email: otpParam.Email, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), + } } else { - otp = otpParam + otp.Otp = otpParam.Otp + otp.ExpiresAt = otpParam.ExpiresAt } otp.UpdatedAt = time.Now().Unix() query := "" - if shouldCreate { query = fmt.Sprintf(`INSERT INTO %s (id, email, otp, expires_at, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d, %d)`, KeySpace+"."+models.Collections.OTP, otp.ID, otp.Email, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt) } else { - query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE email = '%s'`, KeySpace+"."+models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.Email) + query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE id = '%s'`, KeySpace+"."+models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID) } + err := p.db.Query(query).Exec() if err != nil { return nil, err diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index 80a9cb6..329dad2 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -13,6 +13,7 @@ import ( "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gocql/gocql" cansandraDriver "github.com/gocql/gocql" + log "github.com/sirupsen/logrus" ) type provider struct { @@ -99,6 +100,7 @@ func NewProvider() (*provider, error) { cassandraClient.Consistency = gocql.LocalQuorum cassandraClient.ConnectTimeout = 10 * time.Second cassandraClient.ProtoVersion = 4 + cassandraClient.Timeout = 30 * time.Minute // for large data session, err := cassandraClient.CreateSession() if err != nil { @@ -160,10 +162,11 @@ func NewProvider() (*provider, error) { return nil, err } // add is_multi_factor_auth_enabled on users table - userTableAlterQuery := fmt.Sprintf(`ALTER TABLE %s.%s ADD is_multi_factor_auth_enabled boolean;`, KeySpace, models.Collections.User) + userTableAlterQuery := fmt.Sprintf(`ALTER TABLE %s.%s ADD is_multi_factor_auth_enabled boolean`, KeySpace, models.Collections.User) err = session.Query(userTableAlterQuery).Exec() if err != nil { - return nil, err + log.Debug("Failed to alter table as column exists: ", err) + // return nil, err } // token is reserved keyword in cassandra, hence we need to use jwt_token diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go index 9489cdd..4da7ec9 100644 --- a/server/db/providers/cassandradb/user.go +++ b/server/db/providers/cassandradb/user.go @@ -107,7 +107,7 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } if value == nil { - updateFields += fmt.Sprintf("%s = null,", key) + updateFields += fmt.Sprintf("%s = null, ", key) continue } @@ -122,7 +122,6 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use updateFields = strings.TrimSuffix(updateFields, ",") query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID) - err = p.db.Query(query).Exec() if err != nil { return user, err @@ -173,14 +172,14 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( // there is no offset in cassandra // so we fetch till limit + offset // and return the results from offset to limit - query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) scanner := p.db.Query(query).Iter().Scanner() counter := int64(0) for scanner.Next() { if counter >= pagination.Offset { var user models.User - err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) + err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt) if err != nil { return nil, err } @@ -197,8 +196,8 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( // GetUserByEmail to get user information from database using email address func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { var user models.User - query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email) - err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email) + err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt) if err != nil { return user, err } @@ -208,10 +207,95 @@ func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.Use // GetUserByID to get user information from database using user ID func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { var user models.User - query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id) - err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id) + err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt) if err != nil { return user, err } return user, nil } + +// UpdateUsers to update multiple users, with parameters of user IDs slice +// If ids set to nil / empty all the users will be updated +func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { + // set updated_at time for all users + data["updated_at"] = time.Now().Unix() + + updateFields := "" + for key, value := range data { + if key == "_id" { + continue + } + + if key == "_key" { + continue + } + + if value == nil { + updateFields += fmt.Sprintf("%s = null,", key) + continue + } + + valueType := reflect.TypeOf(value) + if valueType.Name() == "string" { + updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string)) + } else { + updateFields += fmt.Sprintf("%s = %v, ", key, value) + } + } + updateFields = strings.Trim(updateFields, " ") + updateFields = strings.TrimSuffix(updateFields, ",") + + query := "" + if ids != nil && len(ids) > 0 { + idsString := "" + for _, id := range ids { + idsString += fmt.Sprintf("'%s', ", id) + } + idsString = strings.Trim(idsString, " ") + idsString = strings.TrimSuffix(idsString, ",") + query = fmt.Sprintf("UPDATE %s SET %s WHERE id IN (%s)", KeySpace+"."+models.Collections.User, updateFields, idsString) + err := p.db.Query(query).Exec() + if err != nil { + return err + } + } else { + // get all ids + getUserIDsQuery := fmt.Sprintf(`SELECT id FROM %s`, KeySpace+"."+models.Collections.User) + scanner := p.db.Query(getUserIDsQuery).Iter().Scanner() + // only 100 ids are allowed in 1 query + // hence we need create multiple update queries + idsString := "" + idsStringArray := []string{idsString} + counter := 1 + for scanner.Next() { + var id string + err := scanner.Scan(&id) + if err == nil { + idsString += fmt.Sprintf("'%s', ", id) + } + counter++ + if counter > 100 { + idsStringArray = append(idsStringArray, idsString) + counter = 1 + idsString = "" + } else { + // update the last index of array when count is less than 100 + idsStringArray[len(idsStringArray)-1] = idsString + } + } + + for _, idStr := range idsStringArray { + idStr = strings.Trim(idStr, " ") + idStr = strings.TrimSuffix(idStr, ",") + query = fmt.Sprintf("UPDATE %s SET %s WHERE id IN (%s)", KeySpace+"."+models.Collections.User, updateFields, idStr) + err := p.db.Query(query).Exec() + if err != nil { + return err + } + } + + } + + return nil +} diff --git a/server/db/providers/mongodb/otp.go b/server/db/providers/mongodb/otp.go index bbf0426..d6ff2df 100644 --- a/server/db/providers/mongodb/otp.go +++ b/server/db/providers/mongodb/otp.go @@ -11,19 +11,33 @@ import ( ) // UpsertOTP to add or update otp -func (p *provider) UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - if otp.ID == "" { - otp.ID = uuid.New().String() - } - - otp.Key = otp.ID - if otp.CreatedAt <= 0 { - otp.CreatedAt = time.Now().Unix() +func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { + otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + shouldCreate := false + if otp == nil { + id := uuid.NewString() + otp = &models.OTP{ + ID: id, + Key: id, + Otp: otpParam.Otp, + Email: otpParam.Email, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), + } + shouldCreate = true + } else { + otp.Otp = otpParam.Otp + otp.ExpiresAt = otpParam.ExpiresAt } otp.UpdatedAt = time.Now().Unix() - otpCollection := p.db.Collection(models.Collections.OTP, options.Collection()) - _, err := otpCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": otp.ID}}, bson.M{"$set": otp}, options.MergeUpdateOptions().SetUpsert(true)) + + var err error + if shouldCreate { + _, err = otpCollection.InsertOne(ctx, otp) + } else { + _, err = otpCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": otp.ID}}, bson.M{"$set": otp}, options.MergeUpdateOptions()) + } if err != nil { return nil, err } diff --git a/server/db/providers/mongodb/user.go b/server/db/providers/mongodb/user.go index 7518fd9..6e90a40 100644 --- a/server/db/providers/mongodb/user.go +++ b/server/db/providers/mongodb/user.go @@ -9,7 +9,9 @@ import ( "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" + log "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -129,3 +131,27 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err return user, nil } + +// UpdateUsers to update multiple users, with parameters of user IDs slice +// If ids set to nil / empty all the users will be updated +func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { + // set updated_at time for all users + data["updated_at"] = time.Now().Unix() + + userCollection := p.db.Collection(models.Collections.User, options.Collection()) + + var res *mongo.UpdateResult + var err error + if ids != nil && len(ids) > 0 { + res, err = userCollection.UpdateMany(ctx, bson.M{"_id": bson.M{"$in": ids}}, bson.M{"$set": data}) + } else { + res, err = userCollection.UpdateMany(ctx, bson.M{}, bson.M{"$set": data}) + } + + if err != nil { + return err + } else { + log.Info("Updated users: ", res.ModifiedCount) + } + return nil +} diff --git a/server/db/providers/provider_template/user.go b/server/db/providers/provider_template/user.go index 00b2db8..2b167db 100644 --- a/server/db/providers/provider_template/user.go +++ b/server/db/providers/provider_template/user.go @@ -60,3 +60,12 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err return user, nil } + +// UpdateUsers to update multiple users, with parameters of user IDs slice +// If ids set to nil / empty all the users will be updated +func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { + // set updated_at time for all users + data["updated_at"] = time.Now().Unix() + + return nil +} diff --git a/server/db/providers/providers.go b/server/db/providers/providers.go index da72190..a578396 100644 --- a/server/db/providers/providers.go +++ b/server/db/providers/providers.go @@ -20,6 +20,9 @@ type Provider interface { GetUserByEmail(ctx context.Context, email string) (models.User, error) // GetUserByID to get user information from database using user ID GetUserByID(ctx context.Context, id string) (models.User, error) + // UpdateUsers to update multiple users, with parameters of user IDs slice + // If ids set to nil / empty all the users will be updated + UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error // AddVerification to save verification request in database AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) diff --git a/server/db/providers/sql/provider.go b/server/db/providers/sql/provider.go index 394bff3..712f3d1 100644 --- a/server/db/providers/sql/provider.go +++ b/server/db/providers/sql/provider.go @@ -40,6 +40,7 @@ func NewProvider() (*provider, error) { NamingStrategy: schema.NamingStrategy{ TablePrefix: models.Prefix, }, + AllowGlobalUpdate: true, } dbType := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseType diff --git a/server/db/providers/sql/user.go b/server/db/providers/sql/user.go index c5953ce..c191935 100644 --- a/server/db/providers/sql/user.go +++ b/server/db/providers/sql/user.go @@ -9,6 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" + "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -121,3 +122,22 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err return user, nil } + +// UpdateUsers to update multiple users, with parameters of user IDs slice +// If ids set to nil / empty all the users will be updated +func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { + // set updated_at time for all users + data["updated_at"] = time.Now().Unix() + + var res *gorm.DB + if ids != nil && len(ids) > 0 { + res = p.db.Model(&models.User{}).Where("id in ?", ids).Updates(data) + } else { + res = p.db.Model(&models.User{}).Updates(data) + } + + if res.Error != nil { + return res.Error + } + return nil +} diff --git a/server/env/env.go b/server/env/env.go index c492b13..e8fe863 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -84,6 +84,7 @@ func InitAllEnv() error { osDisableSignUp := os.Getenv(constants.EnvKeyDisableSignUp) osDisableRedisForEnv := os.Getenv(constants.EnvKeyDisableRedisForEnv) osDisableStrongPassword := os.Getenv(constants.EnvKeyDisableStrongPassword) + osEnforceMultiFactorAuthentication := os.Getenv(constants.EnvKeyEnforceMultiFactorAuthentication) // os slice vars osAllowedOrigins := os.Getenv(constants.EnvKeyAllowedOrigins) @@ -490,6 +491,19 @@ func InitAllEnv() error { } } + if _, ok := envData[constants.EnvKeyEnforceMultiFactorAuthentication]; !ok { + envData[constants.EnvKeyEnforceMultiFactorAuthentication] = osEnforceMultiFactorAuthentication == "true" + } + if osEnforceMultiFactorAuthentication != "" { + boolValue, err := strconv.ParseBool(osEnforceMultiFactorAuthentication) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) { + envData[constants.EnvKeyEnforceMultiFactorAuthentication] = boolValue + } + } + // no need to add nil check as its already done above if envData[constants.EnvKeySmtpHost] == "" || envData[constants.EnvKeySmtpUsername] == "" || envData[constants.EnvKeySmtpPassword] == "" || envData[constants.EnvKeySenderEmail] == "" && envData[constants.EnvKeySmtpPort] == "" { envData[constants.EnvKeyDisableEmailVerification] = true @@ -501,6 +515,10 @@ func InitAllEnv() error { envData[constants.EnvKeyIsEmailServiceEnabled] = true } + if envData[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) && !envData[constants.EnvKeyIsEmailServiceEnabled].(bool) { + return errors.New("to enable multi factor authentication, please enable email service") + } + if envData[constants.EnvKeyDisableEmailVerification].(bool) { envData[constants.EnvKeyDisableMagicLinkLogin] = true } diff --git a/server/env/persist_env.go b/server/env/persist_env.go index d783b93..e4ec275 100644 --- a/server/env/persist_env.go +++ b/server/env/persist_env.go @@ -201,7 +201,7 @@ func PersistEnv() error { envValue := strings.TrimSpace(os.Getenv(key)) if envValue != "" { switch key { - case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv, constants.EnvKeyDisableStrongPassword, constants.EnvKeyIsEmailServiceEnabled: + case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv, constants.EnvKeyDisableStrongPassword, constants.EnvKeyIsEmailServiceEnabled, constants.EnvKeyEnforceMultiFactorAuthentication: if envValueBool, err := strconv.ParseBool(envValue); err == nil { if value.(bool) != envValueBool { storeData[key] = envValueBool diff --git a/server/go.mod b/server/go.mod index 13e3e52..98dcc94 100644 --- a/server/go.mod +++ b/server/go.mod @@ -5,11 +5,12 @@ go 1.16 require ( github.com/99designs/gqlgen v0.14.0 github.com/arangodb/go-driver v1.2.1 + github.com/coreos/etcd v3.3.27+incompatible github.com/coreos/go-oidc/v3 v3.1.0 github.com/gin-gonic/gin v1.7.2 github.com/go-playground/validator/v10 v10.8.0 // indirect github.com/go-redis/redis/v8 v8.11.0 - github.com/gocql/gocql v1.0.0 + github.com/gocql/gocql v1.2.0 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.2 // indirect github.com/google/uuid v1.3.0 diff --git a/server/go.sum b/server/go.sum index 9392d6f..7a1bbcb 100644 --- a/server/go.sum +++ b/server/go.sum @@ -62,6 +62,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/etcd v3.3.27+incompatible h1:QIudLb9KeBsE5zyYxd1mjzRSkzLg9Wf9QlRwFgd6oTA= +github.com/coreos/etcd v3.3.27+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-iptables v0.4.3/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw= github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= @@ -112,6 +114,8 @@ github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gocql/gocql v1.0.0 h1:UnbTERpP72VZ/viKE1Q1gPtmLvyTZTvuAstvSRydw/c= github.com/gocql/gocql v1.0.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= +github.com/gocql/gocql v1.2.0 h1:TZhsCd7fRuye4VyHr3WCvWwIQaZUmjsqnSIXK9FcVCE= +github.com/gocql/gocql v1.2.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index 44c181f..0d247f3 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -67,54 +67,55 @@ type ComplexityRoot struct { } Env struct { - AccessTokenExpiryTime func(childComplexity int) int - AdminSecret func(childComplexity int) int - AllowedOrigins func(childComplexity int) int - AppURL func(childComplexity int) int - AppleClientID func(childComplexity int) int - AppleClientSecret func(childComplexity int) int - ClientID func(childComplexity int) int - ClientSecret func(childComplexity int) int - CustomAccessTokenScript func(childComplexity int) int - DatabaseHost func(childComplexity int) int - DatabaseName func(childComplexity int) int - DatabasePassword func(childComplexity int) int - DatabasePort func(childComplexity int) int - DatabaseType func(childComplexity int) int - DatabaseURL func(childComplexity int) int - DatabaseUsername func(childComplexity int) int - DefaultRoles func(childComplexity int) int - DisableBasicAuthentication func(childComplexity int) int - DisableEmailVerification func(childComplexity int) int - DisableLoginPage func(childComplexity int) int - DisableMagicLinkLogin func(childComplexity int) int - DisableRedisForEnv func(childComplexity int) int - DisableSignUp func(childComplexity int) int - DisableStrongPassword func(childComplexity int) int - FacebookClientID func(childComplexity int) int - FacebookClientSecret func(childComplexity int) int - GithubClientID func(childComplexity int) int - GithubClientSecret func(childComplexity int) int - GoogleClientID func(childComplexity int) int - GoogleClientSecret func(childComplexity int) int - JwtPrivateKey func(childComplexity int) int - JwtPublicKey func(childComplexity int) int - JwtRoleClaim func(childComplexity int) int - JwtSecret func(childComplexity int) int - JwtType func(childComplexity int) int - LinkedinClientID func(childComplexity int) int - LinkedinClientSecret func(childComplexity int) int - OrganizationLogo func(childComplexity int) int - OrganizationName func(childComplexity int) int - ProtectedRoles func(childComplexity int) int - RedisURL func(childComplexity int) int - ResetPasswordURL func(childComplexity int) int - Roles func(childComplexity int) int - SMTPHost func(childComplexity int) int - SMTPPassword func(childComplexity int) int - SMTPPort func(childComplexity int) int - SMTPUsername func(childComplexity int) int - SenderEmail func(childComplexity int) int + AccessTokenExpiryTime func(childComplexity int) int + AdminSecret func(childComplexity int) int + AllowedOrigins func(childComplexity int) int + AppURL func(childComplexity int) int + AppleClientID func(childComplexity int) int + AppleClientSecret func(childComplexity int) int + ClientID func(childComplexity int) int + ClientSecret func(childComplexity int) int + CustomAccessTokenScript func(childComplexity int) int + DatabaseHost func(childComplexity int) int + DatabaseName func(childComplexity int) int + DatabasePassword func(childComplexity int) int + DatabasePort func(childComplexity int) int + DatabaseType func(childComplexity int) int + DatabaseURL func(childComplexity int) int + DatabaseUsername func(childComplexity int) int + DefaultRoles func(childComplexity int) int + DisableBasicAuthentication func(childComplexity int) int + DisableEmailVerification func(childComplexity int) int + DisableLoginPage func(childComplexity int) int + DisableMagicLinkLogin func(childComplexity int) int + DisableRedisForEnv func(childComplexity int) int + DisableSignUp func(childComplexity int) int + DisableStrongPassword func(childComplexity int) int + EnforceMultiFactorAuthentication func(childComplexity int) int + FacebookClientID func(childComplexity int) int + FacebookClientSecret func(childComplexity int) int + GithubClientID func(childComplexity int) int + GithubClientSecret func(childComplexity int) int + GoogleClientID func(childComplexity int) int + GoogleClientSecret func(childComplexity int) int + JwtPrivateKey func(childComplexity int) int + JwtPublicKey func(childComplexity int) int + JwtRoleClaim func(childComplexity int) int + JwtSecret func(childComplexity int) int + JwtType func(childComplexity int) int + LinkedinClientID func(childComplexity int) int + LinkedinClientSecret func(childComplexity int) int + OrganizationLogo func(childComplexity int) int + OrganizationName func(childComplexity int) int + ProtectedRoles func(childComplexity int) int + RedisURL func(childComplexity int) int + ResetPasswordURL func(childComplexity int) int + Roles func(childComplexity int) int + SMTPHost func(childComplexity int) int + SMTPPassword func(childComplexity int) int + SMTPPort func(childComplexity int) int + SMTPUsername func(childComplexity int) int + SenderEmail func(childComplexity int) int } Error struct { @@ -612,6 +613,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Env.DisableStrongPassword(childComplexity), true + case "Env.ENFORCE_MULTI_FACTOR_AUTHENTICATION": + if e.complexity.Env.EnforceMultiFactorAuthentication == nil { + break + } + + return e.complexity.Env.EnforceMultiFactorAuthentication(childComplexity), true + case "Env.FACEBOOK_CLIENT_ID": if e.complexity.Env.FacebookClientID == nil { break @@ -1957,6 +1965,7 @@ type Env { DISABLE_SIGN_UP: Boolean! DISABLE_REDIS_FOR_ENV: Boolean! DISABLE_STRONG_PASSWORD: Boolean! + ENFORCE_MULTI_FACTOR_AUTHENTICATION: Boolean! ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] @@ -2057,6 +2066,7 @@ input UpdateEnvInput { DISABLE_SIGN_UP: Boolean DISABLE_REDIS_FOR_ENV: Boolean DISABLE_STRONG_PASSWORD: Boolean + ENFORCE_MULTI_FACTOR_AUTHENTICATION: Boolean ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] @@ -4415,6 +4425,41 @@ func (ec *executionContext) _Env_DISABLE_STRONG_PASSWORD(ctx context.Context, fi return ec.marshalNBoolean2bool(ctx, field.Selections, res) } +func (ec *executionContext) _Env_ENFORCE_MULTI_FACTOR_AUTHENTICATION(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Env", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.EnforceMultiFactorAuthentication, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + func (ec *executionContext) _Env_ROLES(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -11334,6 +11379,14 @@ func (ec *executionContext) unmarshalInputUpdateEnvInput(ctx context.Context, ob if err != nil { return it, err } + case "ENFORCE_MULTI_FACTOR_AUTHENTICATION": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("ENFORCE_MULTI_FACTOR_AUTHENTICATION")) + it.EnforceMultiFactorAuthentication, err = ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } case "ROLES": var err error @@ -12099,6 +12152,11 @@ func (ec *executionContext) _Env(ctx context.Context, sel ast.SelectionSet, obj if out.Values[i] == graphql.Null { invalids++ } + case "ENFORCE_MULTI_FACTOR_AUTHENTICATION": + out.Values[i] = ec._Env_ENFORCE_MULTI_FACTOR_AUTHENTICATION(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "ROLES": out.Values[i] = ec._Env_ROLES(ctx, field, obj) case "PROTECTED_ROLES": diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index b193c0d..3c5f883 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -54,54 +54,55 @@ type EmailTemplates struct { } type Env struct { - AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` - AdminSecret *string `json:"ADMIN_SECRET"` - DatabaseName *string `json:"DATABASE_NAME"` - DatabaseURL *string `json:"DATABASE_URL"` - DatabaseType *string `json:"DATABASE_TYPE"` - DatabaseUsername *string `json:"DATABASE_USERNAME"` - DatabasePassword *string `json:"DATABASE_PASSWORD"` - DatabaseHost *string `json:"DATABASE_HOST"` - DatabasePort *string `json:"DATABASE_PORT"` - ClientID string `json:"CLIENT_ID"` - ClientSecret string `json:"CLIENT_SECRET"` - CustomAccessTokenScript *string `json:"CUSTOM_ACCESS_TOKEN_SCRIPT"` - SMTPHost *string `json:"SMTP_HOST"` - SMTPPort *string `json:"SMTP_PORT"` - SMTPUsername *string `json:"SMTP_USERNAME"` - SMTPPassword *string `json:"SMTP_PASSWORD"` - SenderEmail *string `json:"SENDER_EMAIL"` - JwtType *string `json:"JWT_TYPE"` - JwtSecret *string `json:"JWT_SECRET"` - JwtPrivateKey *string `json:"JWT_PRIVATE_KEY"` - JwtPublicKey *string `json:"JWT_PUBLIC_KEY"` - AllowedOrigins []string `json:"ALLOWED_ORIGINS"` - AppURL *string `json:"APP_URL"` - RedisURL *string `json:"REDIS_URL"` - ResetPasswordURL *string `json:"RESET_PASSWORD_URL"` - DisableEmailVerification bool `json:"DISABLE_EMAIL_VERIFICATION"` - DisableBasicAuthentication bool `json:"DISABLE_BASIC_AUTHENTICATION"` - DisableMagicLinkLogin bool `json:"DISABLE_MAGIC_LINK_LOGIN"` - DisableLoginPage bool `json:"DISABLE_LOGIN_PAGE"` - DisableSignUp bool `json:"DISABLE_SIGN_UP"` - DisableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"` - DisableStrongPassword bool `json:"DISABLE_STRONG_PASSWORD"` - Roles []string `json:"ROLES"` - ProtectedRoles []string `json:"PROTECTED_ROLES"` - DefaultRoles []string `json:"DEFAULT_ROLES"` - JwtRoleClaim *string `json:"JWT_ROLE_CLAIM"` - GoogleClientID *string `json:"GOOGLE_CLIENT_ID"` - GoogleClientSecret *string `json:"GOOGLE_CLIENT_SECRET"` - GithubClientID *string `json:"GITHUB_CLIENT_ID"` - GithubClientSecret *string `json:"GITHUB_CLIENT_SECRET"` - FacebookClientID *string `json:"FACEBOOK_CLIENT_ID"` - FacebookClientSecret *string `json:"FACEBOOK_CLIENT_SECRET"` - LinkedinClientID *string `json:"LINKEDIN_CLIENT_ID"` - LinkedinClientSecret *string `json:"LINKEDIN_CLIENT_SECRET"` - AppleClientID *string `json:"APPLE_CLIENT_ID"` - AppleClientSecret *string `json:"APPLE_CLIENT_SECRET"` - OrganizationName *string `json:"ORGANIZATION_NAME"` - OrganizationLogo *string `json:"ORGANIZATION_LOGO"` + AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` + AdminSecret *string `json:"ADMIN_SECRET"` + DatabaseName *string `json:"DATABASE_NAME"` + DatabaseURL *string `json:"DATABASE_URL"` + DatabaseType *string `json:"DATABASE_TYPE"` + DatabaseUsername *string `json:"DATABASE_USERNAME"` + DatabasePassword *string `json:"DATABASE_PASSWORD"` + DatabaseHost *string `json:"DATABASE_HOST"` + DatabasePort *string `json:"DATABASE_PORT"` + ClientID string `json:"CLIENT_ID"` + ClientSecret string `json:"CLIENT_SECRET"` + CustomAccessTokenScript *string `json:"CUSTOM_ACCESS_TOKEN_SCRIPT"` + SMTPHost *string `json:"SMTP_HOST"` + SMTPPort *string `json:"SMTP_PORT"` + SMTPUsername *string `json:"SMTP_USERNAME"` + SMTPPassword *string `json:"SMTP_PASSWORD"` + SenderEmail *string `json:"SENDER_EMAIL"` + JwtType *string `json:"JWT_TYPE"` + JwtSecret *string `json:"JWT_SECRET"` + JwtPrivateKey *string `json:"JWT_PRIVATE_KEY"` + JwtPublicKey *string `json:"JWT_PUBLIC_KEY"` + AllowedOrigins []string `json:"ALLOWED_ORIGINS"` + AppURL *string `json:"APP_URL"` + RedisURL *string `json:"REDIS_URL"` + ResetPasswordURL *string `json:"RESET_PASSWORD_URL"` + DisableEmailVerification bool `json:"DISABLE_EMAIL_VERIFICATION"` + DisableBasicAuthentication bool `json:"DISABLE_BASIC_AUTHENTICATION"` + DisableMagicLinkLogin bool `json:"DISABLE_MAGIC_LINK_LOGIN"` + DisableLoginPage bool `json:"DISABLE_LOGIN_PAGE"` + DisableSignUp bool `json:"DISABLE_SIGN_UP"` + DisableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"` + DisableStrongPassword bool `json:"DISABLE_STRONG_PASSWORD"` + EnforceMultiFactorAuthentication bool `json:"ENFORCE_MULTI_FACTOR_AUTHENTICATION"` + Roles []string `json:"ROLES"` + ProtectedRoles []string `json:"PROTECTED_ROLES"` + DefaultRoles []string `json:"DEFAULT_ROLES"` + JwtRoleClaim *string `json:"JWT_ROLE_CLAIM"` + GoogleClientID *string `json:"GOOGLE_CLIENT_ID"` + GoogleClientSecret *string `json:"GOOGLE_CLIENT_SECRET"` + GithubClientID *string `json:"GITHUB_CLIENT_ID"` + GithubClientSecret *string `json:"GITHUB_CLIENT_SECRET"` + FacebookClientID *string `json:"FACEBOOK_CLIENT_ID"` + FacebookClientSecret *string `json:"FACEBOOK_CLIENT_SECRET"` + LinkedinClientID *string `json:"LINKEDIN_CLIENT_ID"` + LinkedinClientSecret *string `json:"LINKEDIN_CLIENT_SECRET"` + AppleClientID *string `json:"APPLE_CLIENT_ID"` + AppleClientSecret *string `json:"APPLE_CLIENT_SECRET"` + OrganizationName *string `json:"ORGANIZATION_NAME"` + OrganizationLogo *string `json:"ORGANIZATION_LOGO"` } type Error struct { @@ -249,45 +250,46 @@ type UpdateEmailTemplateRequest struct { } type UpdateEnvInput struct { - AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` - AdminSecret *string `json:"ADMIN_SECRET"` - CustomAccessTokenScript *string `json:"CUSTOM_ACCESS_TOKEN_SCRIPT"` - OldAdminSecret *string `json:"OLD_ADMIN_SECRET"` - SMTPHost *string `json:"SMTP_HOST"` - SMTPPort *string `json:"SMTP_PORT"` - SMTPUsername *string `json:"SMTP_USERNAME"` - SMTPPassword *string `json:"SMTP_PASSWORD"` - SenderEmail *string `json:"SENDER_EMAIL"` - JwtType *string `json:"JWT_TYPE"` - JwtSecret *string `json:"JWT_SECRET"` - JwtPrivateKey *string `json:"JWT_PRIVATE_KEY"` - JwtPublicKey *string `json:"JWT_PUBLIC_KEY"` - AllowedOrigins []string `json:"ALLOWED_ORIGINS"` - AppURL *string `json:"APP_URL"` - ResetPasswordURL *string `json:"RESET_PASSWORD_URL"` - DisableEmailVerification *bool `json:"DISABLE_EMAIL_VERIFICATION"` - DisableBasicAuthentication *bool `json:"DISABLE_BASIC_AUTHENTICATION"` - DisableMagicLinkLogin *bool `json:"DISABLE_MAGIC_LINK_LOGIN"` - DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"` - DisableSignUp *bool `json:"DISABLE_SIGN_UP"` - DisableRedisForEnv *bool `json:"DISABLE_REDIS_FOR_ENV"` - DisableStrongPassword *bool `json:"DISABLE_STRONG_PASSWORD"` - Roles []string `json:"ROLES"` - ProtectedRoles []string `json:"PROTECTED_ROLES"` - DefaultRoles []string `json:"DEFAULT_ROLES"` - JwtRoleClaim *string `json:"JWT_ROLE_CLAIM"` - GoogleClientID *string `json:"GOOGLE_CLIENT_ID"` - GoogleClientSecret *string `json:"GOOGLE_CLIENT_SECRET"` - GithubClientID *string `json:"GITHUB_CLIENT_ID"` - GithubClientSecret *string `json:"GITHUB_CLIENT_SECRET"` - FacebookClientID *string `json:"FACEBOOK_CLIENT_ID"` - FacebookClientSecret *string `json:"FACEBOOK_CLIENT_SECRET"` - LinkedinClientID *string `json:"LINKEDIN_CLIENT_ID"` - LinkedinClientSecret *string `json:"LINKEDIN_CLIENT_SECRET"` - AppleClientID *string `json:"APPLE_CLIENT_ID"` - AppleClientSecret *string `json:"APPLE_CLIENT_SECRET"` - OrganizationName *string `json:"ORGANIZATION_NAME"` - OrganizationLogo *string `json:"ORGANIZATION_LOGO"` + AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` + AdminSecret *string `json:"ADMIN_SECRET"` + CustomAccessTokenScript *string `json:"CUSTOM_ACCESS_TOKEN_SCRIPT"` + OldAdminSecret *string `json:"OLD_ADMIN_SECRET"` + SMTPHost *string `json:"SMTP_HOST"` + SMTPPort *string `json:"SMTP_PORT"` + SMTPUsername *string `json:"SMTP_USERNAME"` + SMTPPassword *string `json:"SMTP_PASSWORD"` + SenderEmail *string `json:"SENDER_EMAIL"` + JwtType *string `json:"JWT_TYPE"` + JwtSecret *string `json:"JWT_SECRET"` + JwtPrivateKey *string `json:"JWT_PRIVATE_KEY"` + JwtPublicKey *string `json:"JWT_PUBLIC_KEY"` + AllowedOrigins []string `json:"ALLOWED_ORIGINS"` + AppURL *string `json:"APP_URL"` + ResetPasswordURL *string `json:"RESET_PASSWORD_URL"` + DisableEmailVerification *bool `json:"DISABLE_EMAIL_VERIFICATION"` + DisableBasicAuthentication *bool `json:"DISABLE_BASIC_AUTHENTICATION"` + DisableMagicLinkLogin *bool `json:"DISABLE_MAGIC_LINK_LOGIN"` + DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"` + DisableSignUp *bool `json:"DISABLE_SIGN_UP"` + DisableRedisForEnv *bool `json:"DISABLE_REDIS_FOR_ENV"` + DisableStrongPassword *bool `json:"DISABLE_STRONG_PASSWORD"` + EnforceMultiFactorAuthentication *bool `json:"ENFORCE_MULTI_FACTOR_AUTHENTICATION"` + Roles []string `json:"ROLES"` + ProtectedRoles []string `json:"PROTECTED_ROLES"` + DefaultRoles []string `json:"DEFAULT_ROLES"` + JwtRoleClaim *string `json:"JWT_ROLE_CLAIM"` + GoogleClientID *string `json:"GOOGLE_CLIENT_ID"` + GoogleClientSecret *string `json:"GOOGLE_CLIENT_SECRET"` + GithubClientID *string `json:"GITHUB_CLIENT_ID"` + GithubClientSecret *string `json:"GITHUB_CLIENT_SECRET"` + FacebookClientID *string `json:"FACEBOOK_CLIENT_ID"` + FacebookClientSecret *string `json:"FACEBOOK_CLIENT_SECRET"` + LinkedinClientID *string `json:"LINKEDIN_CLIENT_ID"` + LinkedinClientSecret *string `json:"LINKEDIN_CLIENT_SECRET"` + AppleClientID *string `json:"APPLE_CLIENT_ID"` + AppleClientSecret *string `json:"APPLE_CLIENT_SECRET"` + OrganizationName *string `json:"ORGANIZATION_NAME"` + OrganizationLogo *string `json:"ORGANIZATION_LOGO"` } type UpdateProfileInput struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index e3d6908..ddab91c 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -124,6 +124,7 @@ type Env { DISABLE_SIGN_UP: Boolean! DISABLE_REDIS_FOR_ENV: Boolean! DISABLE_STRONG_PASSWORD: Boolean! + ENFORCE_MULTI_FACTOR_AUTHENTICATION: Boolean! ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] @@ -224,6 +225,7 @@ input UpdateEnvInput { DISABLE_SIGN_UP: Boolean DISABLE_REDIS_FOR_ENV: Boolean DISABLE_STRONG_PASSWORD: Boolean + ENFORCE_MULTI_FACTOR_AUTHENTICATION: Boolean ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] diff --git a/server/memorystore/memory_store.go b/server/memorystore/memory_store.go index 9cbbbb4..dc7a195 100644 --- a/server/memorystore/memory_store.go +++ b/server/memorystore/memory_store.go @@ -25,13 +25,14 @@ func InitMemStore() error { constants.EnvKeyOrganizationLogo: "https://www.authorizer.dev/images/logo.png", // boolean envs - constants.EnvKeyDisableBasicAuthentication: false, - constants.EnvKeyDisableMagicLinkLogin: false, - constants.EnvKeyDisableEmailVerification: false, - constants.EnvKeyDisableLoginPage: false, - constants.EnvKeyDisableSignUp: false, - constants.EnvKeyDisableStrongPassword: false, - constants.EnvKeyIsEmailServiceEnabled: false, + constants.EnvKeyDisableBasicAuthentication: false, + constants.EnvKeyDisableMagicLinkLogin: false, + constants.EnvKeyDisableEmailVerification: false, + constants.EnvKeyDisableLoginPage: false, + constants.EnvKeyDisableSignUp: false, + constants.EnvKeyDisableStrongPassword: false, + constants.EnvKeyIsEmailServiceEnabled: false, + constants.EnvKeyEnforceMultiFactorAuthentication: false, } requiredEnvs := RequiredEnvStoreObj.GetRequiredEnv() diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index ad617af..d035312 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -39,6 +39,7 @@ func (s *SessionStore) Set(key string, subKey, value string) { func (s *SessionStore) RemoveAll(key string) { s.mutex.Lock() defer s.mutex.Unlock() + delete(s.store, key) } @@ -53,6 +54,9 @@ func (s *SessionStore) Remove(key, subKey string) { // Get all the values for given key func (s *SessionStore) GetAll(key string) map[string]string { + s.mutex.Lock() + defer s.mutex.Unlock() + if _, ok := s.store[key]; !ok { s.store[key] = make(map[string]string) } @@ -63,6 +67,7 @@ func (s *SessionStore) GetAll(key string) map[string]string { func (s *SessionStore) RemoveByNamespace(namespace string) error { s.mutex.Lock() defer s.mutex.Unlock() + for key := range s.store { if strings.Contains(key, namespace+":") { delete(s.store, key) diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index d6ee1df..4fb1206 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -160,7 +160,7 @@ func (c *provider) GetEnvStore() (map[string]interface{}, error) { return nil, err } for key, value := range data { - if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp || key == constants.EnvKeyDisableStrongPassword || key == constants.EnvKeyIsEmailServiceEnabled { + if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp || key == constants.EnvKeyDisableStrongPassword || key == constants.EnvKeyIsEmailServiceEnabled || key == constants.EnvKeyEnforceMultiFactorAuthentication { boolValue, err := strconv.ParseBool(value) if err != nil { return res, err diff --git a/server/resolvers/env.go b/server/resolvers/env.go index 3abda07..e97fac4 100644 --- a/server/resolvers/env.go +++ b/server/resolvers/env.go @@ -170,6 +170,7 @@ func EnvResolver(ctx context.Context) (*model.Env, error) { res.DisableLoginPage = store[constants.EnvKeyDisableLoginPage].(bool) res.DisableSignUp = store[constants.EnvKeyDisableSignUp].(bool) res.DisableStrongPassword = store[constants.EnvKeyDisableStrongPassword].(bool) + res.EnforceMultiFactorAuthentication = store[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) return res, nil } diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 7d1f28e..6325a6a 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -2,7 +2,6 @@ package resolvers import ( "context" - "errors" "fmt" "strings" "time" @@ -100,12 +99,13 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes scope = params.Scope } - if refs.BoolValue(user.IsMultiFactorAuthEnabled) { - isEnvServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) - if err != nil || !isEnvServiceEnabled { - log.Debug("Email service not enabled:") - return nil, errors.New("email service not enabled") - } + isEmailServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) + if err != nil || !isEmailServiceEnabled { + log.Debug("Email service not enabled: ", err) + } + + // If email service is not enabled continue the process in any way + if refs.BoolValue(user.IsMultiFactorAuthEnabled) && isEmailServiceEnabled { otp := utils.GenerateOTP() otpData, err := db.Provider.UpsertOTP(ctx, &models.OTP{ Email: user.Email, diff --git a/server/resolvers/resend_otp.go b/server/resolvers/resend_otp.go index 60367c1..edd7445 100644 --- a/server/resolvers/resend_otp.go +++ b/server/resolvers/resend_otp.go @@ -9,10 +9,12 @@ import ( log "github.com/sirupsen/logrus" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/utils" ) @@ -44,6 +46,12 @@ func ResendOTPResolver(ctx context.Context, params model.ResendOTPRequest) (*mod return nil, fmt.Errorf(`multi factor authentication not enabled`) } + isEmailServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) + if err != nil || !isEmailServiceEnabled { + log.Debug("Email service not enabled: ", err) + return nil, errors.New("email service not enabled") + } + // get otp by email otpData, err := db.Provider.GetOTPByEmail(ctx, params.Email) if err != nil { diff --git a/server/resolvers/update_env.go b/server/resolvers/update_env.go index 30abe9e..deecc20 100644 --- a/server/resolvers/update_env.go +++ b/server/resolvers/update_env.go @@ -270,8 +270,6 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model } } - go clearSessionIfRequired(currentData, updatedData) - // Update local store memorystore.Provider.UpdateEnvStore(updatedData) jwk, err := crypto.GenerateJWKBasedOnEnv() @@ -325,6 +323,8 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model return res, err } + go clearSessionIfRequired(currentData, updatedData) + res = &model.Response{ Message: "configurations updated successfully", } diff --git a/server/resolvers/update_profile.go b/server/resolvers/update_profile.go index 0a47376..ebc4718 100644 --- a/server/resolvers/update_profile.go +++ b/server/resolvers/update_profile.go @@ -96,7 +96,6 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) } if params.IsMultiFactorAuthEnabled != nil && refs.BoolValue(user.IsMultiFactorAuthEnabled) != refs.BoolValue(params.IsMultiFactorAuthEnabled) { - user.IsMultiFactorAuthEnabled = params.IsMultiFactorAuthEnabled if refs.BoolValue(params.IsMultiFactorAuthEnabled) { isEnvServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) if err != nil || !isEnvServiceEnabled { @@ -104,6 +103,8 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) return nil, errors.New("email service not enabled, so cannot enable multi factor authentication") } } + + user.IsMultiFactorAuthEnabled = params.IsMultiFactorAuthEnabled } isPasswordChanging := false diff --git a/server/test/resend_otp_test.go b/server/test/resend_otp_test.go index 3202d9e..2ba256c 100644 --- a/server/test/resend_otp_test.go +++ b/server/test/resend_otp_test.go @@ -14,9 +14,9 @@ import ( func resendOTPTest(t *testing.T, s TestSetup) { t.Helper() - t.Run(`should verify otp`, func(t *testing.T) { + t.Run(`should resend otp`, func(t *testing.T) { req, ctx := createContext(s) - email := "verify_otp." + s.TestInfo.Email + email := "resend_otp." + s.TestInfo.Email res, err := resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index f10576e..17373fc 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -33,7 +33,7 @@ func TestResolvers(t *testing.T) { if utils.StringSliceContains(testDBs, constants.DbTypeSqlite) && len(testDBs) == 1 { // do nothing } else { - t.Log("waiting for docker containers to spun up") + t.Log("waiting for docker containers to start...") // wait for docker containers to spun up time.Sleep(30 * time.Second) } @@ -116,6 +116,8 @@ func TestResolvers(t *testing.T) { validateJwtTokenTest(t, s) verifyOTPTest(t, s) resendOTPTest(t, s) + + updateAllUsersTest(t, s) webhookLogsTest(t, s) // get logs after above resolver tests are done deleteWebhookTest(t, s) // delete webhooks (admin resolver) }) diff --git a/server/test/update_all_users_tests.go b/server/test/update_all_users_tests.go new file mode 100644 index 0000000..6473908 --- /dev/null +++ b/server/test/update_all_users_tests.go @@ -0,0 +1,67 @@ +package test + +import ( + "fmt" + "testing" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/refs" + "github.com/authorizerdev/authorizer/server/utils" + "github.com/stretchr/testify/assert" +) + +func updateAllUsersTest(t *testing.T, s TestSetup) { + t.Helper() + t.Run("Should update all users", func(t *testing.T) { + _, ctx := createContext(s) + + users := []models.User{} + for i := 0; i < 10; i++ { + user := models.User{ + Email: fmt.Sprintf("update_all_user_%d_%s", i, s.TestInfo.Email), + SignupMethods: constants.AuthRecipeMethodBasicAuth, + Roles: "user", + } + users = append(users, user) + u, err := db.Provider.AddUser(ctx, user) + assert.NoError(t, err) + assert.NotNil(t, u) + } + + err := db.Provider.UpdateUsers(ctx, map[string]interface{}{ + "is_multi_factor_auth_enabled": true, + }, nil) + assert.NoError(t, err) + + listUsers, err := db.Provider.ListUsers(ctx, model.Pagination{ + Limit: 20, + Offset: 0, + }) + assert.NoError(t, err) + for _, u := range listUsers.Users { + assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) + } + + // // update few users + updateIds := []string{listUsers.Users[0].ID, listUsers.Users[1].ID} + err = db.Provider.UpdateUsers(ctx, map[string]interface{}{ + "is_multi_factor_auth_enabled": false, + }, updateIds) + assert.NoError(t, err) + + listUsers, err = db.Provider.ListUsers(ctx, model.Pagination{ + Limit: 20, + Offset: 0, + }) + for _, u := range listUsers.Users { + if utils.StringSliceContains(updateIds, u.ID) { + assert.False(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) + } else { + assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) + } + } + }) +} diff --git a/server/test/verify_otp_test.go b/server/test/verify_otp_test.go index afb7e2e..9e074cd 100644 --- a/server/test/verify_otp_test.go +++ b/server/test/verify_otp_test.go @@ -44,9 +44,11 @@ func verifyOTPTest(t *testing.T, s TestSetup) { // Using access token update profile s.GinContext.Request.Header.Set("Authorization", "Bearer "+refs.StringValue(verifyRes.AccessToken)) ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) - _, err = resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ + updateProfileRes, err := resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ IsMultiFactorAuthEnabled: refs.NewBoolRef(true), }) + assert.NoError(t, err) + assert.NotEmpty(t, updateProfileRes.Message) // Login should not return error but access token should be empty as otp should have been sent loginRes, err = resolvers.LoginResolver(ctx, model.LoginInput{ diff --git a/server/utils/webhook.go b/server/utils/webhook.go index 041a542..acacfbf 100644 --- a/server/utils/webhook.go +++ b/server/utils/webhook.go @@ -11,6 +11,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" log "github.com/sirupsen/logrus" ) @@ -52,6 +53,22 @@ func RegisterEvent(ctx context.Context, eventName string, authRecipe string, use return err } + // dont trigger webhook call in case of test + envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) + if err != nil { + return err + } + if envKey == constants.TestEnv { + db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + HttpStatus: 200, + Request: string(requestBody), + Response: string(`{"message": "test"}`), + WebhookID: webhook.ID, + }) + + return nil + } + requestBytesBuffer := bytes.NewBuffer(requestBody) req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer) if err != nil {