fix: cassandra + mongo + arangodb issues with webhook

This commit is contained in:
Lakhan Samani 2022-07-12 11:48:42 +05:30
parent bfbeb6add2
commit 6b57bce6d9
19 changed files with 167 additions and 94 deletions

View File

@ -38,8 +38,13 @@ func (user *User) AsAPIUser() *model.User {
email := user.Email email := user.Email
createdAt := user.CreatedAt createdAt := user.CreatedAt
updatedAt := user.UpdatedAt updatedAt := user.UpdatedAt
id := user.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.User{ return &model.User{
ID: user.ID, ID: id,
Email: user.Email, Email: user.Email,
EmailVerified: isEmailVerified, EmailVerified: isEmailVerified,
SignupMethods: user.SignupMethods, SignupMethods: user.SignupMethods,

View File

@ -1,6 +1,10 @@
package models package models
import "github.com/authorizerdev/authorizer/server/graph/model" import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
@ -27,8 +31,13 @@ func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequ
redirectURI := v.RedirectURI redirectURI := v.RedirectURI
expires := v.ExpiresAt expires := v.ExpiresAt
identifier := v.Identifier identifier := v.Identifier
id := v.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.VerificationRequest{ return &model.VerificationRequest{
ID: v.ID, ID: id,
Token: &token, Token: &token,
Identifier: &identifier, Identifier: &identifier,
Expires: &expires, Expires: &expires,

View File

@ -2,6 +2,7 @@ package models
import ( import (
"encoding/json" "encoding/json"
"strings"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
) )
@ -23,8 +24,13 @@ type Webhook struct {
func (w *Webhook) AsAPIWebhook() *model.Webhook { func (w *Webhook) AsAPIWebhook() *model.Webhook {
headersMap := make(map[string]interface{}) headersMap := make(map[string]interface{})
json.Unmarshal([]byte(w.Headers), &headersMap) json.Unmarshal([]byte(w.Headers), &headersMap)
id := w.ID
if strings.Contains(id, Collections.Webhook+"/") {
id = strings.TrimPrefix(id, Collections.Webhook+"/")
}
return &model.Webhook{ return &model.Webhook{
ID: w.ID, ID: id,
EventName: &w.EventName, EventName: &w.EventName,
Endpoint: &w.EndPoint, Endpoint: &w.EndPoint,
Headers: headersMap, Headers: headersMap,

View File

@ -1,6 +1,10 @@
package models package models
import "github.com/authorizerdev/authorizer/server/graph/model" import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
@ -17,8 +21,12 @@ type WebhookLog struct {
} }
func (w *WebhookLog) AsAPIWebhookLog() *model.WebhookLog { func (w *WebhookLog) AsAPIWebhookLog() *model.WebhookLog {
id := w.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.WebhookLog{ return &model.WebhookLog{
ID: w.ID, ID: id,
HTTPStatus: &w.HttpStatus, HTTPStatus: &w.HttpStatus,
Response: &w.Response, Response: &w.Response,
Request: &w.Request, Request: &w.Request,

View File

@ -63,9 +63,9 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
return err return err
} }
query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @userId REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session) query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @user_id REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session)
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"userId": user.ID, "user_id": user.ID,
} }
cursor, err := p.db.Query(ctx, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {

View File

@ -83,7 +83,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
// GetWebhookByID to get webhook by id // GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook var webhook models.Webhook
query := fmt.Sprintf("FOR d in %s FILTER d._id == @webhook_id RETURN d", models.Collections.Webhook) query := fmt.Sprintf("FOR d in %s FILTER d._key == @webhook_id RETURN d", models.Collections.Webhook)
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"webhook_id": webhookID, "webhook_id": webhookID,
} }
@ -146,9 +146,9 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
return err return err
} }
query := fmt.Sprintf("FOR d in %s FILTER d.event_id == @event_id REMOVE { _key: d._key }", models.Collections.WebhookLog) query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog)
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"event_id": webhook.ID, "webhook_id": webhook.ID,
} }
cursor, err := p.db.Query(ctx, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)

View File

@ -37,11 +37,12 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit)
if webhookID != "" { if webhookID != "" {
query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhookID SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit)
bindVariables = map[string]interface{}{ bindVariables = map[string]interface{}{
"webhook_id": webhookID, "webhook_id": webhookID,
} }
} }
sctx := driver.WithQueryFullCount(ctx) sctx := driver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, bindVariables) cursor, err := p.db.Query(sctx, query, bindVariables)
if err != nil { if err != nil {

View File

@ -143,6 +143,11 @@ func NewProvider() (*provider, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sessionIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_session_user_id ON %s.%s (user_id)", KeySpace, models.Collections.Session)
err = session.Query(sessionIndexQuery).Exec()
if err != nil {
return nil, err
}
userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User) userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User)
err = session.Query(userCollectionQuery).Exec() err = session.Query(userCollectionQuery).Exec()
@ -177,7 +182,7 @@ func NewProvider() (*provider, error) {
return nil, err return nil, err
} }
webhookCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, endpoint text, enabled boolean, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Webhook) webhookCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, endpoint text, enabled boolean, headers text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Webhook)
err = session.Query(webhookCollectionQuery).Exec() err = session.Query(webhookCollectionQuery).Exec()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -102,6 +102,10 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use
continue continue
} }
if key == "_key" {
continue
}
if value == nil { if value == nil {
updateFields += fmt.Sprintf("%s = null,", key) updateFields += fmt.Sprintf("%s = null,", key)
continue continue
@ -135,7 +139,19 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
return err return err
} }
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE user_id = '%s'", KeySpace+"."+models.Collections.Session, user.ID) getSessionsQuery := fmt.Sprintf("SELECT id FROM %s WHERE user_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.Session, user.ID)
scanner := p.db.Query(getSessionsQuery).Iter().Scanner()
sessionIDs := ""
for scanner.Next() {
var wlID string
err = scanner.Scan(&wlID)
if err != nil {
return err
}
sessionIDs += fmt.Sprintf("'%s',", wlID)
}
sessionIDs = strings.TrimSuffix(sessionIDs, ",")
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.Session, sessionIDs)
err = p.db.Query(deleteSessionQuery).Exec() err = p.db.Query(deleteSessionQuery).Exec()
if err != nil { if err != nil {
return err return err
@ -181,7 +197,7 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, email) query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
if err != nil { if err != nil {
return user, err return user, err

View File

@ -24,6 +24,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
existingHook, _ := p.GetWebhookByEventName(ctx, webhook.EventName)
if existingHook != nil {
return nil, fmt.Errorf("Webhook with %s event_name already exists", webhook.EventName)
}
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt) insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt)
err := p.db.Query(insertQuery).Exec() err := p.db.Query(insertQuery).Exec()
if err != nil { if err != nil {
@ -56,6 +61,10 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
continue continue
} }
if key == "_key" {
continue
}
if value == nil { if value == nil {
updateFields += fmt.Sprintf("%s = null,", key) updateFields += fmt.Sprintf("%s = null,", key)
continue continue
@ -72,7 +81,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
updateFields = strings.TrimSuffix(updateFields, ",") updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID) query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID)
err = p.db.Query(query).Exec() err = p.db.Query(query).Exec()
if err != nil { if err != nil {
return nil, err return nil, err
@ -130,7 +138,7 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook var webhook models.Webhook
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, eventName) query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName)
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
@ -146,7 +154,19 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
return err return err
} }
query = fmt.Sprintf("DELETE FROM %s WHERE webhook_id = '%s'", KeySpace+"."+models.Collections.WebhookLog, webhook.ID) getWebhookLogQuery := fmt.Sprintf("SELECT id FROM %s WHERE webhook_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhook.ID)
scanner := p.db.Query(getWebhookLogQuery).Iter().Scanner()
webhookLogIDs := ""
for scanner.Next() {
var wlID string
err = scanner.Scan(&wlID)
if err != nil {
return err
}
webhookLogIDs += fmt.Sprintf("'%s',", wlID)
}
webhookLogIDs = strings.TrimSuffix(webhookLogIDs, ",")
query = fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.WebhookLog, webhookLogIDs)
err = p.db.Query(query).Exec() err = p.db.Query(query).Exec()
return err return err
} }

