added dynamic params and seprate the update logic for common use

This commit is contained in:
manoj 2022-12-17 15:13:45 +05:30
parent 2d968309bb
commit 1b2483d47f
18 changed files with 182 additions and 195 deletions

View File

@ -1,8 +1,6 @@
package db package db
import ( import (
"fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@ -21,7 +19,6 @@ var Provider providers.Provider
func InitDB() error { func InitDB() error {
var err error var err error
fmt.Println("isCouchbaseDB::: InitDB")
envs := memorystore.RequiredEnvStoreObj.GetRequiredEnv() envs := memorystore.RequiredEnvStoreObj.GetRequiredEnv()
@ -80,13 +77,11 @@ func InitDB() error {
if isCouchbaseDB { if isCouchbaseDB {
log.Info("Initializing CouchbaseDB Driver for: ", envs.DatabaseType) log.Info("Initializing CouchbaseDB Driver for: ", envs.DatabaseType)
Provider, err = couchbase.NewProvider() Provider, err = couchbase.NewProvider()
fmt.Println("isCouchbaseDB", Provider)
if err != nil { if err != nil {
log.Fatal("Failed to initialize Couchbase driver: ", err) log.Fatal("Failed to initialize Couchbase driver: ", err)
return err return err
} }
} }
fmt.Println("isCouchbaseDB:::", Provider)
return nil return nil
} }

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"reflect"
"strings" "strings"
"time" "time"
@ -39,7 +38,6 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em
// UpdateEmailTemplate to update EmailTemplate // UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
scope := p.db.Scope("_default")
bytes, err := json.Marshal(emailTemplate) bytes, err := json.Marshal(emailTemplate)
if err != nil { if err != nil {
return nil, err return nil, err
@ -53,33 +51,15 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models
return nil, err return nil, err
} }
updateFields := "" updateFields, params := GetSetFields(emailTemplateMap)
for key, value := range emailTemplateMap { params["emailId"] = emailTemplate.ID
if key == "_id" {
continue
}
if key == "_key" { query := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id = $emailId", models.Collections.EmailTemplate, updateFields)
continue
}
if value == nil { _, err = p.db.Scope("_default").Query(query, &gocb.QueryOptions{
updateFields += fmt.Sprintf("%s = null,", key) Context: ctx,
continue NamedParameters: params,
} })
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id = '%s'", models.Collections.EmailTemplate, updateFields, emailTemplate.ID)
_, err = scope.Query(query, &gocb.QueryOptions{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,14 +69,17 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models
// ListEmailTemplates to list EmailTemplate // ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
emailTemplates := []*model.EmailTemplate{} emailTemplates := []*model.EmailTemplate{}
// r := p.db.Collection(models.Collections.User).
paginationClone := pagination paginationClone := pagination
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
userQuery := fmt.Sprintf("SELECT _id, event_name, subject, design, template, created_at, updated_at FROM auth._default.%s ORDER BY _id OFFSET %d LIMIT %d", models.Collections.EmailTemplate, paginationClone.Offset, paginationClone.Limit)
_, paginationClone.Total = GetTotalDocs(ctx, scope, models.Collections.EmailTemplate)
userQuery := fmt.Sprintf("SELECT _id, event_name, subject, design, template, created_at, updated_at FROM auth._default.%s ORDER BY _id OFFSET $1 LIMIT $2", models.Collections.EmailTemplate)
queryResult, err := scope.Query(userQuery, &gocb.QueryOptions{ queryResult, err := scope.Query(userQuery, &gocb.QueryOptions{
Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
PositionalParameters: []interface{}{paginationClone.Offset, paginationClone.Limit},
}) })
if err != nil { if err != nil {
@ -126,11 +109,13 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin
// GetEmailTemplateByID to get EmailTemplate by id // GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) { func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) {
emailTemplate := models.EmailTemplate{} emailTemplate := models.EmailTemplate{}
time.Sleep(200 * time.Millisecond)
scope := p.db.Scope("_default") query := fmt.Sprintf(`SELECT _id, event_name, subject, design, template, created_at, updated_at FROM auth._default.%s WHERE _id = $1 LIMIT 1`, models.Collections.EmailTemplate)
query := fmt.Sprintf(`SELECT _id, event_name, subject, design, template, created_at, updated_at FROM auth._default.%s WHERE _id = '%s' LIMIT 1`, models.Collections.EmailTemplate, emailTemplateID) q, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{
q, err := scope.Query(query, &gocb.QueryOptions{}) Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
PositionalParameters: []interface{}{emailTemplateID},
})
if err != nil { if err != nil {
return nil, err return nil, err
@ -147,12 +132,12 @@ func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID str
// GetEmailTemplateByEventName to get EmailTemplate by event_name // GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) { func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) {
emailTemplate := models.EmailTemplate{} emailTemplate := models.EmailTemplate{}
time.Sleep(200 * time.Millisecond)
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
query := fmt.Sprintf("SELECT _id, event_name, subject, design, template, created_at, updated_at FROM auth._default.%s WHERE event_name=$1 LIMIT 1", models.Collections.EmailTemplate) query := fmt.Sprintf("SELECT _id, event_name, subject, design, template, created_at, updated_at FROM auth._default.%s WHERE event_name=$1 LIMIT 1", models.Collections.EmailTemplate)
q, err := scope.Query(query, &gocb.QueryOptions{ q, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
PositionalParameters: []interface{}{eventName}, PositionalParameters: []interface{}{eventName},
}) })
@ -161,7 +146,6 @@ func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName st
} }
err = q.One(&emailTemplate) err = q.One(&emailTemplate)
time.Sleep(20 * time.Second)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -34,8 +34,11 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
updateEnvQuery := fmt.Sprintf("UPDATE auth._default.%s SET env = '%s', updated_at = %d WHERE _id = '%s'", models.Collections.Env, env.EnvData, env.UpdatedAt, env.ID) updateEnvQuery := fmt.Sprintf("UPDATE auth._default.%s SET env = $1, updated_at = $2 WHERE _id = $3", models.Collections.Env)
_, err := scope.Query(updateEnvQuery, &gocb.QueryOptions{}) _, err := scope.Query(updateEnvQuery, &gocb.QueryOptions{
Context: ctx,
PositionalParameters: []interface{}{env.EnvData, env.UpdatedAt, env.UpdatedAt, env.ID},
})
if err != nil { if err != nil {
return env, err return env, err
@ -49,7 +52,9 @@ func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env var env models.Env
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s LIMIT 1", models.Collections.Env) query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s LIMIT 1", models.Collections.Env)
q, err := scope.Query(query, &gocb.QueryOptions{}) q, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx,
})
if err != nil { if err != nil {
return env, err return env, err
} }

View File

@ -12,29 +12,6 @@ import (
// UpsertOTP to add or update otp // UpsertOTP to add or update otp
func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) {
// otp, _ = p.GetOTPByEmail(ctx, otp.Email)
// if otp == nil {
// id := uuid.NewString()
// otp = &models.OTP{
// ID: id,
// Key: id,
// Otp: otp.Otp,
// Email: otp.Email,
// ExpiresAt: otp.ExpiresAt,
// CreatedAt: time.Now().Unix(),
// }
// }
// otp.UpdatedAt = time.Now().Unix()
// unsertOpt := gocb.UpsertOptions{
// Context: ctx,
// }
// _, err := p.db.Collection(models.Collections.OTP).Upsert(otp.ID, otp, &unsertOpt)
// if err != nil {
// return nil, err
// }
// return otp, nil
otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) otp, _ := p.GetOTPByEmail(ctx, otpParam.Email)
shouldCreate := false shouldCreate := false
@ -63,9 +40,10 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models
return otp, err return otp, err
} }
} else { } else {
query := fmt.Sprintf(`UPDATE auth._default.%s SET otp="%s", expires_at=%d, updated_at=%d WHERE _id="%s"`, models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID) query := fmt.Sprintf(`UPDATE auth._default.%s SET otp=$1, expires_at=$2, updated_at=$3 WHERE _id=$4`, models.Collections.OTP)
scope := p.db.Scope("_default") _, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{
_, err := scope.Query(query, &gocb.QueryOptions{}) PositionalParameters: []interface{}{otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID},
})
if err != nil { if err != nil {
return otp, err return otp, err
} }
@ -76,9 +54,10 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models
// GetOTPByEmail to get otp for a given email address // GetOTPByEmail to get otp for a given email address
func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) {
otp := models.OTP{} otp := models.OTP{}
query := fmt.Sprintf(`SELECT _id, email, otp, expires_at, created_at, updated_at FROM auth._default.%s WHERE email = '%s' LIMIT 1`, models.Collections.OTP, emailAddress) query := fmt.Sprintf(`SELECT _id, email, otp, expires_at, created_at, updated_at FROM auth._default.%s WHERE email = $1 LIMIT 1`, models.Collections.OTP)
q, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{ q, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
PositionalParameters: []interface{}{emailAddress},
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -1,7 +1,6 @@
package couchbase package couchbase
import ( import (
"fmt"
"os" "os"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@ -22,7 +21,6 @@ func NewProvider() (*provider, error) {
dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL
userName := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseUsername userName := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseUsername
password := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabasePassword password := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabasePassword
fmt.Println("dbURL", dbURL, userName, password)
opts := gocb.ClusterOptions{ opts := gocb.ClusterOptions{
Username: userName, Username: userName,
Password: password, Password: password,
@ -33,7 +31,6 @@ func NewProvider() (*provider, error) {
return nil, err return nil, err
} }
bucket := cluster.Bucket(bucketName) bucket := cluster.Bucket(bucketName)
// fmt.Println("1 called in oprovuider")
// v := reflect.ValueOf(models.Collections) // v := reflect.ValueOf(models.Collections)
// fmt.Println("called in v", v) // fmt.Println("called in v", v)

View File

@ -0,0 +1,65 @@
package couchbase
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/couchbase/gocb/v2"
)
func GetSetFields(webhookMap map[string]interface{}) (string, map[string]interface{}) {
params := make(map[string]interface{}, 1)
updateFields := ""
for key, value := range webhookMap {
if key == "_id" {
continue
}
if key == "_key" {
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s=$%s,", key, key)
params[key] = "null"
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = $%s, ", key, key)
params[key] = value.(string)
} else {
updateFields += fmt.Sprintf("%s = $%s, ", key, key)
params[key] = value
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
return updateFields, params
}
func GetTotalDocs(ctx context.Context, scope *gocb.Scope, collection string) (error, int64) {
totalDocs := TotalDocs{}
countQuery := fmt.Sprintf("SELECT COUNT(*) as Total FROM auth._default.%s", collection)
queryRes, err := scope.Query(countQuery, &gocb.QueryOptions{
Context: ctx,
})
queryRes.One(&totalDocs)
if err != nil {
return err, totalDocs.Total
}
return nil, totalDocs.Total
}
type TotalDocs struct {
Total int64
}

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"reflect"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@ -59,14 +58,7 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
removeOpt := gocb.RemoveOptions{ removeOpt := gocb.RemoveOptions{
Context: ctx, Context: ctx,
} }
_, err := p.db.Collection(models.Collections.User).Remove(user.ID, &removeOpt) _, err := p.db.Collection(models.Collections.User).Remove(user.ID, &removeOpt)
// query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values)
// sessionCollection := p.db.Collection(models.Collections.Session).Queue()
// _, err = sessionCollection.DeleteMany(ctx, bson.M{"user_id": user.ID}, options.Delete())
// if err != nil {
// return err
// }
if err != nil { if err != nil {
return err return err
} }
@ -76,17 +68,18 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
// ListUsers to get list of users from database // ListUsers to get list of users from database
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
users := []*model.User{} users := []*model.User{}
// r := p.db.Collection(models.Collections.User).
paginationClone := pagination paginationClone := pagination
scope := p.db.Scope("_default")
userQuery := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM auth._default.%s ORDER BY id OFFSET $1 LIMIT $2", models.Collections.User)
inventoryScope := p.db.Scope("_default") queryResult, err := scope.Query(userQuery, &gocb.QueryOptions{
userQuery := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM auth._default.%s ORDER BY id OFFSET %d LIMIT %d", models.Collections.User, paginationClone.Offset, paginationClone.Limit)
queryResult, err := inventoryScope.Query(userQuery, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
Context: ctx, Context: ctx,
PositionalParameters: []interface{}{paginationClone.Offset, paginationClone.Limit},
}) })
_, paginationClone.Total = GetTotalDocs(ctx, scope, models.Collections.User)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -114,11 +107,11 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
user := models.User{} user := models.User{}
scope := p.db.Scope("_default") query := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM auth._default.%s WHERE email = $1 LIMIT 1", models.Collections.User)
query := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM auth._default.%s WHERE email = '%s' LIMIT 1", models.Collections.User, email) q, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{
q, err := scope.Query(query, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
Context: ctx, Context: ctx,
PositionalParameters: []interface{}{email},
}) })
if err != nil { if err != nil {
@ -135,11 +128,11 @@ func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.Use
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
user := models.User{} user := models.User{}
scope := p.db.Scope("_default") query := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM auth._default.%s WHERE _id = $1 LIMIT 1", models.Collections.User)
query := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM auth._default.%s WHERE _id = '%s' LIMIT 1", models.Collections.User, id) q, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{
q, err := scope.Query(query, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
Context: ctx, Context: ctx,
PositionalParameters: []interface{}{id},
}) })
if err != nil { if err != nil {
return user, err return user, err
@ -157,39 +150,18 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err
func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error {
// set updated_at time for all users // set updated_at time for all users
data["updated_at"] = time.Now().Unix() data["updated_at"] = time.Now().Unix()
inventoryScope := p.db.Scope("_default")
upf := "" updateFields, params := GetSetFields(data)
for key, value := range data {
if key == "_id" {
continue
}
if key == "_key" {
continue
}
if value == nil {
upf += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
upf += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
upf += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields := removeLastRune(upf)
if ids != nil && len(ids) > 0 { if ids != nil && len(ids) > 0 {
for _, v := range ids { for _, id := range ids {
userQuery := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id = '%s'", models.Collections.User, updateFields, v) params["id"] = id
userQuery := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id = $id", models.Collections.User, updateFields)
_, err := inventoryScope.Query(userQuery, &gocb.QueryOptions{ _, err := p.db.Scope("_default").Query(userQuery, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
Context: ctx, Context: ctx,
NamedParameters: params,
}) })
if err != nil { if err != nil {
return err return err
@ -197,9 +169,10 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{},
} }
} else { } else {
userQuery := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id IS NOT NULL", models.Collections.User, updateFields) userQuery := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id IS NOT NULL", models.Collections.User, updateFields)
_, err := inventoryScope.Query(userQuery, &gocb.QueryOptions{ _, err := p.db.Scope("_default").Query(userQuery, &gocb.QueryOptions{
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
Context: ctx, Context: ctx,
NamedParameters: params,
}) })
if err != nil { if err != nil {
return err return err
@ -208,7 +181,3 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{},
return nil return nil
} }
func removeLastRune(s string) string {
return s[:len(s)-2]
}

View File

@ -60,11 +60,8 @@ func (p *provider) GetVerificationRequestByToken(ctx context.Context, token stri
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
scope := p.db.Scope("_default")
time.Sleep(200 * time.Millisecond)
query := fmt.Sprintf("SELECT _id, identifier, token, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM auth._default.%s WHERE email=$1 AND identifier=$2 LIMIT 1", models.Collections.VerificationRequest) query := fmt.Sprintf("SELECT _id, identifier, token, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM auth._default.%s WHERE email=$1 AND identifier=$2 LIMIT 1", models.Collections.VerificationRequest)
queryResult, err := scope.Query(query, &gocb.QueryOptions{ queryResult, err := p.db.Scope("_default").Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
PositionalParameters: []interface{}{email, identifier}, PositionalParameters: []interface{}{email, identifier},
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
@ -89,10 +86,13 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
paginationClone := pagination paginationClone := pagination
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s OFFSET %d LIMIT %d", models.Collections.VerificationRequest, paginationClone.Offset, paginationClone.Limit) _, paginationClone.Total = GetTotalDocs(ctx, scope, models.Collections.VerificationRequest)
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s OFFSET $1 LIMIT $2", models.Collections.VerificationRequest)
queryResult, err := scope.Query(query, &gocb.QueryOptions{ queryResult, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
PositionalParameters: []interface{}{paginationClone.Offset, paginationClone.Limit},
}) })
if err != nil { if err != nil {

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"reflect"
"strings" "strings"
"time" "time"
@ -37,6 +36,9 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
// params := make(map[string]interface{}, 1)
// params["webhook_id"] = webhook.ID
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
@ -53,35 +55,13 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
return nil, err return nil, err
} }
updateFields := "" updateFields, params := GetSetFields(webhookMap)
for key, value := range webhookMap {
if key == "_id" {
continue
}
if key == "_key" { query := fmt.Sprintf(`UPDATE auth._default.%s SET %s WHERE _id='%s'`, models.Collections.Webhook, updateFields, webhook.ID)
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE auth._default.%s SET %s WHERE _id = '%s'", models.Collections.Webhook, updateFields, webhook.ID)
_, err = scope.Query(query, &gocb.QueryOptions{ _, err = scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, NamedParameters: params,
}) })
if err != nil { if err != nil {
@ -94,13 +74,21 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
// ListWebhooks to list webhook // ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{} webhooks := []*model.Webhook{}
scope := p.db.Scope("_default")
paginationClone := pagination paginationClone := pagination
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s OFFSET %d LIMIT %d", models.Collections.Webhook, paginationClone.Offset, paginationClone.Limit) scope := p.db.Scope("_default")
params := make(map[string]interface{}, 1)
params["offset"] = paginationClone.Offset
params["limit"] = paginationClone.Limit
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s OFFSET $offset LIMIT $limit", models.Collections.Webhook)
_, paginationClone.Total = GetTotalDocs(ctx, scope, models.Collections.Webhook)
queryResult, err := scope.Query(query, &gocb.QueryOptions{ queryResult, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params,
}) })
if err != nil { if err != nil {
@ -129,10 +117,13 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook var webhook models.Webhook
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM auth._default.%s WHERE _id = '%s' LIMIT 1`, models.Collections.Webhook, webhookID) params := make(map[string]interface{}, 1)
params["_id"] = webhookID
query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM auth._default.%s WHERE _id=$_id LIMIT 1`, models.Collections.Webhook)
q, err := scope.Query(query, &gocb.QueryOptions{ q, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -149,11 +140,14 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook var webhook models.Webhook
params := make(map[string]interface{}, 1)
params["event_name"] = eventName
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM auth._default.%s WHERE event_name = '%s' LIMIT 1`, models.Collections.Webhook, eventName) query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM auth._default.%s WHERE event_name=$event_name LIMIT 1`, models.Collections.Webhook)
q, err := scope.Query(query, &gocb.QueryOptions{ q, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params,
}) })
if err != nil { if err != nil {
@ -170,8 +164,9 @@ func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string)
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error { func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
fmt.Println("trying to dlete webhooks logs", webhook.EventName)
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
params := make(map[string]interface{}, 1)
params["webhook_id"] = webhook.ID
removeOpt := gocb.RemoveOptions{ removeOpt := gocb.RemoveOptions{
Context: ctx, Context: ctx,
} }
@ -181,11 +176,11 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
return err return err
} }
query := fmt.Sprintf(`DELETE FROM auth._default.%s WHERE webhook_id=%s`, models.Collections.WebhookLog, webhook.ID) query := fmt.Sprintf(`DELETE FROM auth._default.%s WHERE webhook_id=$webhook_id`, models.Collections.WebhookLog)
fmt.Println("")
_, err = scope.Query(query, &gocb.QueryOptions{ _, err = scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params,
}) })
if err != nil { if err != nil {
return err return err

View File

@ -35,13 +35,30 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL
// ListWebhookLogs to list webhook logs // ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
var query string
var err error
webhookLogs := []*model.WebhookLog{} webhookLogs := []*model.WebhookLog{}
params := make(map[string]interface{}, 1)
scope := p.db.Scope("_default") scope := p.db.Scope("_default")
paginationClone := pagination paginationClone := pagination
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM auth._default.%s OFFSET %d LIMIT %d", models.Collections.Env, paginationClone.Offset, paginationClone.Limit)
params["webhookID"] = webhookID
params["offset"] = paginationClone.Offset
params["limit"] = paginationClone.Limit
_, paginationClone.Total = GetTotalDocs(ctx, scope, models.Collections.WebhookLog)
if webhookID != "" {
query = fmt.Sprintf(`SELECT _id, http_status, response, request, webhook_id, created_at, updated_at FROM auth._default.%s WHERE webhook_id=$webhookID`, models.Collections.WebhookLog)
} else {
query = fmt.Sprintf("SELECT _id, http_status, response, request, webhook_id, created_at, updated_at FROM auth._default.%s OFFSET $offset LIMIT $limit", models.Collections.WebhookLog)
}
queryResult, err := scope.Query(query, &gocb.QueryOptions{ queryResult, err := scope.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params,
}) })
if err != nil { if err != nil {

View File

@ -74,7 +74,6 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
log := log.WithFields(log.Fields{ log := log.WithFields(log.Fields{
"email": params.Email, "email": params.Email,
}) })
time.Sleep(500 * time.Millisecond)
// find user with email // find user with email
existingUser, err := db.Provider.GetUserByEmail(ctx, params.Email) existingUser, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil { if err != nil {

View File

@ -2,7 +2,6 @@ package test
import ( import (
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
@ -22,7 +21,6 @@ func forgotPasswordTest(t *testing.T, s TestSetup) {
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
time.Sleep(500 * time.Millisecond)
_, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{ _, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{
Email: email, Email: email,
}) })

View File

@ -5,6 +5,7 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
@ -36,7 +37,7 @@ func TestResolvers(t *testing.T) {
} else { } else {
t.Log("waiting for docker containers to start...") t.Log("waiting for docker containers to start...")
// wait for docker containers to spun up // wait for docker containers to spun up
// time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
testDb := "authorizer_test" testDb := "authorizer_test"

View File

@ -3,7 +3,6 @@ package test
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/crypto"
@ -23,7 +22,6 @@ func revokeAccessTest(t *testing.T, s TestSetup) {
Email: email, Email: email,
}) })
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(4 * time.Second)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,

View File

@ -1,7 +1,6 @@
package test package test
import ( import (
"fmt"
"testing" "testing"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@ -56,9 +55,6 @@ func signupTests(t *testing.T, s TestSetup) {
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
fmt.Println("err", err)
fmt.Println("res", res)
assert.NotNil(t, err, "should throw duplicate email error") assert.NotNil(t, err, "should throw duplicate email error")
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)

View File

@ -108,10 +108,8 @@ func testSetup() TestSetup {
memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpPassword, "test") memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpPassword, "test")
memorystore.Provider.UpdateEnvVariable(constants.EnvKeySenderEmail, "info@yopmail.com") memorystore.Provider.UpdateEnvVariable(constants.EnvKeySenderEmail, "info@yopmail.com")
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyProtectedRoles, "admin") memorystore.Provider.UpdateEnvVariable(constants.EnvKeyProtectedRoles, "admin")
fmt.Println("called test suite before")
err = db.InitDB() err = db.InitDB()
fmt.Println("called test suite")
if err != nil { if err != nil {
log.Fatal("Error loading db: ", err) log.Fatal("Error loading db: ", err)
} }

View File

@ -3,7 +3,6 @@ package test
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
@ -42,12 +41,10 @@ func updateAllUsersTest(t *testing.T, s TestSetup) {
Offset: 0, Offset: 0,
}) })
assert.NoError(t, err) assert.NoError(t, err)
time.Sleep(500 * time.Millisecond)
for _, u := range listUsers.Users { for _, u := range listUsers.Users {
assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled))
} }
time.Sleep(1 * time.Second)
// // update few users // // update few users
updateIds := []string{listUsers.Users[0].ID, listUsers.Users[1].ID} updateIds := []string{listUsers.Users[0].ID, listUsers.Users[1].ID}

View File

@ -3,7 +3,6 @@ package test
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/crypto"
@ -15,7 +14,6 @@ import (
) )
func webhookLogsTest(t *testing.T, s TestSetup) { func webhookLogsTest(t *testing.T, s TestSetup) {
time.Sleep(30 * time.Second) // add sleep for webhooklogs to get generated as they are async
t.Helper() t.Helper()
t.Run("should get webhook logs", func(t *testing.T) { t.Run("should get webhook logs", func(t *testing.T) {
req, ctx := createContext(s) req, ctx := createContext(s)
@ -25,11 +23,7 @@ func webhookLogsTest(t *testing.T, s TestSetup) {
assert.NoError(t, err) assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
time.Sleep(1 * time.Second)
webhookLogs, err := resolvers.WebhookLogsResolver(ctx, nil) webhookLogs, err := resolvers.WebhookLogsResolver(ctx, nil)
fmt.Printf("webhookLogs=========== %+v \n", webhookLogs.WebhookLogs)
time.Sleep(20 * time.Second)
fmt.Println("total documents found", len(webhookLogs.WebhookLogs))
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, len(webhookLogs.WebhookLogs), 1) assert.Greater(t, len(webhookLogs.WebhookLogs), 1)