2023-11-16 13:00:54 +00:00
|
|
|
package cassandradb
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/gocql/gocql"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
|
|
|
|
"github.com/authorizerdev/authorizer/server/db/models"
|
|
|
|
)
|
|
|
|
|
|
|
|
func (p *provider) AddAuthenticator(ctx context.Context, authenticators *models.Authenticator) (*models.Authenticator, error) {
|
|
|
|
exists, _ := p.GetAuthenticatorDetailsByUserId(ctx, authenticators.UserID, authenticators.Method)
|
|
|
|
if exists != nil {
|
|
|
|
return authenticators, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if authenticators.ID == "" {
|
|
|
|
authenticators.ID = uuid.New().String()
|
|
|
|
}
|
|
|
|
|
|
|
|
authenticators.CreatedAt = time.Now().Unix()
|
|
|
|
authenticators.UpdatedAt = time.Now().Unix()
|
|
|
|
|
|
|
|
bytes, err := json.Marshal(authenticators)
|
|
|
|
if err != nil {
|
2024-04-02 09:55:11 +00:00
|
|
|
return nil, err
|
2023-11-16 13:00:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
|
|
|
|
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
|
|
|
|
decoder.UseNumber()
|
|
|
|
authenticatorsMap := map[string]interface{}{}
|
|
|
|
err = decoder.Decode(&authenticatorsMap)
|
|
|
|
if err != nil {
|
2024-04-02 09:55:11 +00:00
|
|
|
return nil, err
|
2023-11-16 13:00:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
fields := "("
|
|
|
|
values := "("
|
|
|
|
for key, value := range authenticatorsMap {
|
|
|
|
if value != nil {
|
|
|
|
if key == "_id" {
|
|
|
|
fields += "id,"
|
|
|
|
} else {
|
|
|
|
fields += key + ","
|
|
|
|
}
|
|
|
|
|
|
|
|
valueType := reflect.TypeOf(value)
|
|
|
|
if valueType.Name() == "string" {
|
|
|
|
values += fmt.Sprintf("'%s',", value.(string))
|
|
|
|
} else {
|
|
|
|
values += fmt.Sprintf("%v,", value)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fields = fields[:len(fields)-1] + ")"
|
|
|
|
values = values[:len(values)-1] + ")"
|
|
|
|
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.Authenticators, fields, values)
|
|
|
|
err = p.db.Query(query).Exec()
|
|
|
|
if err != nil {
|
2024-04-02 09:55:11 +00:00
|
|
|
return nil, err
|
2023-11-16 13:00:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return authenticators, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (p *provider) UpdateAuthenticator(ctx context.Context, authenticators *models.Authenticator) (*models.Authenticator, error) {
|
|
|
|
authenticators.UpdatedAt = time.Now().Unix()
|
|
|
|
|
|
|
|
bytes, err := json.Marshal(authenticators)
|
|
|
|
if err != nil {
|
2024-04-02 09:55:11 +00:00
|
|
|
return nil, err
|
2023-11-16 13:00:54 +00:00
|
|
|
}
|
|
|
|
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
|
|
|
|
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
|
|
|
|
decoder.UseNumber()
|
|
|
|
authenticatorsMap := map[string]interface{}{}
|
|
|
|
err = decoder.Decode(&authenticatorsMap)
|
|
|
|
if err != nil {
|
2024-04-02 09:55:11 +00:00
|
|
|
return nil, err
|
2023-11-16 13:00:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
updateFields := ""
|
|
|
|
for key, value := range authenticatorsMap {
|
|
|
|
if key == "_id" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if key == "_key" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if value == nil {
|
|
|
|
updateFields += fmt.Sprintf("%s = null, ", key)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
valueType := reflect.TypeOf(value)
|
|
|
|
if valueType.Name() == "string" {
|
|
|
|
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
|
|
|
|
} else {
|
|
|
|
updateFields += fmt.Sprintf("%s = %v, ", key, value)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
updateFields = strings.Trim(updateFields, " ")
|
|
|
|
updateFields = strings.TrimSuffix(updateFields, ",")
|
|
|
|
|
|
|
|
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Authenticators, updateFields, authenticators.ID)
|
|
|
|
err = p.db.Query(query).Exec()
|
|
|
|
if err != nil {
|
2024-04-02 09:55:11 +00:00
|
|
|
return nil, err
|
2023-11-16 13:00:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return authenticators, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (p *provider) GetAuthenticatorDetailsByUserId(ctx context.Context, userId string, authenticatorType string) (*models.Authenticator, error) {
|
|
|
|
var authenticators models.Authenticator
|
|
|
|
query := fmt.Sprintf("SELECT id, user_id, method, secret, recovery_codes, verified_at, created_at, updated_at FROM %s WHERE user_id = '%s' AND method = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.Authenticators, userId, authenticatorType)
|
|
|
|
err := p.db.Query(query).Consistency(gocql.One).Scan(&authenticators.ID, &authenticators.UserID, &authenticators.Method, &authenticators.Secret, &authenticators.RecoveryCodes, &authenticators.VerifiedAt, &authenticators.CreatedAt, &authenticators.UpdatedAt)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &authenticators, nil
|
|
|
|
}
|