diff --git a/Makefile b/Makefile index 5b434d7..c55ab9c 100644 --- a/Makefile +++ b/Makefile @@ -51,6 +51,6 @@ test-all-db: docker rm -vf authorizer_mongodb_db docker rm -vf authorizer_arangodb docker rm -vf dynamodb-local-test - # docker rm -vf couchbase-local-test + docker rm -vf couchbase-local-test generate: cd server && go run github.com/99designs/gqlgen generate && go mod tidy diff --git a/dashboard/src/components/UpdateWebhookModal.tsx b/dashboard/src/components/UpdateWebhookModal.tsx index 55af022..db35d31 100644 --- a/dashboard/src/components/UpdateWebhookModal.tsx +++ b/dashboard/src/components/UpdateWebhookModal.tsx @@ -63,6 +63,7 @@ interface headersValidatorDataType { interface selecetdWebhookDataTypes { [WebhookInputDataFields.ID]: string; [WebhookInputDataFields.EVENT_NAME]: string; + [WebhookInputDataFields.EVENT_DESCRIPTION]?: string; [WebhookInputDataFields.ENDPOINT]: string; [WebhookInputDataFields.ENABLED]: boolean; [WebhookInputDataFields.HEADERS]?: Record; @@ -86,6 +87,7 @@ const initHeadersValidatorData: headersValidatorDataType = { interface webhookDataType { [WebhookInputDataFields.EVENT_NAME]: string; + [WebhookInputDataFields.EVENT_DESCRIPTION]?: string; [WebhookInputDataFields.ENDPOINT]: string; [WebhookInputDataFields.ENABLED]: boolean; [WebhookInputDataFields.HEADERS]: headersDataType[]; @@ -98,6 +100,7 @@ interface validatorDataType { const initWebhookData: webhookDataType = { [WebhookInputDataFields.EVENT_NAME]: webhookEventNames['User login'], + [WebhookInputDataFields.EVENT_DESCRIPTION]: '', [WebhookInputDataFields.ENDPOINT]: '', [WebhookInputDataFields.ENABLED]: true, [WebhookInputDataFields.HEADERS]: [{ ...initHeadersData }], @@ -144,6 +147,9 @@ const UpdateWebhookModal = ({ case WebhookInputDataFields.EVENT_NAME: setWebhook({ ...webhook, [inputType]: value }); break; + case WebhookInputDataFields.EVENT_DESCRIPTION: + setWebhook({ ...webhook, [inputType]: value }); + break; case WebhookInputDataFields.ENDPOINT: setWebhook({ ...webhook, [inputType]: value }); setValidator({ @@ -246,6 +252,8 @@ const UpdateWebhookModal = ({ let params: any = { [WebhookInputDataFields.EVENT_NAME]: webhook[WebhookInputDataFields.EVENT_NAME], + [WebhookInputDataFields.EVENT_DESCRIPTION]: + webhook[WebhookInputDataFields.EVENT_DESCRIPTION], [WebhookInputDataFields.ENDPOINT]: webhook[WebhookInputDataFields.ENDPOINT], [WebhookInputDataFields.ENABLED]: webhook[WebhookInputDataFields.ENABLED], @@ -402,7 +410,9 @@ const UpdateWebhookModal = ({ + + Event Description + + + + inputChangehandler( + WebhookInputDataFields.EVENT_DESCRIPTION, + e.currentTarget.value, + ) + } + /> + + + ; @@ -117,6 +118,7 @@ const Webhooks = () => { useEffect(() => { fetchWebookData(); }, [paginationProps.page, paginationProps.limit]); + console.log({ webhookData }); return ( @@ -134,6 +136,7 @@ const Webhooks = () => { Event Name + Event Description Endpoint Enabled Headers @@ -147,7 +150,10 @@ const Webhooks = () => { style={{ fontSize: 14 }} > - {webhook[WebhookInputDataFields.EVENT_NAME]} + {webhook[WebhookInputDataFields.EVENT_NAME].split('-')[0]} + + + {webhook[WebhookInputDataFields.EVENT_DESCRIPTION]} {webhook[WebhookInputDataFields.ENDPOINT]} @@ -264,7 +270,7 @@ const Webhooks = () => { - Go to page:{' '} + Go to page:{' '} = pagination.Offset { var webhook models.Webhook - err := scanner.Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) + err := scanner.Scan(&webhook.ID, &webhook.EventDescription, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) if err != nil { return nil, err } @@ -127,8 +116,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { var webhook models.Webhook - query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID) - err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) + query := fmt.Sprintf(`SELECT id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID) + err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventDescription, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) if err != nil { return nil, err } @@ -136,14 +125,19 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook - query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName) - err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) - if err != nil { - return nil, err +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + query := fmt.Sprintf(`SELECT id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name LIKE '%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName+"%") + scanner := p.db.Query(query).Iter().Scanner() + webhooks := []*model.Webhook{} + for scanner.Next() { + var webhook models.Webhook + err := scanner.Scan(&webhook.ID, &webhook.EventDescription, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) + if err != nil { + return nil, err + } + webhooks = append(webhooks, webhook.AsAPIWebhook()) } - return webhook.AsAPIWebhook(), nil + return webhooks, nil } // DeleteWebhook to delete webhook diff --git a/server/db/providers/couchbase/webhook.go b/server/db/providers/couchbase/webhook.go index e438263..2f51acd 100644 --- a/server/db/providers/couchbase/webhook.go +++ b/server/db/providers/couchbase/webhook.go @@ -19,11 +19,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) insertOpt := gocb.InsertOptions{ Context: ctx, } @@ -37,7 +37,10 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() - + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } bytes, err := json.Marshal(webhook) if err != nil { return nil, err @@ -50,17 +53,13 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* if err != nil { return nil, err } - updateFields, params := GetSetFields(webhookMap) - query := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE _id='%s'`, p.scopeName, models.Collections.Webhook, updateFields, webhook.ID) - _, err = p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, NamedParameters: params, }) - if err != nil { return nil, err } @@ -72,7 +71,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} paginationClone := pagination - params := make(map[string]interface{}, 1) params["offset"] = paginationClone.Offset params["limit"] = paginationClone.Limit @@ -81,7 +79,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) return nil, err } paginationClone.Total = total - query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook) + query := fmt.Sprintf("SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook) queryResult, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, @@ -110,11 +108,9 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { var webhook models.Webhook - params := make(map[string]interface{}, 1) params["_id"] = webhookID - - query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook) + query := fmt.Sprintf(`SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook) q, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, @@ -124,42 +120,42 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model return nil, err } err = q.One(&webhook) - if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { params := make(map[string]interface{}, 1) - params["event_name"] = eventName - - query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE event_name=$event_name LIMIT 1`, p.scopeName, models.Collections.Webhook) - q, err := p.db.Query(query, &gocb.QueryOptions{ + // params["event_name"] = eventName + "%" + query := fmt.Sprintf(`SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE event_name LIKE '%s'`, p.scopeName, models.Collections.Webhook, eventName+"%") + queryResult, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, NamedParameters: params, }) - if err != nil { return nil, err } - err = q.One(&webhook) - - if err != nil { + webhooks := []*model.Webhook{} + for queryResult.Next() { + var webhook models.Webhook + err := queryResult.Row(&webhook) + if err != nil { + log.Fatal(err) + } + webhooks = append(webhooks, webhook.AsAPIWebhook()) + } + if err := queryResult.Err(); err != nil { return nil, err } - - return webhook.AsAPIWebhook(), nil + return webhooks, nil } // DeleteWebhook to delete webhook func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error { - params := make(map[string]interface{}, 1) params["webhook_id"] = webhook.ID removeOpt := gocb.RemoveOptions{ diff --git a/server/db/providers/dynamodb/webhook.go b/server/db/providers/dynamodb/webhook.go index 9cf7ec7..8f1ffb7 100644 --- a/server/db/providers/dynamodb/webhook.go +++ b/server/db/providers/dynamodb/webhook.go @@ -3,28 +3,29 @@ package dynamodb import ( "context" "errors" + "fmt" + "strings" "time" + "github.com/google/uuid" + "github.com/guregu/dynamo" + "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/google/uuid" - "github.com/guregu/dynamo" ) // AddWebhook to add webhook func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { collection := p.db.Table(models.Collections.Webhook) - if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) err := collection.Put(webhook).RunWithContext(ctx) - if err != nil { return nil, err } @@ -33,11 +34,13 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { - collection := p.db.Table(models.Collections.Webhook) - webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } + collection := p.db.Table(models.Collections.Webhook) err := UpdateByHashKey(collection, "id", webhook.ID, webhook) - if err != nil { return nil, err } @@ -51,16 +54,13 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) var lastEval dynamo.PagingKey var iter dynamo.PagingIter var iteration int64 = 0 - collection := p.db.Table(models.Collections.Webhook) paginationClone := pagination scanner := collection.Scan() count, err := scanner.Count() - if err != nil { return nil, err } - for (paginationClone.Offset + paginationClone.Limit) > iteration { iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() for iter.NextWithContext(ctx, &webhook) { @@ -75,9 +75,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) lastEval = iter.LastEvaluatedKey() iteration += paginationClone.Limit } - paginationClone.Total = count - return &model.Webhooks{ Pagination: &paginationClone, Webhooks: webhooks, @@ -88,37 +86,29 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { collection := p.db.Table(models.Collections.Webhook) var webhook models.Webhook - err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook) - if err != nil { return nil, err } - if webhook.ID == "" { return webhook.AsAPIWebhook(), errors.New("no documets found") } - return webhook.AsAPIWebhook(), nil } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + webhooks := []models.Webhook{} collection := p.db.Table(models.Collections.Webhook) - - iter := collection.Scan().Index("event_name").Filter("'event_name' = ?", eventName).Iter() - - for iter.NextWithContext(ctx, &webhook) { - return webhook.AsAPIWebhook(), nil - } - - err := iter.Err() - + err := collection.Scan().Index("event_name").Filter("contains(event_name, ?)", eventName).AllWithContext(ctx, &webhooks) if err != nil { - return webhook.AsAPIWebhook(), err + return nil, err } - return webhook.AsAPIWebhook(), nil + resWebhooks := []*model.Webhook{} + for _, w := range webhooks { + resWebhooks = append(resWebhooks, w.AsAPIWebhook()) + } + return resWebhooks, nil } // DeleteWebhook to delete webhook @@ -133,7 +123,6 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er return err } webhookLogs, errIs := p.ListWebhookLogs(ctx, pagination, webhook.ID) - for _, webhookLog := range webhookLogs.WebhookLogs { err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx) if err != nil { diff --git a/server/db/providers/mongodb/webhook.go b/server/db/providers/mongodb/webhook.go index 7b29398..843aec9 100644 --- a/server/db/providers/mongodb/webhook.go +++ b/server/db/providers/mongodb/webhook.go @@ -2,6 +2,8 @@ package mongodb import ( "context" + "fmt" + "strings" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -16,11 +18,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) _, err := webhookCollection.InsertOne(ctx, webhook) if err != nil { @@ -32,39 +34,37 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) _, err := webhookCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": webhook.ID}}, bson.M{"$set": webhook}, options.MergeUpdateOptions()) if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil } // ListWebhooks to list webhook func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { - var webhooks []*model.Webhook + webhooks := []*model.Webhook{} opts := options.Find() opts.SetLimit(pagination.Limit) opts.SetSkip(pagination.Offset) opts.SetSort(bson.M{"created_at": -1}) - paginationClone := pagination - webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) count, err := webhookCollection.CountDocuments(ctx, bson.M{}, options.Count()) if err != nil { return nil, err } - paginationClone.Total = count - cursor, err := webhookCollection.Find(ctx, bson.M{}, opts) if err != nil { return nil, err } defer cursor.Close(ctx) - for cursor.Next(ctx) { var webhook models.Webhook err := cursor.Decode(&webhook) @@ -73,7 +73,6 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } webhooks = append(webhooks, webhook.AsAPIWebhook()) } - return &model.Webhooks{ Pagination: &paginationClone, Webhooks: webhooks, @@ -92,14 +91,27 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + webhooks := []*model.Webhook{} webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) - err := webhookCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&webhook) + opts := options.Find() + opts.SetSort(bson.M{"created_at": -1}) + cursor, err := webhookCollection.Find(ctx, bson.M{"event_name": bson.M{ + "$regex": fmt.Sprintf("^%s", eventName), + }}, opts) if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil + defer cursor.Close(ctx) + for cursor.Next(ctx) { + var webhook models.Webhook + err := cursor.Decode(&webhook) + if err != nil { + return nil, err + } + webhooks = append(webhooks, webhook.AsAPIWebhook()) + } + return webhooks, nil } // DeleteWebhook to delete webhook @@ -109,12 +121,10 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er if err != nil { return err } - webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection()) _, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete()) if err != nil { return err } - return nil } diff --git a/server/db/providers/provider_template/webhook.go b/server/db/providers/provider_template/webhook.go index eda82ee..faf18fa 100644 --- a/server/db/providers/provider_template/webhook.go +++ b/server/db/providers/provider_template/webhook.go @@ -2,6 +2,8 @@ package provider_template import ( "context" + "fmt" + "strings" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -14,16 +16,21 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) return webhook.AsAPIWebhook(), nil } // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } return webhook.AsAPIWebhook(), nil } @@ -38,7 +45,7 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { return nil, nil } diff --git a/server/db/providers/providers.go b/server/db/providers/providers.go index 7325204..f6a2aad 100644 --- a/server/db/providers/providers.go +++ b/server/db/providers/providers.go @@ -56,7 +56,7 @@ type Provider interface { // GetWebhookByID to get webhook by id GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) // GetWebhookByEventName to get webhook by event_name - GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) + GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) // DeleteWebhook to delete webhook DeleteWebhook(ctx context.Context, webhook *model.Webhook) error diff --git a/server/db/providers/sql/webhook.go b/server/db/providers/sql/webhook.go index 93f21a4..72f3cb4 100644 --- a/server/db/providers/sql/webhook.go +++ b/server/db/providers/sql/webhook.go @@ -2,6 +2,8 @@ package sql import ( "context" + "fmt" + "strings" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -14,10 +16,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) res := p.db.Create(&webhook) if res.Error != nil { return nil, res.Error @@ -28,33 +31,31 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() - + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } result := p.db.Save(&webhook) if result.Error != nil { return nil, result.Error } - return webhook.AsAPIWebhook(), nil } // ListWebhooks to list webhook func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { var webhooks []models.Webhook - result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks) if result.Error != nil { return nil, result.Error } - var total int64 totalRes := p.db.Model(&models.Webhook{}).Count(&total) if totalRes.Error != nil { return nil, totalRes.Error } - paginationClone := pagination paginationClone.Total = total - responseWebhooks := []*model.Webhook{} for _, w := range webhooks { responseWebhooks = append(responseWebhooks, w.AsAPIWebhook()) @@ -77,14 +78,17 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook - - result := p.db.Where("event_name = ?", eventName).First(&webhook) +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + var webhooks []models.Webhook + result := p.db.Where("event_name LIKE ?", eventName+"%").Find(&webhooks) if result.Error != nil { return nil, result.Error } - return webhook.AsAPIWebhook(), nil + responseWebhooks := []*model.Webhook{} + for _, w := range webhooks { + responseWebhooks = append(responseWebhooks, w.AsAPIWebhook()) + } + return responseWebhooks, nil } // DeleteWebhook to delete webhook diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index 3eb5dce..a620c77 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -275,13 +275,14 @@ type ComplexityRoot struct { } Webhook struct { - CreatedAt func(childComplexity int) int - Enabled func(childComplexity int) int - Endpoint func(childComplexity int) int - EventName func(childComplexity int) int - Headers func(childComplexity int) int - ID func(childComplexity int) int - UpdatedAt func(childComplexity int) int + CreatedAt func(childComplexity int) int + Enabled func(childComplexity int) int + Endpoint func(childComplexity int) int + EventDescription func(childComplexity int) int + EventName func(childComplexity int) int + Headers func(childComplexity int) int + ID func(childComplexity int) int + UpdatedAt func(childComplexity int) int } WebhookLog struct { @@ -1833,6 +1834,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Webhook.Endpoint(childComplexity), true + case "Webhook.event_description": + if e.complexity.Webhook.EventDescription == nil { + break + } + + return e.complexity.Webhook.EventDescription(childComplexity), true + case "Webhook.event_name": if e.complexity.Webhook.EventName == nil { break @@ -2210,7 +2218,8 @@ type GenerateJWTKeysResponse { type Webhook { id: ID! - event_name: String + event_name: String # this is unique string + event_description: String endpoint: String enabled: Boolean headers: Map @@ -2501,6 +2510,7 @@ input ListWebhookLogRequest { input AddWebhookRequest { event_name: String! + event_description: String endpoint: String! enabled: Boolean! headers: Map @@ -2509,6 +2519,7 @@ input AddWebhookRequest { input UpdateWebhookRequest { id: ID! event_name: String + event_description: String endpoint: String enabled: Boolean headers: Map @@ -10143,6 +10154,8 @@ func (ec *executionContext) fieldContext_Query__webhook(ctx context.Context, fie return ec.fieldContext_Webhook_id(ctx, field) case "event_name": return ec.fieldContext_Webhook_event_name(ctx, field) + case "event_description": + return ec.fieldContext_Webhook_event_description(ctx, field) case "endpoint": return ec.fieldContext_Webhook_endpoint(ctx, field) case "enabled": @@ -12201,6 +12214,47 @@ func (ec *executionContext) fieldContext_Webhook_event_name(ctx context.Context, return fc, nil } +func (ec *executionContext) _Webhook_event_description(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Webhook_event_description(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.EventDescription, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Webhook_event_description(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Webhook", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Webhook_endpoint(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Webhook_endpoint(ctx, field) if err != nil { @@ -12907,6 +12961,8 @@ func (ec *executionContext) fieldContext_Webhooks_webhooks(ctx context.Context, return ec.fieldContext_Webhook_id(ctx, field) case "event_name": return ec.fieldContext_Webhook_event_name(ctx, field) + case "event_description": + return ec.fieldContext_Webhook_event_description(ctx, field) case "endpoint": return ec.fieldContext_Webhook_endpoint(ctx, field) case "enabled": @@ -14756,7 +14812,7 @@ func (ec *executionContext) unmarshalInputAddWebhookRequest(ctx context.Context, asMap[k] = v } - fieldsInOrder := [...]string{"event_name", "endpoint", "enabled", "headers"} + fieldsInOrder := [...]string{"event_name", "event_description", "endpoint", "enabled", "headers"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -14771,6 +14827,14 @@ func (ec *executionContext) unmarshalInputAddWebhookRequest(ctx context.Context, if err != nil { return it, err } + case "event_description": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("event_description")) + it.EventDescription, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } case "endpoint": var err error @@ -16612,7 +16676,7 @@ func (ec *executionContext) unmarshalInputUpdateWebhookRequest(ctx context.Conte asMap[k] = v } - fieldsInOrder := [...]string{"id", "event_name", "endpoint", "enabled", "headers"} + fieldsInOrder := [...]string{"id", "event_name", "event_description", "endpoint", "enabled", "headers"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16635,6 +16699,14 @@ func (ec *executionContext) unmarshalInputUpdateWebhookRequest(ctx context.Conte if err != nil { return it, err } + case "event_description": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("event_description")) + it.EventDescription, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } case "endpoint": var err error @@ -18513,6 +18585,10 @@ func (ec *executionContext) _Webhook(ctx context.Context, sel ast.SelectionSet, out.Values[i] = ec._Webhook_event_name(ctx, field, obj) + case "event_description": + + out.Values[i] = ec._Webhook_event_description(ctx, field, obj) + case "endpoint": out.Values[i] = ec._Webhook_endpoint(ctx, field, obj) diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 57b1aad..d2c7d23 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -10,10 +10,11 @@ type AddEmailTemplateRequest struct { } type AddWebhookRequest struct { - EventName string `json:"event_name"` - Endpoint string `json:"endpoint"` - Enabled bool `json:"enabled"` - Headers map[string]interface{} `json:"headers"` + EventName string `json:"event_name"` + EventDescription *string `json:"event_description"` + Endpoint string `json:"endpoint"` + Enabled bool `json:"enabled"` + Headers map[string]interface{} `json:"headers"` } type AdminLoginInput struct { @@ -387,11 +388,12 @@ type UpdateUserInput struct { } type UpdateWebhookRequest struct { - ID string `json:"id"` - EventName *string `json:"event_name"` - Endpoint *string `json:"endpoint"` - Enabled *bool `json:"enabled"` - Headers map[string]interface{} `json:"headers"` + ID string `json:"id"` + EventName *string `json:"event_name"` + EventDescription *string `json:"event_description"` + Endpoint *string `json:"endpoint"` + Enabled *bool `json:"enabled"` + Headers map[string]interface{} `json:"headers"` } type User struct { @@ -461,13 +463,14 @@ type VerifyOTPRequest struct { } type Webhook struct { - ID string `json:"id"` - EventName *string `json:"event_name"` - Endpoint *string `json:"endpoint"` - Enabled *bool `json:"enabled"` - Headers map[string]interface{} `json:"headers"` - CreatedAt *int64 `json:"created_at"` - UpdatedAt *int64 `json:"updated_at"` + ID string `json:"id"` + EventName *string `json:"event_name"` + EventDescription *string `json:"event_description"` + Endpoint *string `json:"endpoint"` + Enabled *bool `json:"enabled"` + Headers map[string]interface{} `json:"headers"` + CreatedAt *int64 `json:"created_at"` + UpdatedAt *int64 `json:"updated_at"` } type WebhookLog struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 9b96d13..aa95de1 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -168,7 +168,8 @@ type GenerateJWTKeysResponse { type Webhook { id: ID! - event_name: String + event_name: String # this is unique string + event_description: String endpoint: String enabled: Boolean headers: Map @@ -459,6 +460,7 @@ input ListWebhookLogRequest { input AddWebhookRequest { event_name: String! + event_description: String endpoint: String! enabled: Boolean! headers: Map @@ -467,6 +469,7 @@ input AddWebhookRequest { input UpdateWebhookRequest { id: ID! event_name: String + event_description: String endpoint: String enabled: Boolean headers: Map diff --git a/server/handlers/graphql.go b/server/handlers/graphql.go index 734a458..367a956 100644 --- a/server/handlers/graphql.go +++ b/server/handlers/graphql.go @@ -2,7 +2,7 @@ package handlers import ( "github.com/99designs/gqlgen/graphql/handler" - graph "github.com/authorizerdev/authorizer/server/graph" + "github.com/authorizerdev/authorizer/server/graph" "github.com/authorizerdev/authorizer/server/graph/generated" "github.com/gin-gonic/gin" ) diff --git a/server/resolvers/add_webhook.go b/server/resolvers/add_webhook.go index d8c768a..596b1e0 100644 --- a/server/resolvers/add_webhook.go +++ b/server/resolvers/add_webhook.go @@ -9,6 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" "github.com/authorizerdev/authorizer/server/validators" @@ -22,32 +23,32 @@ func AddWebhookResolver(ctx context.Context, params model.AddWebhookRequest) (*m log.Debug("Failed to get GinContext: ", err) return nil, err } - if !token.IsSuperAdmin(gc) { log.Debug("Not logged in as super admin") return nil, fmt.Errorf("unauthorized") } - if !validators.IsValidWebhookEventName(params.EventName) { log.Debug("Invalid Event Name: ", params.EventName) return nil, fmt.Errorf("invalid event name %s", params.EventName) } - if strings.TrimSpace(params.Endpoint) == "" { log.Debug("empty endpoint not allowed") return nil, fmt.Errorf("empty endpoint not allowed") } - headerBytes, err := json.Marshal(params.Headers) if err != nil { return nil, err } + if params.EventDescription == nil { + params.EventDescription = refs.NewStringRef(strings.Join(strings.Split(params.EventName, "."), " ")) + } _, err = db.Provider.AddWebhook(ctx, models.Webhook{ - EventName: params.EventName, - EndPoint: params.Endpoint, - Enabled: params.Enabled, - Headers: string(headerBytes), + EventDescription: refs.StringValue(params.EventDescription), + EventName: params.EventName, + EndPoint: params.Endpoint, + Enabled: params.Enabled, + Headers: string(headerBytes), }) if err != nil { log.Debug("Failed to add webhook: ", err) diff --git a/server/resolvers/update_webhook.go b/server/resolvers/update_webhook.go index f1d1009..5783984 100644 --- a/server/resolvers/update_webhook.go +++ b/server/resolvers/update_webhook.go @@ -28,13 +28,11 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques log.Debug("Not logged in as super admin") return nil, fmt.Errorf("unauthorized") } - webhook, err := db.Provider.GetWebhookByID(ctx, params.ID) if err != nil { log.Debug("failed to get webhook: ", err) return nil, err } - headersString := "" if webhook.Headers != nil { headerBytes, err := json.Marshal(webhook.Headers) @@ -43,17 +41,16 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques } headersString = string(headerBytes) } - webhookDetails := models.Webhook{ - ID: webhook.ID, - Key: webhook.ID, - EventName: refs.StringValue(webhook.EventName), - EndPoint: refs.StringValue(webhook.Endpoint), - Enabled: refs.BoolValue(webhook.Enabled), - Headers: headersString, - CreatedAt: refs.Int64Value(webhook.CreatedAt), + ID: webhook.ID, + Key: webhook.ID, + EventName: refs.StringValue(webhook.EventName), + EventDescription: refs.StringValue(webhook.EventDescription), + EndPoint: refs.StringValue(webhook.Endpoint), + Enabled: refs.BoolValue(webhook.Enabled), + Headers: headersString, + CreatedAt: refs.Int64Value(webhook.CreatedAt), } - if params.EventName != nil && webhookDetails.EventName != refs.StringValue(params.EventName) { if isValid := validators.IsValidWebhookEventName(refs.StringValue(params.EventName)); !isValid { log.Debug("invalid event name: ", refs.StringValue(params.EventName)) @@ -61,7 +58,6 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques } webhookDetails.EventName = refs.StringValue(params.EventName) } - if params.Endpoint != nil && webhookDetails.EndPoint != refs.StringValue(params.Endpoint) { if strings.TrimSpace(refs.StringValue(params.Endpoint)) == "" { log.Debug("empty endpoint not allowed") @@ -69,11 +65,12 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques } webhookDetails.EndPoint = refs.StringValue(params.Endpoint) } - if params.Enabled != nil && webhookDetails.Enabled != refs.BoolValue(params.Enabled) { webhookDetails.Enabled = refs.BoolValue(params.Enabled) } - + if params.EventDescription != nil && webhookDetails.EventDescription != refs.StringValue(params.EventDescription) { + webhookDetails.EventDescription = refs.StringValue(params.EventDescription) + } if params.Headers != nil { headerBytes, err := json.Marshal(params.Headers) if err != nil { @@ -83,12 +80,10 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques webhookDetails.Headers = string(headerBytes) } - _, err = db.Provider.UpdateWebhook(ctx, webhookDetails) if err != nil { return nil, err } - return &model.Response{ Message: `Webhook updated successfully.`, }, nil diff --git a/server/test/add_webhook_test.go b/server/test/add_webhook_test.go index 3068740..500049f 100644 --- a/server/test/add_webhook_test.go +++ b/server/test/add_webhook_test.go @@ -3,11 +3,13 @@ package test import ( "fmt" "testing" + "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -21,7 +23,6 @@ func addWebhookTest(t *testing.T, s TestSetup) { h, err := crypto.EncryptPassword(adminSecret) assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) - for _, eventType := range s.TestInfo.TestWebhookEventTypes { webhook, err := resolvers.AddWebhookResolver(ctx, model.AddWebhookRequest{ EventName: eventType, @@ -35,5 +36,21 @@ func addWebhookTest(t *testing.T, s TestSetup) { assert.NotNil(t, webhook) assert.NotEmpty(t, webhook.Message) } + time.Sleep(2 * time.Second) + // Allow setting multiple webhooks for same event + for _, eventType := range s.TestInfo.TestWebhookEventTypes { + webhook, err := resolvers.AddWebhookResolver(ctx, model.AddWebhookRequest{ + EventName: eventType, + Endpoint: s.TestInfo.WebhookEndpoint, + Enabled: true, + EventDescription: refs.NewStringRef(eventType + "-2"), + Headers: map[string]interface{}{ + "x-test": "foo", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, webhook) + assert.NotEmpty(t, webhook.Message) + } }) } diff --git a/server/test/admin_signup_test.go b/server/test/admin_signup_test.go index fc0b13a..9596f4d 100644 --- a/server/test/admin_signup_test.go +++ b/server/test/admin_signup_test.go @@ -25,6 +25,6 @@ func adminSignupTests(t *testing.T, s TestSetup) { _, err = resolvers.AdminSignupResolver(ctx, model.AdminSignupInput{ AdminSecret: "admin123", }) - assert.Nil(t, err) + assert.NoError(t, err) }) } diff --git a/server/test/delete_webhook_test.go b/server/test/delete_webhook_test.go index 55df1ae..ab9b9f2 100644 --- a/server/test/delete_webhook_test.go +++ b/server/test/delete_webhook_test.go @@ -25,7 +25,7 @@ func deleteWebhookTest(t *testing.T, s TestSetup) { // get all webhooks webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{ - Limit: 10, + Limit: 20, Page: 1, Offset: 0, }) @@ -42,7 +42,7 @@ func deleteWebhookTest(t *testing.T, s TestSetup) { } webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{ - Limit: 10, + Limit: 20, Page: 1, Offset: 0, }) diff --git a/server/test/enable_access_test.go b/server/test/enable_access_test.go index be49efc..9549968 100644 --- a/server/test/enable_access_test.go +++ b/server/test/enable_access_test.go @@ -23,6 +23,8 @@ func enableAccessTest(t *testing.T, s TestSetup) { }) assert.NoError(t, err) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) diff --git a/server/test/forgot_password_test.go b/server/test/forgot_password_test.go index cd09a0a..fd8a3bd 100644 --- a/server/test/forgot_password_test.go +++ b/server/test/forgot_password_test.go @@ -15,17 +15,18 @@ func forgotPasswordTest(t *testing.T, s TestSetup) { t.Run(`should run forgot password`, func(t *testing.T) { _, ctx := createContext(s) email := "forgot_password." + s.TestInfo.Email - _, err := resolvers.SignupResolver(ctx, model.SignUpInput{ + res, err := resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - - _, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{ + assert.NoError(t, err) + assert.NotNil(t, res) + forgotPasswordRes, err := resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{ Email: email, }) assert.Nil(t, err, "no errors for forgot password") - + assert.NotNil(t, forgotPasswordRes) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword) assert.Nil(t, err) diff --git a/server/test/invite_member_test.go b/server/test/invite_member_test.go index 42bc017..d26b5ed 100644 --- a/server/test/invite_member_test.go +++ b/server/test/invite_member_test.go @@ -41,21 +41,20 @@ func inviteUserTest(t *testing.T, s TestSetup) { res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{ Emails: invalidEmailsTest, }) - + assert.Error(t, err) + assert.Nil(t, res) // valid test res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{ Emails: emails, }) assert.Nil(t, err) assert.NotNil(t, res) - // duplicate error test res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{ Emails: emails, }) assert.Error(t, err) assert.Nil(t, res) - cleanData(emails[0]) }) } diff --git a/server/test/login_test.go b/server/test/login_test.go index 3a85d0f..3599efc 100644 --- a/server/test/login_test.go +++ b/server/test/login_test.go @@ -16,12 +16,13 @@ func loginTests(t *testing.T, s TestSetup) { t.Run(`should login`, func(t *testing.T) { _, ctx := createContext(s) email := "login." + s.TestInfo.Email - _, err := resolvers.SignupResolver(ctx, model.SignUpInput{ + signUpRes, err := resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - + assert.NoError(t, err) + assert.NotNil(t, signUpRes) res, err := resolvers.LoginResolver(ctx, model.LoginInput{ Email: email, Password: s.TestInfo.Password, @@ -30,6 +31,8 @@ func loginTests(t *testing.T, s TestSetup) { assert.NotNil(t, err, "should fail because email is not verified") assert.Nil(t, res) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) n, err := utils.EncryptNonce(verificationRequest.Nonce) assert.NoError(t, err) assert.NotEmpty(t, n) diff --git a/server/test/logout_test.go b/server/test/logout_test.go index 3e1c6ff..3d95cf5 100644 --- a/server/test/logout_test.go +++ b/server/test/logout_test.go @@ -20,22 +20,24 @@ func logoutTests(t *testing.T, s TestSetup) { req, ctx := createContext(s) email := "logout." + s.TestInfo.Email - _, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ + magicLoginRes, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) - + assert.NoError(t, err) + assert.NotNil(t, magicLoginRes) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) - + assert.NoError(t, err) + assert.NotNil(t, verifyRes) accessToken := *verifyRes.AccessToken assert.NotEmpty(t, accessToken) - claims, err := token.ParseJWTToken(accessToken) assert.NoError(t, err) assert.NotEmpty(t, claims) - loginMethod := claims["login_method"] sessionKey := verifyRes.User.ID if loginMethod != nil && loginMethod != "" { diff --git a/server/test/magic_link_login_test.go b/server/test/magic_link_login_test.go index b2cff2c..03b9c86 100644 --- a/server/test/magic_link_login_test.go +++ b/server/test/magic_link_login_test.go @@ -30,6 +30,8 @@ func magicLinkLoginTests(t *testing.T, s TestSetup) { assert.Nil(t, err, "signup should be successful") verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) diff --git a/server/test/mobile_signup_test.go b/server/test/mobile_signup_test.go index 0135a66..11deccc 100644 --- a/server/test/mobile_signup_test.go +++ b/server/test/mobile_signup_test.go @@ -29,24 +29,25 @@ func mobileSingupTest(t *testing.T, s TestSetup) { Password: "test", ConfirmPassword: "test", }) - assert.NotNil(t, err, "invalid password") - + assert.Error(t, err) + assert.Nil(t, res) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true) res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ Email: refs.NewStringRef(email), Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.NotNil(t, err, "singup disabled") + assert.Error(t, err) + assert.Nil(t, res) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false) - memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableMobileBasicAuthentication, true) res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ Email: refs.NewStringRef(email), Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.NotNil(t, err, "singup disabled") + assert.Error(t, err) + assert.Nil(t, res) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableMobileBasicAuthentication, false) res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ @@ -54,14 +55,16 @@ func mobileSingupTest(t *testing.T, s TestSetup) { Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.NotNil(t, err, "invalid mobile") + assert.Error(t, err) + assert.Nil(t, res) res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ PhoneNumber: "test", Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.NotNil(t, err, "invalid mobile") + assert.Error(t, err) + assert.Nil(t, res) res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ PhoneNumber: "1234567890", @@ -77,7 +80,8 @@ func mobileSingupTest(t *testing.T, s TestSetup) { Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.Error(t, err, "user exists") + assert.Error(t, err) + assert.Nil(t, res) cleanData(email) cleanData("1234567890@authorizer.dev") diff --git a/server/test/profile_test.go b/server/test/profile_test.go index edb52a0..81fc6b7 100644 --- a/server/test/profile_test.go +++ b/server/test/profile_test.go @@ -27,6 +27,8 @@ func profileTests(t *testing.T, s TestSetup) { assert.NotNil(t, err, "unauthorized") verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) diff --git a/server/test/resend_otp_test.go b/server/test/resend_otp_test.go index 2ba256c..73e715d 100644 --- a/server/test/resend_otp_test.go +++ b/server/test/resend_otp_test.go @@ -44,10 +44,11 @@ func resendOTPTest(t *testing.T, s TestSetup) { // Using access token update profile s.GinContext.Request.Header.Set("Authorization", "Bearer "+refs.StringValue(verifyRes.AccessToken)) ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) - _, err = resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ + updateRes, err := resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ IsMultiFactorAuthEnabled: refs.NewBoolRef(true), }) - + assert.NoError(t, err) + assert.NotNil(t, updateRes) // Resend otp should return error as no initial opt is being sent resendOtpRes, err := resolvers.ResendOTPResolver(ctx, model.ResendOTPRequest{ Email: email, @@ -87,7 +88,7 @@ func resendOTPTest(t *testing.T, s TestSetup) { Otp: otp.Otp, }) assert.Error(t, err) - + assert.Nil(t, verifyOtpRes) verifyOtpRes, err = resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ Email: email, Otp: newOtp.Otp, diff --git a/server/test/resend_verify_email_test.go b/server/test/resend_verify_email_test.go index b3420b0..4119762 100644 --- a/server/test/resend_verify_email_test.go +++ b/server/test/resend_verify_email_test.go @@ -19,13 +19,12 @@ func resendVerifyEmailTests(t *testing.T, s TestSetup) { Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - + assert.NoError(t, err) _, err = resolvers.ResendVerifyEmailResolver(ctx, model.ResendVerifyEmailInput{ Email: email, Identifier: constants.VerificationTypeBasicAuthSignup, }) - - assert.Nil(t, err) + assert.NoError(t, err) cleanData(email) }) diff --git a/server/test/reset_password_test.go b/server/test/reset_password_test.go index f33e774..8b0aa6f 100644 --- a/server/test/reset_password_test.go +++ b/server/test/reset_password_test.go @@ -20,7 +20,7 @@ func resetPasswordTest(t *testing.T, s TestSetup) { Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - + assert.NoError(t, err) _, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{ Email: email, }) @@ -28,7 +28,7 @@ func resetPasswordTest(t *testing.T, s TestSetup) { verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword) assert.Nil(t, err, "should get forgot password request") - + assert.NotNil(t, verificationRequest) _, err = resolvers.ResetPasswordResolver(ctx, model.ResetPasswordInput{ Token: verificationRequest.Token, Password: "test1", diff --git a/server/test/revoke_access_test.go b/server/test/revoke_access_test.go index 4223a0d..4be042d 100644 --- a/server/test/revoke_access_test.go +++ b/server/test/revoke_access_test.go @@ -23,6 +23,8 @@ func revokeAccessTest(t *testing.T, s TestSetup) { }) assert.NoError(t, err) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) @@ -33,7 +35,7 @@ func revokeAccessTest(t *testing.T, s TestSetup) { UserID: verifyRes.User.ID, }) assert.Error(t, err) - + assert.Nil(t, res) adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) diff --git a/server/test/session_test.go b/server/test/session_test.go index 86b848a..273904b 100644 --- a/server/test/session_test.go +++ b/server/test/session_test.go @@ -30,10 +30,13 @@ func sessionTests(t *testing.T, s TestSetup) { assert.NotNil(t, err, "unauthorized") verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) - + assert.NoError(t, err) + assert.NotNil(t, verifyRes) accessToken := *verifyRes.AccessToken assert.NotEmpty(t, accessToken) diff --git a/server/test/signup_test.go b/server/test/signup_test.go index a66906c..85d179d 100644 --- a/server/test/signup_test.go +++ b/server/test/signup_test.go @@ -22,14 +22,14 @@ func signupTests(t *testing.T, s TestSetup) { ConfirmPassword: s.TestInfo.Password + "s", }) assert.NotNil(t, err, "invalid password") - + assert.Nil(t, res) res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: "test", ConfirmPassword: "test", }) assert.NotNil(t, err, "invalid password") - + assert.Nil(t, res) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true) res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, @@ -37,7 +37,7 @@ func signupTests(t *testing.T, s TestSetup) { ConfirmPassword: s.TestInfo.Password, }) assert.NotNil(t, err, "singup disabled") - + assert.Nil(t, res) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false) res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, @@ -48,15 +48,13 @@ func signupTests(t *testing.T, s TestSetup) { user := *res.User assert.Equal(t, email, user.Email) assert.Nil(t, res.AccessToken, "access token should be nil") - res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.NotNil(t, err, "should throw duplicate email error") - + assert.Nil(t, res) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) assert.Nil(t, err) assert.Equal(t, email, verificationRequest.Email) diff --git a/server/test/test.go b/server/test/test.go index 1fb6492..b2727ea 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -40,31 +40,49 @@ func cleanData(email string) { verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) if err == nil { err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) + if err != nil { + log.Debug("DeleteVerificationRequest err", err) + } } verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword) if err == nil { err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) + if err != nil { + log.Debug("DeleteVerificationRequest err", err) + } } verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeUpdateEmail) if err == nil { err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) + if err != nil { + log.Debug("DeleteVerificationRequest err", err) + } } verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) if err == nil { err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) + if err != nil { + log.Debug("DeleteVerificationRequest err", err) + } } otp, err := db.Provider.GetOTPByEmail(ctx, email) if err == nil { err = db.Provider.DeleteOTP(ctx, otp) + if err != nil { + log.Debug("DeleteOTP err", err) + } } dbUser, err := db.Provider.GetUserByEmail(ctx, email) if err == nil { - db.Provider.DeleteUser(ctx, dbUser) + err = db.Provider.DeleteUser(ctx, dbUser) + if err != nil { + log.Debug("DeleteUser err", err) + } } } diff --git a/server/test/update_all_users_tests.go b/server/test/update_all_users_tests.go index 6473908..375158f 100644 --- a/server/test/update_all_users_tests.go +++ b/server/test/update_all_users_tests.go @@ -17,15 +17,12 @@ func updateAllUsersTest(t *testing.T, s TestSetup) { t.Helper() t.Run("Should update all users", func(t *testing.T) { _, ctx := createContext(s) - - users := []models.User{} for i := 0; i < 10; i++ { user := models.User{ Email: fmt.Sprintf("update_all_user_%d_%s", i, s.TestInfo.Email), SignupMethods: constants.AuthRecipeMethodBasicAuth, Roles: "user", } - users = append(users, user) u, err := db.Provider.AddUser(ctx, user) assert.NoError(t, err) assert.NotNil(t, u) @@ -56,12 +53,15 @@ func updateAllUsersTest(t *testing.T, s TestSetup) { Limit: 20, Offset: 0, }) + assert.NoError(t, err) + assert.NotNil(t, listUsers) for _, u := range listUsers.Users { if utils.StringSliceContains(updateIds, u.ID) { assert.False(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) } else { assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) } + cleanData(u.Email) } }) } diff --git a/server/test/update_profile_test.go b/server/test/update_profile_test.go index d8e4f0f..70f974a 100644 --- a/server/test/update_profile_test.go +++ b/server/test/update_profile_test.go @@ -30,11 +30,13 @@ func updateProfileTests(t *testing.T, s TestSetup) { assert.NotNil(t, err, "unauthorized") verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) assert.NoError(t, err) - + assert.NotNil(t, verifyRes) s.GinContext.Request.Header.Set("Authorization", "Bearer "+*verifyRes.AccessToken) ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) diff --git a/server/test/update_webhook_test.go b/server/test/update_webhook_test.go index 07f658c..14ccb94 100644 --- a/server/test/update_webhook_test.go +++ b/server/test/update_webhook_test.go @@ -24,45 +24,73 @@ func updateWebhookTest(t *testing.T, s TestSetup) { assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get webhook - webhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) + webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) assert.NoError(t, err) - assert.NotNil(t, webhook) - // it should completely replace headers - webhook.Headers = map[string]interface{}{ - "x-new-test": "test", + assert.NotNil(t, webhooks) + assert.Equal(t, 2, len(webhooks)) + for _, webhook := range webhooks { + // it should completely replace headers + webhook.Headers = map[string]interface{}{ + "x-new-test": "test", + } + res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ + ID: webhook.ID, + Headers: webhook.Headers, + Enabled: refs.NewBoolRef(false), + Endpoint: refs.NewStringRef("https://sometest.com"), + }) + assert.NoError(t, err) + assert.NotEmpty(t, res) + assert.NotEmpty(t, res.Message) } - + if len(webhooks) == 0 { + // avoid index out of range error + return + } + // Test updating webhook name + w := webhooks[0] res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ - ID: webhook.ID, - Headers: webhook.Headers, - Enabled: refs.NewBoolRef(false), - Endpoint: refs.NewStringRef("https://sometest.com"), + ID: w.ID, + EventName: refs.NewStringRef(constants.UserAccessEnabledWebhookEvent), }) - assert.NoError(t, err) - assert.NotEmpty(t, res) - assert.NotEmpty(t, res.Message) - - updatedWebhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) + assert.NotNil(t, res) + // Check if webhooks with new name is as per expected len + accessWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserAccessEnabledWebhookEvent) assert.NoError(t, err) - assert.NotNil(t, updatedWebhook) - assert.Equal(t, webhook.ID, updatedWebhook.ID) - assert.Equal(t, refs.StringValue(webhook.EventName), refs.StringValue(updatedWebhook.EventName)) - assert.Len(t, updatedWebhook.Headers, 1) - assert.False(t, refs.BoolValue(updatedWebhook.Enabled)) - for key, val := range updatedWebhook.Headers { - assert.Equal(t, val, webhook.Headers[key]) - } - assert.Equal(t, refs.StringValue(updatedWebhook.Endpoint), "https://sometest.com") - + assert.Equal(t, 3, len(accessWebhooks)) + // Revert name change res, err = resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ - ID: webhook.ID, - Headers: webhook.Headers, - Enabled: refs.NewBoolRef(true), - Endpoint: refs.NewStringRef(s.TestInfo.WebhookEndpoint), + ID: w.ID, + EventName: refs.NewStringRef(constants.UserDeletedWebhookEvent), }) assert.NoError(t, err) - assert.NotEmpty(t, res) - assert.NotEmpty(t, res.Message) + assert.NotNil(t, res) + updatedWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) + assert.NoError(t, err) + assert.NotNil(t, updatedWebhooks) + assert.Equal(t, 2, len(updatedWebhooks)) + for _, updatedWebhook := range updatedWebhooks { + assert.Contains(t, refs.StringValue(updatedWebhook.EventName), constants.UserDeletedWebhookEvent) + assert.Len(t, updatedWebhook.Headers, 1) + assert.False(t, refs.BoolValue(updatedWebhook.Enabled)) + foundUpdatedHeader := false + for key, val := range updatedWebhook.Headers { + if key == "x-new-test" && val == "test" { + foundUpdatedHeader = true + } + } + assert.True(t, foundUpdatedHeader) + assert.Equal(t, "https://sometest.com", refs.StringValue(updatedWebhook.Endpoint)) + res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ + ID: updatedWebhook.ID, + Headers: updatedWebhook.Headers, + Enabled: refs.NewBoolRef(true), + Endpoint: refs.NewStringRef(s.TestInfo.WebhookEndpoint), + }) + assert.NoError(t, err) + assert.NotEmpty(t, res) + assert.NotEmpty(t, res.Message) + } }) } diff --git a/server/test/users_test.go b/server/test/users_test.go index 96e6537..26d7a61 100644 --- a/server/test/users_test.go +++ b/server/test/users_test.go @@ -34,7 +34,7 @@ func usersTest(t *testing.T, s TestSetup) { usersRes, err := resolvers.UsersResolver(ctx, pagination) assert.NotNil(t, err, "unauthorized") - + assert.Nil(t, usersRes) adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) h, err := crypto.EncryptPassword(adminSecret) diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go index 52ce50b..e2fcf8c 100644 --- a/server/test/validate_jwt_token_test.go +++ b/server/test/validate_jwt_token_test.go @@ -53,6 +53,8 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID nonce := uuid.New().String() authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth, nonce, "") + assert.NoError(t, err) + assert.NotNil(t, authToken) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) @@ -74,8 +76,8 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { Token: authToken.AccessToken.Token, Roles: []string{"invalid_role"}, }) - assert.Error(t, err) + assert.Nil(t, res) }) t.Run(`should validate the refresh token`, func(t *testing.T) { diff --git a/server/test/verification_requests_test.go b/server/test/verification_requests_test.go index 8cbb762..0d0ce65 100644 --- a/server/test/verification_requests_test.go +++ b/server/test/verification_requests_test.go @@ -17,17 +17,14 @@ func verificationRequestsTest(t *testing.T, s TestSetup) { t.Run(`should get verification requests with admin secret only`, func(t *testing.T) { req, ctx := createContext(s) - email := "verification_requests." + s.TestInfo.Email res, err := resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - assert.NoError(t, err) assert.NotNil(t, res) - limit := int64(10) page := int64(1) pagination := &model.PaginatedInput{ @@ -39,6 +36,7 @@ func verificationRequestsTest(t *testing.T, s TestSetup) { requests, err := resolvers.VerificationRequestsResolver(ctx, pagination) assert.NotNil(t, err, "unauthorized") + assert.Nil(t, requests) adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) diff --git a/server/test/verify_email_test.go b/server/test/verify_email_test.go index 146dd27..2a07656 100644 --- a/server/test/verify_email_test.go +++ b/server/test/verify_email_test.go @@ -20,7 +20,8 @@ func verifyEmailTest(t *testing.T, s TestSetup) { Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) - + assert.NoError(t, err) + assert.NotNil(t, res) user := *res.User assert.Equal(t, email, user.Email) assert.Nil(t, res.AccessToken, "access token should be nil") diff --git a/server/test/webhook_logs_test.go b/server/test/webhook_logs_test.go index 8b9a04a..25210a3 100644 --- a/server/test/webhook_logs_test.go +++ b/server/test/webhook_logs_test.go @@ -29,17 +29,20 @@ func webhookLogsTest(t *testing.T, s TestSetup) { assert.NoError(t, err) assert.Greater(t, len(webhookLogs.WebhookLogs), 1) - webhooks, err := resolvers.WebhooksResolver(ctx, nil) + webhooks, err := resolvers.WebhooksResolver(ctx, &model.PaginatedInput{ + Pagination: &model.PaginationInput{ + Limit: refs.NewInt64Ref(20), + }, + }) assert.NoError(t, err) assert.NotEmpty(t, webhooks) - for _, w := range webhooks.Webhooks { t.Run(fmt.Sprintf("should get webhook for webhook_id:%s", w.ID), func(t *testing.T) { webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{ WebhookID: &w.ID, }) assert.NoError(t, err) - assert.GreaterOrEqual(t, len(webhookLogs.WebhookLogs), 1) + assert.GreaterOrEqual(t, len(webhookLogs.WebhookLogs), 1, refs.StringValue(w.EventName)) for _, wl := range webhookLogs.WebhookLogs { assert.Equal(t, refs.StringValue(wl.WebhookID), w.ID) } diff --git a/server/test/webhook_test.go b/server/test/webhook_test.go index 4bbe464..0fb789f 100644 --- a/server/test/webhook_test.go +++ b/server/test/webhook_test.go @@ -25,18 +25,20 @@ func webhookTest(t *testing.T, s TestSetup) { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get webhook by event name - webhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) + webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) assert.NoError(t, err) - assert.NotNil(t, webhook) - - res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{ - ID: webhook.ID, - }) - assert.NoError(t, err) - assert.Equal(t, res.ID, webhook.ID) - assert.Equal(t, refs.StringValue(res.Endpoint), refs.StringValue(webhook.Endpoint)) - assert.Equal(t, refs.StringValue(res.EventName), refs.StringValue(webhook.EventName)) - assert.Equal(t, refs.BoolValue(res.Enabled), refs.BoolValue(webhook.Enabled)) - assert.Len(t, res.Headers, len(webhook.Headers)) + assert.NotNil(t, webhooks) + assert.Equal(t, 2, len(webhooks)) + for _, webhook := range webhooks { + res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{ + ID: webhook.ID, + }) + assert.NoError(t, err) + assert.Equal(t, res.ID, webhook.ID) + assert.Equal(t, refs.StringValue(res.Endpoint), refs.StringValue(webhook.Endpoint)) + // assert.Equal(t, refs.StringValue(res.EventName), refs.StringValue(webhook.EventName)) + assert.Equal(t, refs.BoolValue(res.Enabled), refs.BoolValue(webhook.Enabled)) + assert.Len(t, res.Headers, len(webhook.Headers)) + } }) } diff --git a/server/test/webhooks_test.go b/server/test/webhooks_test.go index b4ec561..6ed1bb2 100644 --- a/server/test/webhooks_test.go +++ b/server/test/webhooks_test.go @@ -6,7 +6,9 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" + "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -21,9 +23,13 @@ func webhooksTest(t *testing.T, s TestSetup) { assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) - webhooks, err := resolvers.WebhooksResolver(ctx, nil) + webhooks, err := resolvers.WebhooksResolver(ctx, &model.PaginatedInput{ + Pagination: &model.PaginationInput{ + Limit: refs.NewInt64Ref(20), + }, + }) assert.NoError(t, err) assert.NotEmpty(t, webhooks) - assert.Len(t, webhooks.Webhooks, len(s.TestInfo.TestWebhookEventTypes)) + assert.Len(t, webhooks.Webhooks, len(s.TestInfo.TestWebhookEventTypes)*2) }) } diff --git a/server/utils/webhook.go b/server/utils/webhook.go index acacfbf..705c571 100644 --- a/server/utils/webhook.go +++ b/server/utils/webhook.go @@ -17,98 +17,101 @@ import ( ) func RegisterEvent(ctx context.Context, eventName string, authRecipe string, user models.User) error { - webhook, err := db.Provider.GetWebhookByEventName(ctx, eventName) + webhooks, err := db.Provider.GetWebhookByEventName(ctx, eventName) if err != nil { + log.Debug("Error getting webhook: %v", err) return err } + for _, webhook := range webhooks { + if !refs.BoolValue(webhook.Enabled) { + continue + } + userBytes, err := json.Marshal(user.AsAPIUser()) + if err != nil { + log.Debug("error marshalling user obj: ", err) + continue + } + userMap := map[string]interface{}{} + err = json.Unmarshal(userBytes, &userMap) + if err != nil { + log.Debug("error un-marshalling user obj: ", err) + continue + } - if !refs.BoolValue(webhook.Enabled) { - return nil - } + reqBody := map[string]interface{}{ + "webhook_id": webhook.ID, + "event_name": eventName, + "event_description": webhook.EventDescription, + "user": userMap, + } - userBytes, err := json.Marshal(user.AsAPIUser()) - if err != nil { - log.Debug("error marshalling user obj: ", err) - return err - } - userMap := map[string]interface{}{} - err = json.Unmarshal(userBytes, &userMap) - if err != nil { - log.Debug("error un-marshalling user obj: ", err) - return err - } + if eventName == constants.UserLoginWebhookEvent || eventName == constants.UserSignUpWebhookEvent { + reqBody["auth_recipe"] = authRecipe + } - reqBody := map[string]interface{}{ - "event_name": eventName, - "user": userMap, - } + requestBody, err := json.Marshal(reqBody) + if err != nil { + log.Debug("error marshalling requestBody obj: ", err) + continue + } - if eventName == constants.UserLoginWebhookEvent || eventName == constants.UserSignUpWebhookEvent { - reqBody["auth_recipe"] = authRecipe - } + // dont trigger webhook call in case of test + envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) + if err != nil { + continue + } + if envKey == constants.TestEnv { + _, err := db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + HttpStatus: 200, + Request: string(requestBody), + Response: string(`{"message": "test"}`), + WebhookID: webhook.ID, + }) + if err != nil { + log.Debug("error saving webhook log:", err) + } + continue + } - requestBody, err := json.Marshal(reqBody) - if err != nil { - log.Debug("error marshalling requestBody obj: ", err) - return err - } + requestBytesBuffer := bytes.NewBuffer(requestBody) + req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer) + if err != nil { + log.Debug("error creating webhook post request: ", err) + continue + } + req.Header.Set("Content-Type", "application/json") - // dont trigger webhook call in case of test - envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) - if err != nil { - return err - } - if envKey == constants.TestEnv { - db.Provider.AddWebhookLog(ctx, models.WebhookLog{ - HttpStatus: 200, + if webhook.Headers != nil { + for key, val := range webhook.Headers { + req.Header.Set(key, val.(string)) + } + } + + client := &http.Client{Timeout: time.Second * 30} + resp, err := client.Do(req) + if err != nil { + log.Debug("error making request: ", err) + continue + } + defer resp.Body.Close() + + responseBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Debug("error reading response: ", err) + continue + } + + statusCode := int64(resp.StatusCode) + _, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + HttpStatus: statusCode, Request: string(requestBody), - Response: string(`{"message": "test"}`), + Response: string(responseBytes), WebhookID: webhook.ID, }) - - return nil - } - - requestBytesBuffer := bytes.NewBuffer(requestBody) - req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer) - if err != nil { - log.Debug("error creating webhook post request: ", err) - return err - } - req.Header.Set("Content-Type", "application/json") - - if webhook.Headers != nil { - for key, val := range webhook.Headers { - req.Header.Set(key, val.(string)) + if err != nil { + log.Debug("failed to add webhook log: ", err) + continue } } - - client := &http.Client{Timeout: time.Second * 30} - resp, err := client.Do(req) - if err != nil { - log.Debug("error making request: ", err) - return err - } - defer resp.Body.Close() - - responseBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Debug("error reading response: ", err) - return err - } - - statusCode := int64(resp.StatusCode) - _, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{ - HttpStatus: statusCode, - Request: string(requestBody), - Response: string(responseBytes), - WebhookID: webhook.ID, - }) - - if err != nil { - log.Debug("failed to add webhook log: ", err) - return err - } - return nil }