2022-04-21 07:06:22 +00:00
|
|
|
package cassandradb
|
2022-03-19 12:11:27 +00:00
|
|
|
|
|
|
|
import (
|
2022-07-10 16:19:33 +00:00
|
|
|
"context"
|
2022-04-22 11:15:49 +00:00
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"reflect"
|
2022-03-19 12:11:27 +00:00
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/authorizerdev/authorizer/server/constants"
|
|
|
|
"github.com/authorizerdev/authorizer/server/db/models"
|
|
|
|
"github.com/authorizerdev/authorizer/server/graph/model"
|
2022-05-29 11:52:46 +00:00
|
|
|
"github.com/authorizerdev/authorizer/server/memorystore"
|
2022-12-21 17:44:24 +00:00
|
|
|
"github.com/authorizerdev/authorizer/server/refs"
|
2022-04-22 11:15:49 +00:00
|
|
|
"github.com/gocql/gocql"
|
2022-03-19 12:11:27 +00:00
|
|
|
"github.com/google/uuid"
|
|
|
|
)
|
|
|
|
|
|
|
|
// AddUser to save user information in database
|
2023-07-31 11:12:11 +00:00
|
|
|
func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) {
|
2022-03-19 12:11:27 +00:00
|
|
|
if user.ID == "" {
|
|
|
|
user.ID = uuid.New().String()
|
|
|
|
}
|
|
|
|
|
|
|
|
if user.Roles == "" {
|
2022-05-31 02:44:03 +00:00
|
|
|
defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles)
|
2022-05-29 11:52:46 +00:00
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
2022-05-31 02:44:03 +00:00
|
|
|
user.Roles = defaultRoles
|
2022-03-19 12:11:27 +00:00
|
|
|
}
|
|
|
|
|
2022-12-21 17:44:24 +00:00
|
|
|
if user.PhoneNumber != nil && strings.TrimSpace(refs.StringValue(user.PhoneNumber)) != "" {
|
|
|
|
if u, _ := p.GetUserByPhoneNumber(ctx, refs.StringValue(user.PhoneNumber)); u != nil && u.ID != user.ID {
|
|
|
|
return user, fmt.Errorf("user with given phone number already exists")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-03-19 12:11:27 +00:00
|
|
|
user.CreatedAt = time.Now().Unix()
|
|
|
|
user.UpdatedAt = time.Now().Unix()
|
|
|
|
|
2022-04-22 11:15:49 +00:00
|
|
|
bytes, err := json.Marshal(user)
|
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
2022-04-22 14:26:55 +00:00
|
|
|
|
|
|
|
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
|
|
|
|
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
|
|
|
|
decoder.UseNumber()
|
2022-04-22 11:15:49 +00:00
|
|
|
userMap := map[string]interface{}{}
|
2022-04-22 14:26:55 +00:00
|
|
|
err = decoder.Decode(&userMap)
|
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
2022-04-22 11:15:49 +00:00
|
|
|
|
|
|
|
fields := "("
|
|
|
|
values := "("
|
|
|
|
for key, value := range userMap {
|
|
|
|
if value != nil {
|
2022-04-22 14:26:55 +00:00
|
|
|
if key == "_id" {
|
|
|
|
fields += "id,"
|
|
|
|
} else {
|
|
|
|
fields += key + ","
|
|
|
|
}
|
2022-04-22 11:15:49 +00:00
|
|
|
|
|
|
|
valueType := reflect.TypeOf(value)
|
2022-04-22 14:26:55 +00:00
|
|
|
if valueType.Name() == "string" {
|
|
|
|
values += fmt.Sprintf("'%s',", value.(string))
|
2022-04-22 11:15:49 +00:00
|
|
|
} else {
|
2022-04-22 14:26:55 +00:00
|
|
|
values += fmt.Sprintf("%v,", value)
|
2022-04-22 11:15:49 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fields = fields[:len(fields)-1] + ")"
|
|
|
|
values = values[:len(values)-1] + ")"
|
|
|
|
|
2022-04-22 14:26:55 +00:00
|
|
|
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values)
|
2022-04-22 11:15:49 +00:00
|
|
|
err = p.db.Query(query).Exec()
|
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
|
|
|
|
2022-03-19 12:11:27 +00:00
|
|
|
return user, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// UpdateUser to update user information in database
|
2023-07-31 11:12:11 +00:00
|
|
|
func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) {
|
2022-03-19 12:11:27 +00:00
|
|
|
user.UpdatedAt = time.Now().Unix()
|
2022-04-22 14:26:55 +00:00
|
|
|
|
2022-04-22 11:15:49 +00:00
|
|
|
bytes, err := json.Marshal(user)
|
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
2022-04-22 14:26:55 +00:00
|
|
|
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
|
|
|
|
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
|
|
|
|
decoder.UseNumber()
|
2022-04-22 11:15:49 +00:00
|
|
|
userMap := map[string]interface{}{}
|
2022-04-22 14:26:55 +00:00
|
|
|
err = decoder.Decode(&userMap)
|
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
2022-04-22 11:15:49 +00:00
|
|
|
|
|
|
|
updateFields := ""
|
|
|
|
for key, value := range userMap {
|
2022-04-22 14:26:55 +00:00
|
|
|
if key == "_id" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2022-07-12 06:18:42 +00:00
|
|
|
if key == "_key" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2022-04-22 14:26:55 +00:00
|
|
|
if value == nil {
|
2022-08-02 08:42:36 +00:00
|
|
|
updateFields += fmt.Sprintf("%s = null, ", key)
|
2022-04-22 14:26:55 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
valueType := reflect.TypeOf(value)
|
|
|
|
if valueType.Name() == "string" {
|
|
|
|
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
|
|
|
|
} else {
|
|
|
|
updateFields += fmt.Sprintf("%s = %v, ", key, value)
|
2022-04-22 11:15:49 +00:00
|
|
|
}
|
|
|
|
}
|
2022-04-22 14:26:55 +00:00
|
|
|
updateFields = strings.Trim(updateFields, " ")
|
|
|
|
updateFields = strings.TrimSuffix(updateFields, ",")
|
2022-04-22 11:15:49 +00:00
|
|
|
|
|
|
|
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID)
|
|
|
|
err = p.db.Query(query).Exec()
|
|
|
|
if err != nil {
|
|
|
|
return user, err
|
|
|
|
}
|
|
|
|
|
2022-03-19 12:11:27 +00:00
|
|
|
return user, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// DeleteUser to delete user information from database
|
2023-07-31 11:12:11 +00:00
|
|
|
func (p *provider) DeleteUser(ctx context.Context, user *models.User) error {
|
2022-04-22 11:15:49 +00:00
|
|
|
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID)
|
2022-04-22 14:26:55 +00:00
|
|
|
err := p.db.Query(query).Exec()
|
2022-07-12 03:12:32 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2022-07-12 06:18:42 +00:00
|
|
|
getSessionsQuery := fmt.Sprintf("SELECT id FROM %s WHERE user_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.Session, user.ID)
|
|
|
|
scanner := p.db.Query(getSessionsQuery).Iter().Scanner()
|
|
|
|
sessionIDs := ""
|
|
|
|
for scanner.Next() {
|
|
|
|
var wlID string
|
|
|
|
err = scanner.Scan(&wlID)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
sessionIDs += fmt.Sprintf("'%s',", wlID)
|
|
|
|
}
|
|
|
|
sessionIDs = strings.TrimSuffix(sessionIDs, ",")
|
|
|
|
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.Session, sessionIDs)
|
2022-07-12 03:12:32 +00:00
|
|
|
err = p.db.Query(deleteSessionQuery).Exec()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
2022-03-19 12:11:27 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// ListUsers to get list of users from database
|
2023-07-31 11:12:11 +00:00
|
|
|
func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) {
|
2022-04-22 11:15:49 +00:00
|
|
|
responseUsers := []*model.User{}
|
|
|
|
paginationClone := pagination
|
|
|
|
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User)
|
2023-08-01 10:39:17 +00:00
|
|
|
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
|
2022-04-22 11:15:49 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// there is no offset in cassandra
|
|
|
|
// so we fetch till limit + offset
|
|
|
|
// and return the results from offset to limit
|
2022-08-02 08:42:36 +00:00
|
|
|
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
|
2022-04-22 11:15:49 +00:00
|
|
|
scanner := p.db.Query(query).Iter().Scanner()
|
|
|
|
counter := int64(0)
|
|
|
|
for scanner.Next() {
|
|
|
|
if counter >= pagination.Offset {
|
2023-08-01 10:39:17 +00:00
|
|
|
var user models.User
|
|
|
|
err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
|
2022-04-22 11:15:49 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
responseUsers = append(responseUsers, user.AsAPIUser())
|
|
|
|
}
|
|
|
|
counter++
|
|
|
|
}
|
|
|
|
return &model.Users{
|
2023-07-31 11:12:11 +00:00
|
|
|
Pagination: paginationClone,
|
2022-04-22 11:15:49 +00:00
|
|
|
Users: responseUsers,
|
|
|
|
}, nil
|
2022-03-19 12:11:27 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// GetUserByEmail to get user information from database using email address
|
2023-07-31 11:12:11 +00:00
|
|
|
func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
|
2023-08-01 10:39:17 +00:00
|
|
|
var user models.User
|
2022-08-02 08:42:36 +00:00
|
|
|
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email)
|
|
|
|
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
|
2022-04-22 11:15:49 +00:00
|
|
|
if err != nil {
|
2023-08-01 10:39:17 +00:00
|
|
|
return nil, err
|
2022-04-22 11:15:49 +00:00
|
|
|
}
|
2023-08-01 10:39:17 +00:00
|
|
|
return &user, nil
|
2022-03-19 12:11:27 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// GetUserByID to get user information from database using user ID
|
2023-07-31 11:12:11 +00:00
|
|
|
func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) {
|
2023-08-01 10:39:17 +00:00
|
|
|
var user models.User
|
2022-08-02 08:42:36 +00:00
|
|
|
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id)
|
|
|
|
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
|
2022-04-22 11:15:49 +00:00
|
|
|
if err != nil {
|
2023-08-01 10:39:17 +00:00
|
|
|
return nil, err
|
2022-04-22 11:15:49 +00:00
|
|
|
}
|
2023-08-01 10:39:17 +00:00
|
|
|
return &user, nil
|
2022-03-19 12:11:27 +00:00
|
|
|
}
|
2022-08-02 08:42:36 +00:00
|
|
|
|
|
|
|
// UpdateUsers to update multiple users, with parameters of user IDs slice
|
|
|
|
// If ids set to nil / empty all the users will be updated
|
|
|
|
func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error {
|
|
|
|
// set updated_at time for all users
|
|
|
|
data["updated_at"] = time.Now().Unix()
|
|
|
|
|
|
|
|
updateFields := ""
|
|
|
|
for key, value := range data {
|
|
|
|
if key == "_id" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if key == "_key" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if value == nil {
|
|
|
|
updateFields += fmt.Sprintf("%s = null,", key)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
valueType := reflect.TypeOf(value)
|
|
|
|
if valueType.Name() == "string" {
|
|
|
|
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
|
|
|
|
} else {
|
|
|
|
updateFields += fmt.Sprintf("%s = %v, ", key, value)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
updateFields = strings.Trim(updateFields, " ")
|
|
|
|
updateFields = strings.TrimSuffix(updateFields, ",")
|
|
|
|
query := ""
|
2023-07-31 11:12:11 +00:00
|
|
|
if len(ids) > 0 {
|
2022-08-02 08:42:36 +00:00
|
|
|
idsString := ""
|
|
|
|
for _, id := range ids {
|
|
|
|
idsString += fmt.Sprintf("'%s', ", id)
|
|
|
|
}
|
|
|
|
idsString = strings.Trim(idsString, " ")
|
|
|
|
idsString = strings.TrimSuffix(idsString, ",")
|
|
|
|
query = fmt.Sprintf("UPDATE %s SET %s WHERE id IN (%s)", KeySpace+"."+models.Collections.User, updateFields, idsString)
|
|
|
|
err := p.db.Query(query).Exec()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// get all ids
|
|
|
|
getUserIDsQuery := fmt.Sprintf(`SELECT id FROM %s`, KeySpace+"."+models.Collections.User)
|
|
|
|
scanner := p.db.Query(getUserIDsQuery).Iter().Scanner()
|
|
|
|
// only 100 ids are allowed in 1 query
|
|
|
|
// hence we need create multiple update queries
|
|
|
|
idsString := ""
|
|
|
|
idsStringArray := []string{idsString}
|
|
|
|
counter := 1
|
|
|
|
for scanner.Next() {
|
|
|
|
var id string
|
|
|
|
err := scanner.Scan(&id)
|
|
|
|
if err == nil {
|
|
|
|
idsString += fmt.Sprintf("'%s', ", id)
|
|
|
|
}
|
|
|
|
counter++
|
|
|
|
if counter > 100 {
|
|
|
|
idsStringArray = append(idsStringArray, idsString)
|
|
|
|
counter = 1
|
|
|
|
idsString = ""
|
|
|
|
} else {
|
|
|
|
// update the last index of array when count is less than 100
|
|
|
|
idsStringArray[len(idsStringArray)-1] = idsString
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, idStr := range idsStringArray {
|
|
|
|
idStr = strings.Trim(idStr, " ")
|
|
|
|
idStr = strings.TrimSuffix(idStr, ",")
|
|
|
|
query = fmt.Sprintf("UPDATE %s SET %s WHERE id IN (%s)", KeySpace+"."+models.Collections.User, updateFields, idStr)
|
|
|
|
err := p.db.Query(query).Exec()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
2022-12-21 17:44:24 +00:00
|
|
|
|
|
|
|
// GetUserByPhoneNumber to get user information from database using phone number
|
|
|
|
func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) {
|
2023-08-01 10:39:17 +00:00
|
|
|
var user models.User
|
2022-12-21 17:44:24 +00:00
|
|
|
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE phone_number = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, phoneNumber)
|
2023-08-01 10:39:17 +00:00
|
|
|
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt)
|
2022-12-21 17:44:24 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-08-01 10:39:17 +00:00
|
|
|
return &user, nil
|
2022-12-21 17:44:24 +00:00
|
|
|
}
|