From deaf1e2ff750af5e8ef0f308ca1ff4397c1dcb21 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Sun, 26 Mar 2023 07:20:45 +0530 Subject: [PATCH] fix: allow multiple hooks for same event --- server/db/models/webhook.go | 16 +- server/db/providers/arangodb/webhook.go | 44 +++-- server/db/providers/cassandradb/webhook.go | 53 +++--- server/db/providers/couchbase/webhook.go | 58 ++++--- server/db/providers/dynamodb/webhook.go | 56 +++---- server/db/providers/mongodb/webhook.go | 45 +++-- .../db/providers/provider_template/webhook.go | 14 +- server/db/providers/providers.go | 2 +- server/db/providers/sql/webhook.go | 31 ++-- server/graph/generated/generated.go | 80 ++++++++- server/graph/model/models_gen.go | 3 + server/graph/schema.graphqls | 3 + server/test/update_webhook_test.go | 72 ++++---- server/test/webhook_test.go | 26 +-- server/utils/webhook.go | 155 +++++++++--------- 15 files changed, 385 insertions(+), 273 deletions(-) diff --git a/server/db/models/webhook.go b/server/db/models/webhook.go index 0b64133..5b0d227 100644 --- a/server/db/models/webhook.go +++ b/server/db/models/webhook.go @@ -10,11 +10,15 @@ import ( // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation +// Event name has been kept unique as per initial design. But later on decided that we can have +// multiple hooks for same event so will be in a pattern `event_name-TIMESTAMP` + // Webhook model for db type Webhook struct { Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` EventName string `gorm:"unique" json:"event_name" bson:"event_name" cql:"event_name" dynamo:"event_name" index:"event_name,hash"` + Title string `json:"title" bson:"title" cql:"title" dynamo:"title"` EndPoint string `json:"endpoint" bson:"endpoint" cql:"endpoint" dynamo:"endpoint"` Headers string `json:"headers" bson:"headers" cql:"headers" dynamo:"headers"` Enabled bool `json:"enabled" bson:"enabled" cql:"enabled" dynamo:"enabled"` @@ -26,14 +30,22 @@ type Webhook struct { func (w *Webhook) AsAPIWebhook() *model.Webhook { headersMap := make(map[string]interface{}) json.Unmarshal([]byte(w.Headers), &headersMap) - id := w.ID if strings.Contains(id, Collections.Webhook+"/") { id = strings.TrimPrefix(id, Collections.Webhook+"/") } - + // If event name contains timestamp trim that part + if strings.Contains(w.EventName, "-") { + splitData := strings.Split(w.EventName, "-") + w.EventName = splitData[0] + } + // set default title to event name without dot(.) + if w.Title == "" { + w.Title = strings.Join(strings.Split(w.EventName, "."), " ") + } return &model.Webhook{ ID: id, + Title: refs.NewStringRef(w.Title), EventName: refs.NewStringRef(w.EventName), Endpoint: refs.NewStringRef(w.EndPoint), Headers: headersMap, diff --git a/server/db/providers/arangodb/webhook.go b/server/db/providers/arangodb/webhook.go index 0e89b86..5124e66 100644 --- a/server/db/providers/arangodb/webhook.go +++ b/server/db/providers/arangodb/webhook.go @@ -3,8 +3,10 @@ package arangodb import ( "context" "fmt" + "strings" "time" + "github.com/arangodb/go-driver" arangoDriver "github.com/arangodb/go-driver" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" @@ -17,8 +19,12 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod webhook.ID = uuid.New().String() webhook.Key = webhook.ID } - webhook.Key = webhook.ID + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") + } + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook) @@ -32,12 +38,15 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook) meta, err := webhookCollection.UpdateDocument(ctx, webhook.Key, webhook) if err != nil { return nil, err } - webhook.Key = meta.Key webhook.ID = meta.ID.String() return webhook.AsAPIWebhook(), nil @@ -55,10 +64,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) return nil, err } defer cursor.Close() - paginationClone := pagination paginationClone.Total = cursor.Statistics().FullCount() - for { var webhook models.Webhook meta, err := cursor.ReadDocument(ctx, &webhook) @@ -87,13 +94,11 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model bindVars := map[string]interface{}{ "webhook_id": webhookID, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return nil, err } defer cursor.Close() - for { if !cursor.HasMore() { if webhook.Key == "" { @@ -110,32 +115,28 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook - query := fmt.Sprintf("FOR d in %s FILTER d.event_name == @event_name RETURN d", models.Collections.Webhook) +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + query := fmt.Sprintf("FOR d in %s FILTER d.event_name LIKE @event_name RETURN d", models.Collections.Webhook) bindVars := map[string]interface{}{ - "event_name": eventName, + "event_name": eventName + "%", } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return nil, err } defer cursor.Close() - + webhooks := []*model.Webhook{} for { - if !cursor.HasMore() { - if webhook.Key == "" { - return nil, fmt.Errorf("webhook not found") - } + var webhook models.Webhook + if _, err := cursor.ReadDocument(ctx, &webhook); driver.IsNoMoreDocuments(err) { + // We're done break - } - _, err := cursor.ReadDocument(ctx, &webhook) - if err != nil { + } else if err != nil { return nil, err } + webhooks = append(webhooks, webhook.AsAPIWebhook()) } - return webhook.AsAPIWebhook(), nil + return webhooks, nil } // DeleteWebhook to delete webhook @@ -145,17 +146,14 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er if err != nil { return err } - query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog) bindVars := map[string]interface{}{ "webhook_id": webhook.ID, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return err } defer cursor.Close() - return nil } diff --git a/server/db/providers/cassandradb/webhook.go b/server/db/providers/cassandradb/webhook.go index 1954052..a8255d6 100644 --- a/server/db/providers/cassandradb/webhook.go +++ b/server/db/providers/cassandradb/webhook.go @@ -19,29 +19,29 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - - existingHook, _ := p.GetWebhookByEventName(ctx, webhook.EventName) - if existingHook != nil { - return nil, fmt.Errorf("Webhook with %s event_name already exists", webhook.EventName) + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") } - - insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt) + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + insertQuery := fmt.Sprintf("INSERT INTO %s (id, title, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.Title, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt) err := p.db.Query(insertQuery).Exec() if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil } // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() - + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } bytes, err := json.Marshal(webhook) if err != nil { return nil, err @@ -54,22 +54,18 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* if err != nil { return nil, err } - updateFields := "" for key, value := range webhookMap { if key == "_id" { continue } - if key == "_key" { continue } - if value == nil { updateFields += fmt.Sprintf("%s = null,", key) continue } - valueType := reflect.TypeOf(value) if valueType.Name() == "string" { updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string)) @@ -79,7 +75,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } updateFields = strings.Trim(updateFields, " ") updateFields = strings.TrimSuffix(updateFields, ",") - query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID) err = p.db.Query(query).Exec() if err != nil { @@ -92,24 +87,21 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} paginationClone := pagination - totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.Webhook) err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) if err != nil { return nil, err } - // there is no offset in cassandra // so we fetch till limit + offset // and return the results from offset to limit - query := fmt.Sprintf("SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.Webhook, pagination.Limit+pagination.Offset) - + query := fmt.Sprintf("SELECT id, title, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.Webhook, pagination.Limit+pagination.Offset) scanner := p.db.Query(query).Iter().Scanner() counter := int64(0) for scanner.Next() { if counter >= pagination.Offset { var webhook models.Webhook - err := scanner.Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) + err := scanner.Scan(&webhook.ID, &webhook.Title, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) if err != nil { return nil, err } @@ -127,8 +119,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { var webhook models.Webhook - query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID) - err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) + query := fmt.Sprintf(`SELECT id, title, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID) + err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.Title, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) if err != nil { return nil, err } @@ -136,14 +128,19 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook - query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName) - err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) - if err != nil { - return nil, err +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + query := fmt.Sprintf(`SELECT id, title, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name LIKE '%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName+"%s") + scanner := p.db.Query(query).Iter().Scanner() + webhooks := []*model.Webhook{} + for scanner.Next() { + var webhook models.Webhook + err := scanner.Scan(&webhook.ID, &webhook.Title, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) + if err != nil { + return nil, err + } + webhooks = append(webhooks, webhook.AsAPIWebhook()) } - return webhook.AsAPIWebhook(), nil + return webhooks, nil } // DeleteWebhook to delete webhook diff --git a/server/db/providers/couchbase/webhook.go b/server/db/providers/couchbase/webhook.go index e438263..ff22c49 100644 --- a/server/db/providers/couchbase/webhook.go +++ b/server/db/providers/couchbase/webhook.go @@ -19,11 +19,14 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") + } + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) insertOpt := gocb.InsertOptions{ Context: ctx, } @@ -37,7 +40,10 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() - + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } bytes, err := json.Marshal(webhook) if err != nil { return nil, err @@ -50,17 +56,13 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* if err != nil { return nil, err } - updateFields, params := GetSetFields(webhookMap) - query := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE _id='%s'`, p.scopeName, models.Collections.Webhook, updateFields, webhook.ID) - _, err = p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, NamedParameters: params, }) - if err != nil { return nil, err } @@ -72,7 +74,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} paginationClone := pagination - params := make(map[string]interface{}, 1) params["offset"] = paginationClone.Offset params["limit"] = paginationClone.Limit @@ -81,7 +82,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) return nil, err } paginationClone.Total = total - query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook) + query := fmt.Sprintf("SELECT _id, title, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook) queryResult, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, @@ -89,6 +90,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) }) if err != nil { return nil, err + } else if err := queryResult.Err(); err != nil { + return nil, err } for queryResult.Next() { var webhook models.Webhook @@ -98,9 +101,6 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } webhooks = append(webhooks, webhook.AsAPIWebhook()) } - if err := queryResult.Err(); err != nil { - return nil, err - } return &model.Webhooks{ Pagination: &paginationClone, Webhooks: webhooks, @@ -110,11 +110,9 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { var webhook models.Webhook - params := make(map[string]interface{}, 1) params["_id"] = webhookID - - query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook) + query := fmt.Sprintf(`SELECT _id, title, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook) q, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, @@ -124,37 +122,37 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model return nil, err } err = q.One(&webhook) - if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { params := make(map[string]interface{}, 1) - params["event_name"] = eventName - - query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE event_name=$event_name LIMIT 1`, p.scopeName, models.Collections.Webhook) - q, err := p.db.Query(query, &gocb.QueryOptions{ + params["event_name"] = eventName + "%" + query := fmt.Sprintf(`SELECT _id, title, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE event_name LIKE $event_name`, p.scopeName, models.Collections.Webhook) + queryResult, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, NamedParameters: params, }) - if err != nil { return nil, err - } - err = q.One(&webhook) - - if err != nil { + } else if err := queryResult.Err(); err != nil { return nil, err } - - return webhook.AsAPIWebhook(), nil + webhooks := []*model.Webhook{} + for queryResult.Next() { + var webhook models.Webhook + err := queryResult.Row(&webhook) + if err != nil { + log.Fatal(err) + } + webhooks = append(webhooks, webhook.AsAPIWebhook()) + } + return webhooks, nil } // DeleteWebhook to delete webhook diff --git a/server/db/providers/dynamodb/webhook.go b/server/db/providers/dynamodb/webhook.go index 9cf7ec7..f2a0a62 100644 --- a/server/db/providers/dynamodb/webhook.go +++ b/server/db/providers/dynamodb/webhook.go @@ -3,28 +3,32 @@ package dynamodb import ( "context" "errors" + "fmt" + "strings" "time" + "github.com/google/uuid" + "github.com/guregu/dynamo" + "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/google/uuid" - "github.com/guregu/dynamo" ) // AddWebhook to add webhook func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { collection := p.db.Table(models.Collections.Webhook) - if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") + } + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) err := collection.Put(webhook).RunWithContext(ctx) - if err != nil { return nil, err } @@ -33,11 +37,13 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { - collection := p.db.Table(models.Collections.Webhook) - webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } + collection := p.db.Table(models.Collections.Webhook) err := UpdateByHashKey(collection, "id", webhook.ID, webhook) - if err != nil { return nil, err } @@ -51,16 +57,13 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) var lastEval dynamo.PagingKey var iter dynamo.PagingIter var iteration int64 = 0 - collection := p.db.Table(models.Collections.Webhook) paginationClone := pagination scanner := collection.Scan() count, err := scanner.Count() - if err != nil { return nil, err } - for (paginationClone.Offset + paginationClone.Limit) > iteration { iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() for iter.NextWithContext(ctx, &webhook) { @@ -75,9 +78,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) lastEval = iter.LastEvaluatedKey() iteration += paginationClone.Limit } - paginationClone.Total = count - return &model.Webhooks{ Pagination: &paginationClone, Webhooks: webhooks, @@ -88,37 +89,29 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { collection := p.db.Table(models.Collections.Webhook) var webhook models.Webhook - err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook) - if err != nil { return nil, err } - if webhook.ID == "" { return webhook.AsAPIWebhook(), errors.New("no documets found") } - return webhook.AsAPIWebhook(), nil } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + webhooks := []models.Webhook{} collection := p.db.Table(models.Collections.Webhook) - - iter := collection.Scan().Index("event_name").Filter("'event_name' = ?", eventName).Iter() - - for iter.NextWithContext(ctx, &webhook) { - return webhook.AsAPIWebhook(), nil - } - - err := iter.Err() - + err := collection.Scan().Index("event_name").Filter("contains(event_name, ?)", eventName).AllWithContext(ctx, &webhooks) if err != nil { - return webhook.AsAPIWebhook(), err + return nil, err } - return webhook.AsAPIWebhook(), nil + resWebhooks := []*model.Webhook{} + for _, w := range webhooks { + resWebhooks = append(resWebhooks, w.AsAPIWebhook()) + } + return resWebhooks, nil } // DeleteWebhook to delete webhook @@ -133,7 +126,6 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er return err } webhookLogs, errIs := p.ListWebhookLogs(ctx, pagination, webhook.ID) - for _, webhookLog := range webhookLogs.WebhookLogs { err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx) if err != nil { diff --git a/server/db/providers/mongodb/webhook.go b/server/db/providers/mongodb/webhook.go index 7b29398..12d9e04 100644 --- a/server/db/providers/mongodb/webhook.go +++ b/server/db/providers/mongodb/webhook.go @@ -2,6 +2,8 @@ package mongodb import ( "context" + "fmt" + "strings" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -16,11 +18,14 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") + } + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) _, err := webhookCollection.InsertOne(ctx, webhook) if err != nil { @@ -32,39 +37,37 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) _, err := webhookCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": webhook.ID}}, bson.M{"$set": webhook}, options.MergeUpdateOptions()) if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil } // ListWebhooks to list webhook func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { - var webhooks []*model.Webhook + webhooks := []*model.Webhook{} opts := options.Find() opts.SetLimit(pagination.Limit) opts.SetSkip(pagination.Offset) opts.SetSort(bson.M{"created_at": -1}) - paginationClone := pagination - webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) count, err := webhookCollection.CountDocuments(ctx, bson.M{}, options.Count()) if err != nil { return nil, err } - paginationClone.Total = count - cursor, err := webhookCollection.Find(ctx, bson.M{}, opts) if err != nil { return nil, err } defer cursor.Close(ctx) - for cursor.Next(ctx) { var webhook models.Webhook err := cursor.Decode(&webhook) @@ -73,7 +76,6 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } webhooks = append(webhooks, webhook.AsAPIWebhook()) } - return &model.Webhooks{ Pagination: &paginationClone, Webhooks: webhooks, @@ -92,14 +94,27 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + webhooks := []*model.Webhook{} webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) - err := webhookCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&webhook) + opts := options.Find() + opts.SetSort(bson.M{"created_at": -1}) + cursor, err := webhookCollection.Find(ctx, bson.M{ + "event_name": fmt.Sprintf("'^%s'", eventName), + }, opts) if err != nil { return nil, err } - return webhook.AsAPIWebhook(), nil + defer cursor.Close(ctx) + for cursor.Next(ctx) { + var webhook models.Webhook + err := cursor.Decode(&webhook) + if err != nil { + return nil, err + } + webhooks = append(webhooks, webhook.AsAPIWebhook()) + } + return webhooks, nil } // DeleteWebhook to delete webhook @@ -109,12 +124,10 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er if err != nil { return err } - webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection()) _, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete()) if err != nil { return err } - return nil } diff --git a/server/db/providers/provider_template/webhook.go b/server/db/providers/provider_template/webhook.go index eda82ee..0aaf2f5 100644 --- a/server/db/providers/provider_template/webhook.go +++ b/server/db/providers/provider_template/webhook.go @@ -2,6 +2,8 @@ package provider_template import ( "context" + "fmt" + "strings" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -14,16 +16,24 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") + } + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) return webhook.AsAPIWebhook(), nil } // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } return webhook.AsAPIWebhook(), nil } @@ -38,7 +48,7 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { return nil, nil } diff --git a/server/db/providers/providers.go b/server/db/providers/providers.go index 7325204..f6a2aad 100644 --- a/server/db/providers/providers.go +++ b/server/db/providers/providers.go @@ -56,7 +56,7 @@ type Provider interface { // GetWebhookByID to get webhook by id GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) // GetWebhookByEventName to get webhook by event_name - GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) + GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) // DeleteWebhook to delete webhook DeleteWebhook(ctx context.Context, webhook *model.Webhook) error diff --git a/server/db/providers/sql/webhook.go b/server/db/providers/sql/webhook.go index 93f21a4..8dba023 100644 --- a/server/db/providers/sql/webhook.go +++ b/server/db/providers/sql/webhook.go @@ -2,6 +2,8 @@ package sql import ( "context" + "fmt" + "strings" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -14,10 +16,14 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod if webhook.ID == "" { webhook.ID = uuid.New().String() } - webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() + if webhook.Title == "" { + webhook.Title = strings.Join(strings.Split(webhook.EventName, "."), " ") + } + // Add timestamp to make event name unique for legacy version + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) res := p.db.Create(&webhook) if res.Error != nil { return nil, res.Error @@ -28,33 +34,31 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() - + // Event is changed + if !strings.Contains(webhook.EventName, "-") { + webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) + } result := p.db.Save(&webhook) if result.Error != nil { return nil, result.Error } - return webhook.AsAPIWebhook(), nil } // ListWebhooks to list webhook func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { var webhooks []models.Webhook - result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks) if result.Error != nil { return nil, result.Error } - var total int64 totalRes := p.db.Model(&models.Webhook{}).Count(&total) if totalRes.Error != nil { return nil, totalRes.Error } - paginationClone := pagination paginationClone.Total = total - responseWebhooks := []*model.Webhook{} for _, w := range webhooks { responseWebhooks = append(responseWebhooks, w.AsAPIWebhook()) @@ -77,14 +81,17 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model } // GetWebhookByEventName to get webhook by event_name -func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { - var webhook models.Webhook - - result := p.db.Where("event_name = ?", eventName).First(&webhook) +func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) { + var webhooks []models.Webhook + result := p.db.Where("event_name LIKE ?", eventName+"%").Find(&webhooks) if result.Error != nil { return nil, result.Error } - return webhook.AsAPIWebhook(), nil + responseWebhooks := []*model.Webhook{} + for _, w := range webhooks { + responseWebhooks = append(responseWebhooks, w.AsAPIWebhook()) + } + return responseWebhooks, nil } // DeleteWebhook to delete webhook diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index 3eb5dce..5538be2 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -281,6 +281,7 @@ type ComplexityRoot struct { EventName func(childComplexity int) int Headers func(childComplexity int) int ID func(childComplexity int) int + Title func(childComplexity int) int UpdatedAt func(childComplexity int) int } @@ -1854,6 +1855,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Webhook.ID(childComplexity), true + case "Webhook.title": + if e.complexity.Webhook.Title == nil { + break + } + + return e.complexity.Webhook.Title(childComplexity), true + case "Webhook.updated_at": if e.complexity.Webhook.UpdatedAt == nil { break @@ -2210,6 +2218,7 @@ type GenerateJWTKeysResponse { type Webhook { id: ID! + title: String event_name: String endpoint: String enabled: Boolean @@ -2500,6 +2509,7 @@ input ListWebhookLogRequest { } input AddWebhookRequest { + title: String! event_name: String! endpoint: String! enabled: Boolean! @@ -2508,6 +2518,7 @@ input AddWebhookRequest { input UpdateWebhookRequest { id: ID! + title: String event_name: String endpoint: String enabled: Boolean @@ -10141,6 +10152,8 @@ func (ec *executionContext) fieldContext_Query__webhook(ctx context.Context, fie switch field.Name { case "id": return ec.fieldContext_Webhook_id(ctx, field) + case "title": + return ec.fieldContext_Webhook_title(ctx, field) case "event_name": return ec.fieldContext_Webhook_event_name(ctx, field) case "endpoint": @@ -12160,6 +12173,47 @@ func (ec *executionContext) fieldContext_Webhook_id(ctx context.Context, field g return fc, nil } +func (ec *executionContext) _Webhook_title(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Webhook_title(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Title, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2áš–string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Webhook_title(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Webhook", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Webhook_event_name(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Webhook_event_name(ctx, field) if err != nil { @@ -12905,6 +12959,8 @@ func (ec *executionContext) fieldContext_Webhooks_webhooks(ctx context.Context, switch field.Name { case "id": return ec.fieldContext_Webhook_id(ctx, field) + case "title": + return ec.fieldContext_Webhook_title(ctx, field) case "event_name": return ec.fieldContext_Webhook_event_name(ctx, field) case "endpoint": @@ -14756,13 +14812,21 @@ func (ec *executionContext) unmarshalInputAddWebhookRequest(ctx context.Context, asMap[k] = v } - fieldsInOrder := [...]string{"event_name", "endpoint", "enabled", "headers"} + fieldsInOrder := [...]string{"title", "event_name", "endpoint", "enabled", "headers"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { continue } switch k { + case "title": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("title")) + it.Title, err = ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } case "event_name": var err error @@ -16612,7 +16676,7 @@ func (ec *executionContext) unmarshalInputUpdateWebhookRequest(ctx context.Conte asMap[k] = v } - fieldsInOrder := [...]string{"id", "event_name", "endpoint", "enabled", "headers"} + fieldsInOrder := [...]string{"id", "title", "event_name", "endpoint", "enabled", "headers"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16627,6 +16691,14 @@ func (ec *executionContext) unmarshalInputUpdateWebhookRequest(ctx context.Conte if err != nil { return it, err } + case "title": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("title")) + it.Title, err = ec.unmarshalOString2áš–string(ctx, v) + if err != nil { + return it, err + } case "event_name": var err error @@ -18509,6 +18581,10 @@ func (ec *executionContext) _Webhook(ctx context.Context, sel ast.SelectionSet, if out.Values[i] == graphql.Null { invalids++ } + case "title": + + out.Values[i] = ec._Webhook_title(ctx, field, obj) + case "event_name": out.Values[i] = ec._Webhook_event_name(ctx, field, obj) diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 57b1aad..05da647 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -10,6 +10,7 @@ type AddEmailTemplateRequest struct { } type AddWebhookRequest struct { + Title string `json:"title"` EventName string `json:"event_name"` Endpoint string `json:"endpoint"` Enabled bool `json:"enabled"` @@ -388,6 +389,7 @@ type UpdateUserInput struct { type UpdateWebhookRequest struct { ID string `json:"id"` + Title *string `json:"title"` EventName *string `json:"event_name"` Endpoint *string `json:"endpoint"` Enabled *bool `json:"enabled"` @@ -462,6 +464,7 @@ type VerifyOTPRequest struct { type Webhook struct { ID string `json:"id"` + Title *string `json:"title"` EventName *string `json:"event_name"` Endpoint *string `json:"endpoint"` Enabled *bool `json:"enabled"` diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 9b96d13..2dd7aa9 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -168,6 +168,7 @@ type GenerateJWTKeysResponse { type Webhook { id: ID! + title: String event_name: String endpoint: String enabled: Boolean @@ -458,6 +459,7 @@ input ListWebhookLogRequest { } input AddWebhookRequest { + title: String! event_name: String! endpoint: String! enabled: Boolean! @@ -466,6 +468,7 @@ input AddWebhookRequest { input UpdateWebhookRequest { id: ID! + title: String event_name: String endpoint: String enabled: Boolean diff --git a/server/test/update_webhook_test.go b/server/test/update_webhook_test.go index 07f658c..f1863cb 100644 --- a/server/test/update_webhook_test.go +++ b/server/test/update_webhook_test.go @@ -24,45 +24,47 @@ func updateWebhookTest(t *testing.T, s TestSetup) { assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get webhook - webhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) + webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) assert.NoError(t, err) - assert.NotNil(t, webhook) - // it should completely replace headers - webhook.Headers = map[string]interface{}{ - "x-new-test": "test", + assert.NotNil(t, webhooks) + assert.Greater(t, len(webhooks), 0) + for _, webhook := range webhooks { + // it should completely replace headers + webhook.Headers = map[string]interface{}{ + "x-new-test": "test", + } + res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ + ID: webhook.ID, + Headers: webhook.Headers, + Enabled: refs.NewBoolRef(false), + Endpoint: refs.NewStringRef("https://sometest.com"), + }) + assert.NoError(t, err) + assert.NotEmpty(t, res) + assert.NotEmpty(t, res.Message) } - res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ - ID: webhook.ID, - Headers: webhook.Headers, - Enabled: refs.NewBoolRef(false), - Endpoint: refs.NewStringRef("https://sometest.com"), - }) - + updatedWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) assert.NoError(t, err) - assert.NotEmpty(t, res) - assert.NotEmpty(t, res.Message) - - updatedWebhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) - assert.NoError(t, err) - assert.NotNil(t, updatedWebhook) - assert.Equal(t, webhook.ID, updatedWebhook.ID) - assert.Equal(t, refs.StringValue(webhook.EventName), refs.StringValue(updatedWebhook.EventName)) - assert.Len(t, updatedWebhook.Headers, 1) - assert.False(t, refs.BoolValue(updatedWebhook.Enabled)) - for key, val := range updatedWebhook.Headers { - assert.Equal(t, val, webhook.Headers[key]) + assert.NotNil(t, updatedWebhooks) + for _, updatedWebhook := range updatedWebhooks { + assert.Contains(t, refs.StringValue(updatedWebhook.EventName), constants.UserDeletedWebhookEvent) + assert.Len(t, updatedWebhook.Headers, 1) + assert.False(t, refs.BoolValue(updatedWebhook.Enabled)) + for key, val := range updatedWebhook.Headers { + assert.Equal(t, "x-new-test", key) + assert.Equal(t, "test", val) + } + assert.Equal(t, "https://sometest.com", refs.StringValue(updatedWebhook.Endpoint)) + res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ + ID: updatedWebhook.ID, + Headers: updatedWebhook.Headers, + Enabled: refs.NewBoolRef(true), + Endpoint: refs.NewStringRef(s.TestInfo.WebhookEndpoint), + }) + assert.NoError(t, err) + assert.NotEmpty(t, res) + assert.NotEmpty(t, res.Message) } - assert.Equal(t, refs.StringValue(updatedWebhook.Endpoint), "https://sometest.com") - - res, err = resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ - ID: webhook.ID, - Headers: webhook.Headers, - Enabled: refs.NewBoolRef(true), - Endpoint: refs.NewStringRef(s.TestInfo.WebhookEndpoint), - }) - assert.NoError(t, err) - assert.NotEmpty(t, res) - assert.NotEmpty(t, res.Message) }) } diff --git a/server/test/webhook_test.go b/server/test/webhook_test.go index 4bbe464..f52fa62 100644 --- a/server/test/webhook_test.go +++ b/server/test/webhook_test.go @@ -25,18 +25,20 @@ func webhookTest(t *testing.T, s TestSetup) { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get webhook by event name - webhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) + webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) assert.NoError(t, err) - assert.NotNil(t, webhook) - - res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{ - ID: webhook.ID, - }) - assert.NoError(t, err) - assert.Equal(t, res.ID, webhook.ID) - assert.Equal(t, refs.StringValue(res.Endpoint), refs.StringValue(webhook.Endpoint)) - assert.Equal(t, refs.StringValue(res.EventName), refs.StringValue(webhook.EventName)) - assert.Equal(t, refs.BoolValue(res.Enabled), refs.BoolValue(webhook.Enabled)) - assert.Len(t, res.Headers, len(webhook.Headers)) + assert.NotNil(t, webhooks) + assert.Greater(t, len(webhooks), 0) + for _, webhook := range webhooks { + res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{ + ID: webhook.ID, + }) + assert.NoError(t, err) + assert.Equal(t, res.ID, webhook.ID) + assert.Equal(t, refs.StringValue(res.Endpoint), refs.StringValue(webhook.Endpoint)) + assert.Equal(t, refs.StringValue(res.EventName), refs.StringValue(webhook.EventName)) + assert.Equal(t, refs.BoolValue(res.Enabled), refs.BoolValue(webhook.Enabled)) + assert.Len(t, res.Headers, len(webhook.Headers)) + } }) } diff --git a/server/utils/webhook.go b/server/utils/webhook.go index acacfbf..6387b0d 100644 --- a/server/utils/webhook.go +++ b/server/utils/webhook.go @@ -17,98 +17,97 @@ import ( ) func RegisterEvent(ctx context.Context, eventName string, authRecipe string, user models.User) error { - webhook, err := db.Provider.GetWebhookByEventName(ctx, eventName) + webhooks, err := db.Provider.GetWebhookByEventName(ctx, eventName) if err != nil { return err } + for _, webhook := range webhooks { + if !refs.BoolValue(webhook.Enabled) { + return nil + } + userBytes, err := json.Marshal(user.AsAPIUser()) + if err != nil { + log.Debug("error marshalling user obj: ", err) + return err + } + userMap := map[string]interface{}{} + err = json.Unmarshal(userBytes, &userMap) + if err != nil { + log.Debug("error un-marshalling user obj: ", err) + return err + } - if !refs.BoolValue(webhook.Enabled) { - return nil - } + reqBody := map[string]interface{}{ + "event_name": eventName, + "user": userMap, + } - userBytes, err := json.Marshal(user.AsAPIUser()) - if err != nil { - log.Debug("error marshalling user obj: ", err) - return err - } - userMap := map[string]interface{}{} - err = json.Unmarshal(userBytes, &userMap) - if err != nil { - log.Debug("error un-marshalling user obj: ", err) - return err - } + if eventName == constants.UserLoginWebhookEvent || eventName == constants.UserSignUpWebhookEvent { + reqBody["auth_recipe"] = authRecipe + } - reqBody := map[string]interface{}{ - "event_name": eventName, - "user": userMap, - } + requestBody, err := json.Marshal(reqBody) + if err != nil { + log.Debug("error marshalling requestBody obj: ", err) + return err + } - if eventName == constants.UserLoginWebhookEvent || eventName == constants.UserSignUpWebhookEvent { - reqBody["auth_recipe"] = authRecipe - } + // dont trigger webhook call in case of test + envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) + if err != nil { + return err + } + if envKey == constants.TestEnv { + db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + HttpStatus: 200, + Request: string(requestBody), + Response: string(`{"message": "test"}`), + WebhookID: webhook.ID, + }) - requestBody, err := json.Marshal(reqBody) - if err != nil { - log.Debug("error marshalling requestBody obj: ", err) - return err - } + return nil + } - // dont trigger webhook call in case of test - envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) - if err != nil { - return err - } - if envKey == constants.TestEnv { - db.Provider.AddWebhookLog(ctx, models.WebhookLog{ - HttpStatus: 200, + requestBytesBuffer := bytes.NewBuffer(requestBody) + req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer) + if err != nil { + log.Debug("error creating webhook post request: ", err) + return err + } + req.Header.Set("Content-Type", "application/json") + + if webhook.Headers != nil { + for key, val := range webhook.Headers { + req.Header.Set(key, val.(string)) + } + } + + client := &http.Client{Timeout: time.Second * 30} + resp, err := client.Do(req) + if err != nil { + log.Debug("error making request: ", err) + return err + } + defer resp.Body.Close() + + responseBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Debug("error reading response: ", err) + return err + } + + statusCode := int64(resp.StatusCode) + _, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + HttpStatus: statusCode, Request: string(requestBody), - Response: string(`{"message": "test"}`), + Response: string(responseBytes), WebhookID: webhook.ID, }) - return nil - } - - requestBytesBuffer := bytes.NewBuffer(requestBody) - req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer) - if err != nil { - log.Debug("error creating webhook post request: ", err) - return err - } - req.Header.Set("Content-Type", "application/json") - - if webhook.Headers != nil { - for key, val := range webhook.Headers { - req.Header.Set(key, val.(string)) + if err != nil { + log.Debug("failed to add webhook log: ", err) + return err } } - - client := &http.Client{Timeout: time.Second * 30} - resp, err := client.Do(req) - if err != nil { - log.Debug("error making request: ", err) - return err - } - defer resp.Body.Close() - - responseBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Debug("error reading response: ", err) - return err - } - - statusCode := int64(resp.StatusCode) - _, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{ - HttpStatus: statusCode, - Request: string(requestBody), - Response: string(responseBytes), - WebhookID: webhook.ID, - }) - - if err != nil { - log.Debug("failed to add webhook log: ", err) - return err - } - return nil }