View File

@ -40,8 +40,8 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat
query := fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, pagination.Limit+pagination.Offset) query := fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, pagination.Limit+pagination.Offset)
if webhookID != "" { if webhookID != "" {
totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s'`, KeySpace+"."+models.Collections.WebhookLog, webhookID) totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.WebhookLog, webhookID)
query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset) query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset)
} }
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)

View File

@ -111,7 +111,7 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
} }
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection()) webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
_, err = webhookLogCollection.DeleteOne(nil, bson.M{"webhook_id": webhook.ID}, options.Delete()) _, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete())
if err != nil { if err != nil {
return err return err
} }

View File

@ -113,7 +113,7 @@ func PersistEnv() error {
ctx := context.Background() ctx := context.Background()
env, err := db.Provider.GetEnv(ctx) env, err := db.Provider.GetEnv(ctx)
// config not found in db // config not found in db
if err != nil { if err != nil || env.EnvData == "" {
// AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid // AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid
hash := uuid.New().String()[:36-4] hash := uuid.New().String()[:36-4]
err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash) err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash)
@ -174,7 +174,7 @@ func PersistEnv() error {
err = json.Unmarshal(decryptedConfigs, &storeData) err = json.Unmarshal(decryptedConfigs, &storeData)
if err != nil { if err != nil {
log.Debug("Error while unmarshalling env data: ", err) log.Debug("Error while un-marshalling env data: ", err)
return err return err
} }

View File

@ -1,10 +1,7 @@
package stores package stores
import ( import (
"os"
"sync" "sync"
"github.com/authorizerdev/authorizer/server/constants"
) )
// EnvStore struct to store the env variables // EnvStore struct to store the env variables
@ -23,12 +20,10 @@ func NewEnvStore() *EnvStore {
// UpdateEnvStore to update the whole env store object // UpdateEnvStore to update the whole env store object
func (e *EnvStore) UpdateStore(store map[string]interface{}) { func (e *EnvStore) UpdateStore(store map[string]interface{}) {
if os.Getenv("ENV") != constants.TestEnv {
e.mutex.Lock() e.mutex.Lock()
defer e.mutex.Unlock() defer e.mutex.Unlock()
}
// just override the keys + new keys
// just override the keys + new keys
for key, value := range store { for key, value := range store {
e.store[key] = value e.store[key] = value
} }
@ -46,9 +41,8 @@ func (e *EnvStore) Get(key string) interface{} {
// Set sets the value of the key in env store // Set sets the value of the key in env store
func (e *EnvStore) Set(key string, value interface{}) { func (e *EnvStore) Set(key string, value interface{}) {
if os.Getenv("ENV") != constants.TestEnv {
e.mutex.Lock() e.mutex.Lock()
defer e.mutex.Unlock() defer e.mutex.Unlock()
}
e.store[key] = value e.store[key] = value
} }

View File

@ -1,11 +1,8 @@
package stores package stores
import ( import (
"os"
"strings" "strings"
"sync" "sync"
"github.com/authorizerdev/authorizer/server/constants"
) )
// SessionStore struct to store the env variables // SessionStore struct to store the env variables
@ -29,10 +26,9 @@ func (s *SessionStore) Get(key, subKey string) string {
// Set sets the value of the key in state store // Set sets the value of the key in state store
func (s *SessionStore) Set(key string, subKey, value string) { func (s *SessionStore) Set(key string, subKey, value string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
}
if _, ok := s.store[key]; !ok { if _, ok := s.store[key]; !ok {
s.store[key] = make(map[string]string) s.store[key] = make(map[string]string)
} }
@ -41,19 +37,15 @@ func (s *SessionStore) Set(key string, subKey, value string) {
// RemoveAll all values for given key // RemoveAll all values for given key
func (s *SessionStore) RemoveAll(key string) { func (s *SessionStore) RemoveAll(key string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
}
delete(s.store, key) delete(s.store, key)
} }
// Remove value for given key and subkey // Remove value for given key and subkey
func (s *SessionStore) Remove(key, subKey string) { func (s *SessionStore) Remove(key, subKey string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
}
if _, ok := s.store[key]; ok { if _, ok := s.store[key]; ok {
delete(s.store[key], subKey) delete(s.store[key], subKey)
} }
@ -69,11 +61,8 @@ func (s *SessionStore) GetAll(key string) map[string]string {
// RemoveByNamespace to delete session for a given namespace example google,github // RemoveByNamespace to delete session for a given namespace example google,github
func (s *SessionStore) RemoveByNamespace(namespace string) error { func (s *SessionStore) RemoveByNamespace(namespace string) error {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
}
for key := range s.store { for key := range s.store {
if strings.Contains(key, namespace+":") { if strings.Contains(key, namespace+":") {
delete(s.store, key) delete(s.store, key)

View File

@ -1,10 +1,7 @@
package stores package stores
import ( import (
"os"
"sync" "sync"
"github.com/authorizerdev/authorizer/server/constants"
) )
// StateStore struct to store the env variables // StateStore struct to store the env variables
@ -28,19 +25,16 @@ func (s *StateStore) Get(key string) string {
// Set sets the value of the key in state store // Set sets the value of the key in state store
func (s *StateStore) Set(key string, value string) { func (s *StateStore) Set(key string, value string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
}
s.store[key] = value s.store[key] = value
} }
// Remove removes the key from state store // Remove removes the key from state store
func (s *StateStore) Remove(key string) { func (s *StateStore) Remove(key string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
}
delete(s.store, key) delete(s.store, key)
} }

View File

@ -24,7 +24,11 @@ func deleteWebhookTest(t *testing.T, s TestSetup) {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
// get all webhooks // get all webhooks
webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{}) webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{
Limit: 10,
Page: 1,
Offset: 0,
})
assert.NoError(t, err) assert.NoError(t, err)
for _, w := range webhooks.Webhooks { for _, w := range webhooks.Webhooks {
@ -37,12 +41,17 @@ func deleteWebhookTest(t *testing.T, s TestSetup) {
assert.NotEmpty(t, res.Message) assert.NotEmpty(t, res.Message)
} }
webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{}) webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{
Limit: 10,
Page: 1,
Offset: 0,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, webhooks.Webhooks, 0) assert.Len(t, webhooks.Webhooks, 0)
webhookLogs, err := db.Provider.ListWebhookLogs(ctx, model.Pagination{ webhookLogs, err := db.Provider.ListWebhookLogs(ctx, model.Pagination{
Limit: 10, Limit: 100,
Page: 1,
Offset: 0,
}, "") }, "")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, webhookLogs.WebhookLogs, 0) assert.Len(t, webhookLogs.WebhookLogs, 0)

View File

@ -2,8 +2,8 @@ package test
import ( import (
"context" "context"
"os"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
@ -13,31 +13,45 @@ import (
func TestResolvers(t *testing.T) { func TestResolvers(t *testing.T) {
databases := map[string]string{ databases := map[string]string{
constants.DbTypeSqlite: "../../data.db", // constants.DbTypeSqlite: "../../data.db",
// constants.DbTypeArangodb: "http://localhost:8529", // constants.DbTypeArangodb: "http://localhost:8529",
// constants.DbTypeMongodb: "mongodb://localhost:27017", // constants.DbTypeMongodb: "mongodb://localhost:27017",
// constants.DbTypeCassandraDB: "127.0.0.1:9042", constants.DbTypeScyllaDB: "127.0.0.1:9042",
} }
for dbType, dbURL := range databases { testDb := "authorizer_test"
s := testSetup() s := testSetup()
defer s.Server.Close() defer s.Server.Close()
for dbType, dbURL := range databases {
ctx := context.Background() ctx := context.Background()
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseURL, dbURL) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseURL, dbURL)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseType, dbType) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseType, dbType)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseName, testDb)
os.Setenv(constants.EnvKeyDatabaseURL, dbURL)
os.Setenv(constants.EnvKeyDatabaseType, dbType)
os.Setenv(constants.EnvKeyDatabaseName, testDb)
memorystore.InitRequiredEnv()
err := db.InitDB() err := db.InitDB()
if err != nil { if err != nil {
t.Errorf("Error initializing database: %s", err) t.Errorf("Error initializing database: %s", err.Error())
} }
// clean the persisted config for test to use fresh config // clean the persisted config for test to use fresh config
envData, err := db.Provider.GetEnv(ctx) envData, err := db.Provider.GetEnv(ctx)
if err == nil { if err == nil {
envData.EnvData = "" envData.EnvData = ""
db.Provider.UpdateEnv(ctx, envData) _, err = db.Provider.UpdateEnv(ctx, envData)
if err != nil {
t.Errorf("Error updating env: %s", err.Error())
}
}
err = env.PersistEnv()
if err != nil {
t.Errorf("Error persisting env: %s", err.Error())
} }
env.PersistEnv()
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEnv, "test") memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEnv, "test")
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyIsProd, false) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyIsProd, false)
@ -78,7 +92,6 @@ func TestResolvers(t *testing.T) {
inviteUserTest(t, s) inviteUserTest(t, s)
validateJwtTokenTest(t, s) validateJwtTokenTest(t, s)
time.Sleep(5 * time.Second) // add sleep for webhooklogs to get generated as they are async
webhookLogsTest(t, s) // get logs after above resolver tests are done webhookLogsTest(t, s) // get logs after above resolver tests are done
deleteWebhookTest(t, s) // delete webhooks (admin resolver) deleteWebhookTest(t, s) // delete webhooks (admin resolver)
}) })

View File

@ -3,6 +3,7 @@ package test
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/crypto"
@ -14,6 +15,7 @@ import (
) )
func webhookLogsTest(t *testing.T, s TestSetup) { func webhookLogsTest(t *testing.T, s TestSetup) {
time.Sleep(30 * time.Second) // add sleep for webhooklogs to get generated as they are async
t.Helper() t.Helper()
t.Run("should get webhook logs", func(t *testing.T) { t.Run("should get webhook logs", func(t *testing.T) {
req, ctx := createContext(s) req, ctx := createContext(s)
@ -23,15 +25,16 @@ func webhookLogsTest(t *testing.T, s TestSetup) {
assert.NoError(t, err) assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
webhooks, err := resolvers.WebhooksResolver(ctx, nil)
assert.NoError(t, err)
assert.NotEmpty(t, webhooks)
webhookLogs, err := resolvers.WebhookLogsResolver(ctx, nil) webhookLogs, err := resolvers.WebhookLogsResolver(ctx, nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, len(webhookLogs.WebhookLogs), 1) assert.Greater(t, len(webhookLogs.WebhookLogs), 1)
webhooks, err := resolvers.WebhooksResolver(ctx, nil)
assert.NoError(t, err)
assert.NotEmpty(t, webhooks)
for _, w := range webhooks.Webhooks { for _, w := range webhooks.Webhooks {
t.Run(fmt.Sprintf("should get webhook for webhook_id:%s", w.ID), func(t *testing.T) {
webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{ webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{
WebhookID: &w.ID, WebhookID: &w.ID,
}) })
@ -40,6 +43,7 @@ func webhookLogsTest(t *testing.T, s TestSetup) {
for _, wl := range webhookLogs.WebhookLogs { for _, wl := range webhookLogs.WebhookLogs {
assert.Equal(t, utils.StringValue(wl.WebhookID), w.ID) assert.Equal(t, utils.StringValue(wl.WebhookID), w.ID)
} }
})
} }
}) })
} }