feat: add helper for updating all users

This commit is contained in:
Lakhan Samani
2022-08-02 14:12:36 +05:30
parent 236045ac54
commit 587828b59b
41 changed files with 629 additions and 210 deletions

View File

@@ -16,21 +16,27 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models
shouldCreate := false
if otp == nil {
shouldCreate = true
otp.ID = uuid.New().String()
otp.Key = otp.ID
otp.CreatedAt = time.Now().Unix()
otp = &models.OTP{
ID: uuid.NewString(),
Otp: otpParam.Otp,
Email: otpParam.Email,
ExpiresAt: otpParam.ExpiresAt,
CreatedAt: time.Now().Unix(),
UpdatedAt: time.Now().Unix(),
}
} else {
otp = otpParam
otp.Otp = otpParam.Otp
otp.ExpiresAt = otpParam.ExpiresAt
}
otp.UpdatedAt = time.Now().Unix()
query := ""
if shouldCreate {
query = fmt.Sprintf(`INSERT INTO %s (id, email, otp, expires_at, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d, %d)`, KeySpace+"."+models.Collections.OTP, otp.ID, otp.Email, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt)
} else {
query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE email = '%s'`, KeySpace+"."+models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.Email)
query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE id = '%s'`, KeySpace+"."+models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID)
}
err := p.db.Query(query).Exec()
if err != nil {
return nil, err

View File

@@ -13,6 +13,7 @@ import (
"github.com/authorizerdev/authorizer/server/memorystore"
"github.com/gocql/gocql"
cansandraDriver "github.com/gocql/gocql"
log "github.com/sirupsen/logrus"
)
type provider struct {
@@ -99,6 +100,7 @@ func NewProvider() (*provider, error) {
cassandraClient.Consistency = gocql.LocalQuorum
cassandraClient.ConnectTimeout = 10 * time.Second
cassandraClient.ProtoVersion = 4
cassandraClient.Timeout = 30 * time.Minute // for large data
session, err := cassandraClient.CreateSession()
if err != nil {
@@ -160,10 +162,11 @@ func NewProvider() (*provider, error) {
return nil, err
}
// add is_multi_factor_auth_enabled on users table
userTableAlterQuery := fmt.Sprintf(`ALTER TABLE %s.%s ADD is_multi_factor_auth_enabled boolean;`, KeySpace, models.Collections.User)
userTableAlterQuery := fmt.Sprintf(`ALTER TABLE %s.%s ADD is_multi_factor_auth_enabled boolean`, KeySpace, models.Collections.User)
err = session.Query(userTableAlterQuery).Exec()
if err != nil {
return nil, err
log.Debug("Failed to alter table as column exists: ", err)
// return nil, err
}
// token is reserved keyword in cassandra, hence we need to use jwt_token

View File

@@ -107,7 +107,7 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
updateFields += fmt.Sprintf("%s = null, ", key)
continue
}
@@ -122,7 +122,6 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID)
err = p.db.Query(query).Exec()
if err != nil {
return user, err
@@ -173,14 +172,14 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var user models.User
err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return nil, err
}
@@ -197,8 +196,8 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return user, err
}
@@ -208,10 +207,95 @@ func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.Use
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return user, err
}
return user, nil
}
// UpdateUsers to update multiple users, with parameters of user IDs slice
// If ids set to nil / empty all the users will be updated
func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error {
// set updated_at time for all users
data["updated_at"] = time.Now().Unix()
updateFields := ""
for key, value := range data {
if key == "_id" {
continue
}
if key == "_key" {
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := ""
if ids != nil && len(ids) > 0 {
idsString := ""
for _, id := range ids {
idsString += fmt.Sprintf("'%s', ", id)
}
idsString = strings.Trim(idsString, " ")
idsString = strings.TrimSuffix(idsString, ",")
query = fmt.Sprintf("UPDATE %s SET %s WHERE id IN (%s)", KeySpace+"."+models.Collections.User, updateFields, idsString)
err := p.db.Query(query).Exec()
if err != nil {
return err
}
} else {
// get all ids
getUserIDsQuery := fmt.Sprintf(`SELECT id FROM %s`, KeySpace+"."+models.Collections.User)
scanner := p.db.Query(getUserIDsQuery).Iter().Scanner()
// only 100 ids are allowed in 1 query
// hence we need create multiple update queries
idsString := ""
idsStringArray := []string{idsString}
counter := 1
for scanner.Next() {
var id string
err := scanner.Scan(&id)
if err == nil {
idsString += fmt.Sprintf("'%s', ", id)
}
counter++
if counter > 100 {
idsStringArray = append(idsStringArray, idsString)
counter = 1
idsString = ""
} else {
// update the last index of array when count is less than 100
idsStringArray[len(idsStringArray)-1] = idsString
}
}
for _, idStr := range idsStringArray {
idStr = strings.Trim(idStr, " ")
idStr = strings.TrimSuffix(idStr, ",")
query = fmt.Sprintf("UPDATE %s SET %s WHERE id IN (%s)", KeySpace+"."+models.Collections.User, updateFields, idStr)
err := p.db.Query(query).Exec()
if err != nil {
return err
}
}
}
return nil
}