fix: tests

This commit is contained in:
Lakhan Samani 2022-04-22 19:56:55 +05:30
parent aaf0831793
commit 961f2271c1
9 changed files with 94 additions and 27 deletions

View File

@ -58,20 +58,44 @@ func NewProvider() (*provider, error) {
return nil, err return nil, err
} }
userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nick_name text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id, email))", KeySpace, models.Collections.User) userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User)
err = session.Query(userCollectionQuery).Exec() err = session.Query(userCollectionQuery).Exec()
if err != nil { if err != nil {
log.Println("Unable to create user collection:", err) log.Println("Unable to create user collection:", err)
return nil, err return nil, err
} }
userIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_user_email ON %s.%s (email)", KeySpace, models.Collections.User)
err = session.Query(userIndexQuery).Exec()
if err != nil {
log.Println("Unable to create user index:", err)
return nil, err
}
// token is reserved keyword in cassandra, hence we need to use jwt_token // token is reserved keyword in cassandra, hence we need to use jwt_token
verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id, identifier, email))", KeySpace, models.Collections.VerificationRequest) verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.VerificationRequest)
err = session.Query(verificationRequestCollectionQuery).Exec() err = session.Query(verificationRequestCollectionQuery).Exec()
if err != nil { if err != nil {
log.Println("Unable to create verification request collection:", err) log.Println("Unable to create verification request collection:", err)
return nil, err return nil, err
} }
verificationRequestIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_email ON %s.%s (email)", KeySpace, models.Collections.VerificationRequest)
err = session.Query(verificationRequestIndexQuery).Exec()
if err != nil {
log.Println("Unable to create verification_requests index:", err)
return nil, err
}
verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_identifier ON %s.%s (identifier)", KeySpace, models.Collections.VerificationRequest)
err = session.Query(verificationRequestIndexQuery).Exec()
if err != nil {
log.Println("Unable to create verification_requests index:", err)
return nil, err
}
verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_jwt_token ON %s.%s (jwt_token)", KeySpace, models.Collections.VerificationRequest)
err = session.Query(verificationRequestIndexQuery).Exec()
if err != nil {
log.Println("Unable to create verification_requests index:", err)
return nil, err
}
return &provider{ return &provider{
db: session, db: session,

View File

@ -32,20 +32,31 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
if err != nil { if err != nil {
return user, err return user, err
} }
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
userMap := map[string]interface{}{} userMap := map[string]interface{}{}
json.Unmarshal(bytes, &userMap) err = decoder.Decode(&userMap)
if err != nil {
return user, err
}
fields := "(" fields := "("
values := "(" values := "("
for key, value := range userMap { for key, value := range userMap {
if value != nil { if value != nil {
fields += key + "," if key == "_id" {
fields += "id,"
} else {
fields += key + ","
}
valueType := reflect.TypeOf(value) valueType := reflect.TypeOf(value)
if valueType.Kind() == reflect.String { if valueType.Name() == "string" {
values += "'" + value.(string) + "'," values += fmt.Sprintf("'%s',", value.(string))
} else { } else {
values += fmt.Sprintf("%v", value) + "," values += fmt.Sprintf("%v,", value)
} }
} }
} }
@ -53,7 +64,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
fields = fields[:len(fields)-1] + ")" fields = fields[:len(fields)-1] + ")"
values = values[:len(values)-1] + ")" values = values[:len(values)-1] + ")"
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s", KeySpace+"."+models.Collections.User, fields, values) query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values)
err = p.db.Query(query).Exec() err = p.db.Query(query).Exec()
if err != nil { if err != nil {
@ -66,26 +77,46 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
// UpdateUser to update user information in database // UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) { func (p *provider) UpdateUser(user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
bytes, err := json.Marshal(user) bytes, err := json.Marshal(user)
if err != nil { if err != nil {
return user, err return user, err
} }
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
userMap := map[string]interface{}{} userMap := map[string]interface{}{}
json.Unmarshal(bytes, &userMap) err = decoder.Decode(&userMap)
if err != nil {
return user, err
}
updateFields := "" updateFields := ""
for key, value := range userMap { for key, value := range userMap {
if value != nil { if value != nil && key != "_id" {
valueType := reflect.TypeOf(value) }
if valueType.Kind() == reflect.String {
updateFields += key + " = '" + value.(string) + "'," if key == "_id" {
} else { continue
updateFields += key + " = " + fmt.Sprintf("%v", value) + "," }
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
} }
} }
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID) query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID)
err = p.db.Query(query).Exec() err = p.db.Query(query).Exec()
if err != nil { if err != nil {
return user, err return user, err
@ -97,8 +128,8 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error { func (p *provider) DeleteUser(user models.User) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID) query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID)
err := p.db.Query(query).Exec()
return p.db.Query(query).Exec() return err
} }
// ListUsers to get list of users from database // ListUsers to get list of users from database
@ -114,7 +145,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
// there is no offset in cassandra // there is no offset in cassandra
// so we fetch till limit + offset // so we fetch till limit + offset
// and return the results from offset to limit // and return the results from offset to limit
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s ORDER BY created_at DESC LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner() scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0) counter := int64(0)

