From 961f2271c1615653b5c465d3c58f36b3a860aff0 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Fri, 22 Apr 2022 19:56:55 +0530 Subject: [PATCH] fix: tests --- server/db/providers/cassandradb/provider.go | 28 +++++++- server/db/providers/cassandradb/user.go | 65 ++++++++++++++----- .../cassandradb/verification_requests.go | 10 ++- server/email/invite_email.go | 2 +- server/email/verification_email.go | 2 +- server/handlers/oauth_callback.go | 2 +- server/test/enable_access_test.go | 6 +- server/test/resolvers_test.go | 1 + server/test/test.go | 5 ++ 9 files changed, 94 insertions(+), 27 deletions(-) diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index dc77aa3..8e12c40 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -58,20 +58,44 @@ func NewProvider() (*provider, error) { return nil, err } - userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nick_name text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id, email))", KeySpace, models.Collections.User) + userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User) err = session.Query(userCollectionQuery).Exec() if err != nil { log.Println("Unable to create user collection:", err) return nil, err } + userIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_user_email ON %s.%s (email)", KeySpace, models.Collections.User) + err = session.Query(userIndexQuery).Exec() + if err != nil { + log.Println("Unable to create user index:", err) + return nil, err + } // token is reserved keyword in cassandra, hence we need to use jwt_token - verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id, identifier, email))", KeySpace, models.Collections.VerificationRequest) + verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.VerificationRequest) err = session.Query(verificationRequestCollectionQuery).Exec() if err != nil { log.Println("Unable to create verification request collection:", err) return nil, err } + verificationRequestIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_email ON %s.%s (email)", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestIndexQuery).Exec() + if err != nil { + log.Println("Unable to create verification_requests index:", err) + return nil, err + } + verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_identifier ON %s.%s (identifier)", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestIndexQuery).Exec() + if err != nil { + log.Println("Unable to create verification_requests index:", err) + return nil, err + } + verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_jwt_token ON %s.%s (jwt_token)", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestIndexQuery).Exec() + if err != nil { + log.Println("Unable to create verification_requests index:", err) + return nil, err + } return &provider{ db: session, diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go index 81813b2..09b7476 100644 --- a/server/db/providers/cassandradb/user.go +++ b/server/db/providers/cassandradb/user.go @@ -32,20 +32,31 @@ func (p *provider) AddUser(user models.User) (models.User, error) { if err != nil { return user, err } + + // use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling + decoder := json.NewDecoder(strings.NewReader(string(bytes))) + decoder.UseNumber() userMap := map[string]interface{}{} - json.Unmarshal(bytes, &userMap) + err = decoder.Decode(&userMap) + if err != nil { + return user, err + } fields := "(" values := "(" for key, value := range userMap { if value != nil { - fields += key + "," + if key == "_id" { + fields += "id," + } else { + fields += key + "," + } valueType := reflect.TypeOf(value) - if valueType.Kind() == reflect.String { - values += "'" + value.(string) + "'," + if valueType.Name() == "string" { + values += fmt.Sprintf("'%s',", value.(string)) } else { - values += fmt.Sprintf("%v", value) + "," + values += fmt.Sprintf("%v,", value) } } } @@ -53,7 +64,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) { fields = fields[:len(fields)-1] + ")" values = values[:len(values)-1] + ")" - query := fmt.Sprintf("INSERT INTO %s %s VALUES %s", KeySpace+"."+models.Collections.User, fields, values) + query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values) err = p.db.Query(query).Exec() if err != nil { @@ -66,26 +77,46 @@ func (p *provider) AddUser(user models.User) (models.User, error) { // UpdateUser to update user information in database func (p *provider) UpdateUser(user models.User) (models.User, error) { user.UpdatedAt = time.Now().Unix() + bytes, err := json.Marshal(user) if err != nil { return user, err } + // use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling + decoder := json.NewDecoder(strings.NewReader(string(bytes))) + decoder.UseNumber() userMap := map[string]interface{}{} - json.Unmarshal(bytes, &userMap) + err = decoder.Decode(&userMap) + if err != nil { + return user, err + } updateFields := "" for key, value := range userMap { - if value != nil { - valueType := reflect.TypeOf(value) - if valueType.Kind() == reflect.String { - updateFields += key + " = '" + value.(string) + "'," - } else { - updateFields += key + " = " + fmt.Sprintf("%v", value) + "," - } + if value != nil && key != "_id" { + } + + if key == "_id" { + continue + } + + if value == nil { + updateFields += fmt.Sprintf("%s = null,", key) + continue + } + + valueType := reflect.TypeOf(value) + if valueType.Name() == "string" { + updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string)) + } else { + updateFields += fmt.Sprintf("%s = %v, ", key, value) } } + updateFields = strings.Trim(updateFields, " ") + updateFields = strings.TrimSuffix(updateFields, ",") query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID) + err = p.db.Query(query).Exec() if err != nil { return user, err @@ -97,8 +128,8 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) { // DeleteUser to delete user information from database func (p *provider) DeleteUser(user models.User) error { query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID) - - return p.db.Query(query).Exec() + err := p.db.Query(query).Exec() + return err } // ListUsers to get list of users from database @@ -114,7 +145,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) // there is no offset in cassandra // so we fetch till limit + offset // and return the results from offset to limit - query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s ORDER BY created_at DESC LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) scanner := p.db.Query(query).Iter().Scanner() counter := int64(0) diff --git a/server/db/providers/cassandradb/verification_requests.go b/server/db/providers/cassandradb/verification_requests.go index de615b2..6c82462 100644 --- a/server/db/providers/cassandradb/verification_requests.go +++ b/server/db/providers/cassandradb/verification_requests.go @@ -2,6 +2,7 @@ package cassandradb import ( "fmt" + "log" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -31,6 +32,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { var verificationRequest models.VerificationRequest query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token) + err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) if err != nil { return verificationRequest, err @@ -41,7 +43,8 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat // GetVerificationRequestByEmail to get verification request by email from database func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { var verificationRequest models.VerificationRequest - query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier) + query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier) + err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) if err != nil { return verificationRequest, err @@ -58,13 +61,15 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.VerificationRequest) err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) if err != nil { + log.Println("Error while quering verification request", err) return nil, err } // there is no offset in cassandra // so we fetch till limit + offset // and return the results from offset to limit - query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s ORDER BY created_at DESC LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset) + query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset) + scanner := p.db.Query(query).Iter().Scanner() counter := int64(0) for scanner.Next() { @@ -72,6 +77,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model var verificationRequest models.VerificationRequest err := scanner.Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) if err != nil { + log.Println("Error while parsing verification request", err) return nil, err } verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest()) diff --git a/server/email/invite_email.go b/server/email/invite_email.go index 5cbd1c9..bdebd81 100644 --- a/server/email/invite_email.go +++ b/server/email/invite_email.go @@ -107,7 +107,7 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error { err := SendMail(Receiver, Subject, message) if err != nil { - log.Println("=> error sending email:", err) + log.Println("error sending email:", err) } return err } diff --git a/server/email/verification_email.go b/server/email/verification_email.go index bb0881f..c373151 100644 --- a/server/email/verification_email.go +++ b/server/email/verification_email.go @@ -107,7 +107,7 @@ func SendVerificationMail(toEmail, token, hostname string) error { err := SendMail(Receiver, Subject, message) if err != nil { - log.Println("=> error sending email:", err) + log.Println("error sending email:", err) } return err } diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index bfa4f00..936e618 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -259,7 +259,7 @@ func processGithubUserInfo(code string) (models.User, error) { GivenName: &firstName, FamilyName: &lastName, Picture: &picture, - Email: userRawData["sub"], + Email: userRawData["email"], } return user, nil diff --git a/server/test/enable_access_test.go b/server/test/enable_access_test.go index c54f91b..6d06153 100644 --- a/server/test/enable_access_test.go +++ b/server/test/enable_access_test.go @@ -15,9 +15,9 @@ import ( func enableAccessTest(t *testing.T, s TestSetup) { t.Helper() - t.Run(`should revoke access`, func(t *testing.T) { + t.Run(`should enable access`, func(t *testing.T) { req, ctx := createContext(s) - email := "revoke_access." + s.TestInfo.Email + email := "enable_access." + s.TestInfo.Email _, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) @@ -45,7 +45,7 @@ func enableAccessTest(t *testing.T, s TestSetup) { assert.NoError(t, err) assert.NotEmpty(t, res.Message) - // it should allow login with revoked access + // it should allow login with enabled access res, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index 40812b1..513c1b0 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -14,6 +14,7 @@ func TestResolvers(t *testing.T) { constants.DbTypeSqlite: "../../data.db", // constants.DbTypeArangodb: "http://localhost:8529", // constants.DbTypeMongodb: "mongodb://localhost:27017", + // constants.DbTypeCassandraDB: "127.0.0.1:9042", } for dbType, dbURL := range databases { diff --git a/server/test/test.go b/server/test/test.go index ff0bdb3..c4cb14f 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -46,6 +46,11 @@ func cleanData(email string) { err = db.Provider.DeleteVerificationRequest(verificationRequest) } + verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeMagicLinkLogin) + if err == nil { + err = db.Provider.DeleteVerificationRequest(verificationRequest) + } + dbUser, err := db.Provider.GetUserByEmail(email) if err == nil { db.Provider.DeleteUser(dbUser)