authorizer/server/db/providers/cassandradb/user.go

191 lines
6.5 KiB
Go
Raw Normal View History

2022-04-21 07:06:22 +00:00
package cassandradb
2022-03-19 12:11:27 +00:00
import (
2022-04-22 11:15:49 +00:00
"encoding/json"
"fmt"
"reflect"
2022-03-19 12:11:27 +00:00
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
2022-05-29 11:52:46 +00:00
"github.com/authorizerdev/authorizer/server/memorystore"
2022-04-22 11:15:49 +00:00
"github.com/gocql/gocql"
2022-03-19 12:11:27 +00:00
"github.com/google/uuid"
)
// AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) {
if user.ID == "" {
user.ID = uuid.New().String()
}
if user.Roles == "" {
2022-05-31 02:44:03 +00:00
defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles)
2022-05-29 11:52:46 +00:00
if err != nil {
return user, err
}
2022-05-31 02:44:03 +00:00
user.Roles = defaultRoles
2022-03-19 12:11:27 +00:00
}
user.CreatedAt = time.Now().Unix()
user.UpdatedAt = time.Now().Unix()
2022-04-22 11:15:49 +00:00
bytes, err := json.Marshal(user)
if err != nil {
return user, err
}
2022-04-22 14:26:55 +00:00
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
2022-04-22 11:15:49 +00:00
userMap := map[string]interface{}{}
2022-04-22 14:26:55 +00:00
err = decoder.Decode(&userMap)
if err != nil {
return user, err
}
2022-04-22 11:15:49 +00:00
fields := "("
values := "("
for key, value := range userMap {
if value != nil {
2022-04-22 14:26:55 +00:00
if key == "_id" {
fields += "id,"
} else {
fields += key + ","
}
2022-04-22 11:15:49 +00:00
valueType := reflect.TypeOf(value)
2022-04-22 14:26:55 +00:00
if valueType.Name() == "string" {
values += fmt.Sprintf("'%s',", value.(string))
2022-04-22 11:15:49 +00:00
} else {
2022-04-22 14:26:55 +00:00
values += fmt.Sprintf("%v,", value)
2022-04-22 11:15:49 +00:00
}
}
}
fields = fields[:len(fields)-1] + ")"
values = values[:len(values)-1] + ")"
2022-04-22 14:26:55 +00:00
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values)
2022-04-22 11:15:49 +00:00
err = p.db.Query(query).Exec()
if err != nil {
return user, err
}
2022-03-19 12:11:27 +00:00
return user, nil
}
// UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
2022-04-22 14:26:55 +00:00
2022-04-22 11:15:49 +00:00
bytes, err := json.Marshal(user)
if err != nil {
return user, err
}
2022-04-22 14:26:55 +00:00
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
2022-04-22 11:15:49 +00:00
userMap := map[string]interface{}{}
2022-04-22 14:26:55 +00:00
err = decoder.Decode(&userMap)
if err != nil {
return user, err
}
2022-04-22 11:15:49 +00:00
updateFields := ""
for key, value := range userMap {
2022-04-22 14:26:55 +00:00
if key == "_id" {
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
2022-04-22 11:15:49 +00:00
}
}
2022-04-22 14:26:55 +00:00
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
2022-04-22 11:15:49 +00:00
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID)
2022-04-22 14:26:55 +00:00
2022-04-22 11:15:49 +00:00
err = p.db.Query(query).Exec()
if err != nil {
return user, err
}
2022-03-19 12:11:27 +00:00
return user, nil
}
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error {
2022-04-22 11:15:49 +00:00
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID)
2022-04-22 14:26:55 +00:00
err := p.db.Query(query).Exec()
return err
2022-03-19 12:11:27 +00:00
}
// ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) {
2022-04-22 11:15:49 +00:00
responseUsers := []*model.User{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User)
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil {
return nil, err
}
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
2022-04-22 14:26:55 +00:00
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
2022-04-22 11:15:49 +00:00
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var user models.User
err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return nil, err
}
responseUsers = append(responseUsers, user.AsAPIUser())
}
counter++
}
return &model.Users{
Users: responseUsers,
Pagination: &paginationClone,
}, nil
2022-03-19 12:11:27 +00:00
}
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) {
var user models.User
2022-04-22 11:15:49 +00:00
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, email)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return user, err
}
2022-03-19 12:11:27 +00:00
return user, nil
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) {
var user models.User
2022-04-22 11:15:49 +00:00
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return user, err
}
2022-03-19 12:11:27 +00:00
return user, nil
}