View File

@ -2,6 +2,7 @@ package cassandradb
import ( import (
"fmt" "fmt"
"log"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@ -31,6 +32,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token) query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token)
err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
@ -41,7 +43,8 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier) query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier)
err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
@ -58,13 +61,15 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.VerificationRequest) totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.VerificationRequest)
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil { if err != nil {
log.Println("Error while quering verification request", err)
return nil, err return nil, err
} }
// there is no offset in cassandra // there is no offset in cassandra
// so we fetch till limit + offset // so we fetch till limit + offset
// and return the results from offset to limit // and return the results from offset to limit
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s ORDER BY created_at DESC LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset) query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner() scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0) counter := int64(0)
for scanner.Next() { for scanner.Next() {
@ -72,6 +77,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
err := scanner.Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) err := scanner.Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
if err != nil { if err != nil {
log.Println("Error while parsing verification request", err)
return nil, err return nil, err
} }
verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest()) verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest())

View File

@ -107,7 +107,7 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error {
err := SendMail(Receiver, Subject, message) err := SendMail(Receiver, Subject, message)
if err != nil { if err != nil {
log.Println("=> error sending email:", err) log.Println("error sending email:", err)
} }
return err return err
} }

View File

@ -107,7 +107,7 @@ func SendVerificationMail(toEmail, token, hostname string) error {
err := SendMail(Receiver, Subject, message) err := SendMail(Receiver, Subject, message)
if err != nil { if err != nil {
log.Println("=> error sending email:", err) log.Println("error sending email:", err)
} }
return err return err
} }

View File

@ -259,7 +259,7 @@ func processGithubUserInfo(code string) (models.User, error) {
GivenName: &firstName, GivenName: &firstName,
FamilyName: &lastName, FamilyName: &lastName,
Picture: &picture, Picture: &picture,
Email: userRawData["sub"], Email: userRawData["email"],
} }
return user, nil return user, nil

View File

@ -15,9 +15,9 @@ import (
func enableAccessTest(t *testing.T, s TestSetup) { func enableAccessTest(t *testing.T, s TestSetup) {
t.Helper() t.Helper()
t.Run(`should revoke access`, func(t *testing.T) { t.Run(`should enable access`, func(t *testing.T) {
req, ctx := createContext(s) req, ctx := createContext(s)
email := "revoke_access." + s.TestInfo.Email email := "enable_access." + s.TestInfo.Email
_, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ _, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{
Email: email, Email: email,
}) })
@ -45,7 +45,7 @@ func enableAccessTest(t *testing.T, s TestSetup) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, res.Message) assert.NotEmpty(t, res.Message)
// it should allow login with revoked access // it should allow login with enabled access
res, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ res, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{
Email: email, Email: email,
}) })

View File

@ -14,6 +14,7 @@ func TestResolvers(t *testing.T) {
constants.DbTypeSqlite: "../../data.db", constants.DbTypeSqlite: "../../data.db",
// constants.DbTypeArangodb: "http://localhost:8529", // constants.DbTypeArangodb: "http://localhost:8529",
// constants.DbTypeMongodb: "mongodb://localhost:27017", // constants.DbTypeMongodb: "mongodb://localhost:27017",
// constants.DbTypeCassandraDB: "127.0.0.1:9042",
} }
for dbType, dbURL := range databases { for dbType, dbURL := range databases {

View File

@ -46,6 +46,11 @@ func cleanData(email string) {
err = db.Provider.DeleteVerificationRequest(verificationRequest) err = db.Provider.DeleteVerificationRequest(verificationRequest)
} }
verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeMagicLinkLogin)
if err == nil {
err = db.Provider.DeleteVerificationRequest(verificationRequest)
}
dbUser, err := db.Provider.GetUserByEmail(email) dbUser, err := db.Provider.GetUserByEmail(email)
if err == nil { if err == nil {
db.Provider.DeleteUser(dbUser) db.Provider.DeleteUser(dbUser)