Compare commits

..

9 Commits

Author SHA1 Message Date
Lakhan Samani
330f35f2fc Merge pull request #322 from authorizerdev/fix/use-scopes-as-string
[server] use scope string instead of string array in tokens
2023-02-08 09:41:17 +05:30
Lakhan Samani
70242debe1 [server] fix scope response type + add extra claims to access token 2023-02-08 09:39:08 +05:30
Lakhan Samani
4018da6697 [server] use scope string instead of string array in tokens 2023-02-07 01:13:03 +05:30
Lakhan Samani
a73c6ee49e Merge pull request #319 from authorizerdev/fix/arangodb-connection 2023-02-06 21:32:32 +05:30
Lakhan Samani
c23fb1bb32 [server] update arangodb version for test 3.10.3 2023-02-06 20:39:22 +05:30
Lakhan Samani
270853a6a3 [server] add support for arangodb cert, username, password
Adding support for db username, password and ca cert will
enable users to connect arangodb cloud platform
2023-02-06 18:14:19 +05:30
Lakhan Samani
2d0346ff23 [server] remove unused error condition for couchbase 2023-02-05 11:03:26 +05:30
Lakhan Samani
4b26e1ce85 [server] fix bucket creation for couchbase
Run create bucket query only if bucket is not found.
Required while running couchbase cloud version
2023-02-05 11:01:20 +05:30
Lakhan Samani
8212e81023 [server] common util for couchbase syntax 2023-02-02 13:06:18 +05:30
19 changed files with 188 additions and 88 deletions

View File

@@ -1,4 +1,4 @@
FROM golang:1.19.1-alpine as go-builder
FROM golang:1.19.5-alpine as go-builder
WORKDIR /authorizer
COPY server server
COPY Makefile .

View File

@@ -26,7 +26,7 @@ test-scylladb:
cd server && go clean --testcache && TEST_DBS="scylladb" go test -p 1 -v ./test
docker rm -vf authorizer_scylla_db
test-arangodb:
docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.8.4
docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.10.3
cd server && go clean --testcache && TEST_DBS="arangodb" go test -p 1 -v ./test
docker rm -vf authorizer_arangodb
test-dynamodb:
@@ -42,7 +42,7 @@ test-all-db:
rm -rf server/test/test.db server/test/test.db-shm server/test/test.db-wal && rm -rf test.db test.db-shm test.db-wal
docker run -d --name authorizer_scylla_db -p 9042:9042 scylladb/scylla
docker run -d --name authorizer_mongodb_db -p 27017:27017 mongo:4.4.15
docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.8.4
docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.10.3
docker run -d --name dynamodb-local-test -p 8000:8000 amazon/dynamodb-local:latest
docker run -d --name couchbase-local-test -p 8091-8097:8091-8097 -p 11210:11210 -p 11207:11207 -p 18091-18095:18091-18095 -p 18096:18096 -p 18097:18097 couchbase:latest
sh scripts/couchbase-test.sh

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
@@ -52,7 +51,7 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.EmailTemplate, pagination.Offset, pagination.Limit)
sctx := driver.WithQueryFullCount(ctx)
sctx := arangoDriver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err

View File

@@ -2,8 +2,11 @@ package arangodb
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/arangodb/go-driver/http"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -22,44 +25,75 @@ type provider struct {
func NewProvider() (*provider, error) {
ctx := context.Background()
dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL
conn, err := http.NewConnection(http.ConnectionConfig{
dbUsername := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseUsername
dbPassword := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabasePassword
dbCACertificate := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseCACert
httpConfig := http.ConnectionConfig{
Endpoints: []string{dbURL},
})
}
// If ca certificate if present, create tls config
if dbCACertificate != "" {
caCert, err := base64.StdEncoding.DecodeString(dbCACertificate)
if err != nil {
return nil, err
}
arangoClient, err := arangoDriver.NewClient(arangoDriver.ClientConfig{
// Prepare TLS Config
tlsConfig := &tls.Config{}
certPool := x509.NewCertPool()
if success := certPool.AppendCertsFromPEM(caCert); !success {
return nil, fmt.Errorf("invalid certificate")
}
tlsConfig.RootCAs = certPool
httpConfig.TLSConfig = tlsConfig
}
// Create new http connection
conn, err := http.NewConnection(httpConfig)
if err != nil {
return nil, err
}
clientConfig := arangoDriver.ClientConfig{
Connection: conn,
})
}
if dbUsername != "" && dbPassword != "" {
clientConfig.Authentication = arangoDriver.BasicAuthentication(dbUsername, dbPassword)
}
arangoClient, err := arangoDriver.NewClient(clientConfig)
if err != nil {
return nil, err
}
var arangodb driver.Database
var arangodb arangoDriver.Database
dbName := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseName
arangodb_exists, err := arangoClient.DatabaseExists(nil, dbName)
arangodb_exists, err := arangoClient.DatabaseExists(ctx, dbName)
if err != nil {
return nil, err
}
if arangodb_exists {
arangodb, err = arangoClient.Database(nil, dbName)
arangodb, err = arangoClient.Database(ctx, dbName)
if err != nil {
return nil, err
}
} else {
arangodb, err = arangoClient.CreateDatabase(nil, dbName, nil)
arangodb, err = arangoClient.CreateDatabase(ctx, dbName, nil)
if err != nil {
return nil, err
}
}
userCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.User)
if err != nil {
return nil, err
}
if !userCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.User, nil)
if err != nil {
return nil, err
}
}
userCollection, _ := arangodb.Collection(nil, models.Collections.User)
userCollection, err := arangodb.Collection(ctx, models.Collections.User)
if err != nil {
return nil, err
}
userCollection.EnsureHashIndex(ctx, []string{"email"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
@@ -70,6 +104,9 @@ func NewProvider() (*provider, error) {
})
verificationRequestCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.VerificationRequest)
if err != nil {
return nil, err
}
if !verificationRequestCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.VerificationRequest, nil)
if err != nil {
@@ -77,7 +114,10 @@ func NewProvider() (*provider, error) {
}
}
verificationRequestCollection, _ := arangodb.Collection(nil, models.Collections.VerificationRequest)
verificationRequestCollection, err := arangodb.Collection(ctx, models.Collections.VerificationRequest)
if err != nil {
return nil, err
}
verificationRequestCollection.EnsureHashIndex(ctx, []string{"email", "identifier"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
@@ -87,6 +127,9 @@ func NewProvider() (*provider, error) {
})
sessionCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Session)
if err != nil {
return nil, err
}
if !sessionCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.Session, nil)
if err != nil {
@@ -94,13 +137,19 @@ func NewProvider() (*provider, error) {
}
}
sessionCollection, _ := arangodb.Collection(nil, models.Collections.Session)
sessionCollection, err := arangodb.Collection(ctx, models.Collections.Session)
if err != nil {
return nil, err
}
sessionCollection.EnsureHashIndex(ctx, []string{"user_id"}, &arangoDriver.EnsureHashIndexOptions{
Sparse: true,
})
configCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Env)
if !configCollectionExists {
envCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Env)
if err != nil {
return nil, err
}
if !envCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.Env, nil)
if err != nil {
return nil, err
@@ -108,6 +157,9 @@ func NewProvider() (*provider, error) {
}
webhookCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Webhook)
if err != nil {
return nil, err
}
if !webhookCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.Webhook, nil)
if err != nil {
@@ -115,13 +167,19 @@ func NewProvider() (*provider, error) {
}
}
webhookCollection, _ := arangodb.Collection(nil, models.Collections.Webhook)
webhookCollection, err := arangodb.Collection(ctx, models.Collections.Webhook)
if err != nil {
return nil, err
}
webhookCollection.EnsureHashIndex(ctx, []string{"event_name"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
})
webhookLogCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.WebhookLog)
if err != nil {
return nil, err
}
if !webhookLogCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.WebhookLog, nil)
if err != nil {
@@ -129,12 +187,18 @@ func NewProvider() (*provider, error) {
}
}
webhookLogCollection, _ := arangodb.Collection(nil, models.Collections.WebhookLog)
webhookLogCollection, err := arangodb.Collection(ctx, models.Collections.WebhookLog)
if err != nil {
return nil, err
}
webhookLogCollection.EnsureHashIndex(ctx, []string{"webhook_id"}, &arangoDriver.EnsureHashIndexOptions{
Sparse: true,
})
emailTemplateCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.EmailTemplate)
if err != nil {
return nil, err
}
if !emailTemplateCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.EmailTemplate, nil)
if err != nil {
@@ -142,13 +206,19 @@ func NewProvider() (*provider, error) {
}
}
emailTemplateCollection, _ := arangodb.Collection(nil, models.Collections.EmailTemplate)
emailTemplateCollection, err := arangodb.Collection(ctx, models.Collections.EmailTemplate)
if err != nil {
return nil, err
}
emailTemplateCollection.EnsureHashIndex(ctx, []string{"event_name"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
})
otpCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.OTP)
if err != nil {
return nil, err
}
if !otpCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.OTP, nil)
if err != nil {
@@ -156,7 +226,10 @@ func NewProvider() (*provider, error) {
}
}
otpCollection, _ := arangodb.Collection(nil, models.Collections.OTP)
otpCollection, err := arangodb.Collection(ctx, models.Collections.OTP)
if err != nil {
return nil, err
}
otpCollection.EnsureHashIndex(ctx, []string{"email"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,

View File

@@ -7,7 +7,6 @@ import (
"strings"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/google/uuid"
@@ -91,7 +90,7 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
// ListUsers to get list of users from database
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []*model.User
sctx := driver.WithQueryFullCount(ctx)
sctx := arangoDriver.WithQueryFullCount(ctx)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.User, pagination.Offset, pagination.Limit)
@@ -199,7 +198,7 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{},
}
query := ""
if ids != nil && len(ids) > 0 {
if len(ids) > 0 {
keysArray := ""
for _, id := range ids {
keysArray += fmt.Sprintf("'%s', ", id)
@@ -212,7 +211,6 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{},
}
_, err = p.db.Query(ctx, query, nil)
if err != nil {
return err
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
@@ -96,7 +96,7 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest
sctx := driver.WithQueryFullCount(ctx)
sctx := arangoDriver.WithQueryFullCount(ctx)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.VerificationRequest, pagination.Offset, pagination.Limit)
cursor, err := p.db.Query(sctx, query, nil)
@@ -112,7 +112,7 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode
var verificationRequest models.VerificationRequest
meta, err := cursor.ReadDocument(ctx, &verificationRequest)
if driver.IsNoMoreDocuments(err) {
if arangoDriver.IsNoMoreDocuments(err) {
break
} else if err != nil {
return nil, err
@@ -132,8 +132,8 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
collection, _ := p.db.Collection(nil, models.Collections.VerificationRequest)
_, err := collection.RemoveDocument(nil, verificationRequest.Key)
collection, _ := p.db.Collection(ctx, models.Collections.VerificationRequest)
_, err := collection.RemoveDocument(ctx, verificationRequest.Key)
if err != nil {
return err
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
@@ -50,7 +49,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.Webhook, pagination.Offset, pagination.Limit)
sctx := driver.WithQueryFullCount(ctx)
sctx := arangoDriver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
@@ -44,7 +43,7 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat
}
}
sctx := driver.WithQueryFullCount(ctx)
sctx := arangoDriver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, bindVariables)
if err != nil {
return nil, err

View File

@@ -70,9 +70,11 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
emailTemplates := []*model.EmailTemplate{}
paginationClone := pagination
_, paginationClone.Total = p.GetTotalDocs(ctx, models.Collections.EmailTemplate)
total, err := p.GetTotalDocs(ctx, models.Collections.EmailTemplate)
if err != nil {
return nil, err
}
paginationClone.Total = total
userQuery := fmt.Sprintf("SELECT _id, event_name, subject, design, template, created_at, updated_at FROM %s.%s ORDER BY _id OFFSET $1 LIMIT $2", p.scopeName, models.Collections.EmailTemplate)
queryResult, err := p.db.Query(userQuery, &gocb.QueryOptions{

View File

@@ -107,14 +107,23 @@ func CreateBucketAndScope(cluster *gocb.Cluster, bucketName string, scopeName st
FlushEnabled: true,
CompressionMode: gocb.CompressionModeActive,
}
shouldCreateBucket := false
// check if bucket exists
_, err = cluster.Buckets().GetBucket(bucketName, nil)
if err != nil {
// bucket not found
shouldCreateBucket = true
}
if shouldCreateBucket {
err = cluster.Buckets().CreateBucket(gocb.CreateBucketSettings{
BucketSettings: settings,
ConflictResolutionType: gocb.ConflictResolutionTypeSequenceNumber,
}, nil)
bucket := cluster.Bucket(bucketName)
if err != nil && !errors.Is(err, gocb.ErrBucketExists) {
return bucket, err
if err != nil {
return nil, err
}
}
bucket := cluster.Bucket(bucketName)
if scopeName != defaultScope {
err = bucket.Collections().CreateScope(scopeName, nil)
if err != nil && !errors.Is(err, gocb.ErrScopeExists) {

View File

@@ -44,7 +44,7 @@ func GetSetFields(webhookMap map[string]interface{}) (string, map[string]interfa
return updateFields, params
}
func (p *provider) GetTotalDocs(ctx context.Context, collection string) (error, int64) {
func (p *provider) GetTotalDocs(ctx context.Context, collection string) (int64, error) {
totalDocs := TotalDocs{}
countQuery := fmt.Sprintf("SELECT COUNT(*) as Total FROM %s.%s", p.scopeName, collection)
@@ -55,9 +55,9 @@ func (p *provider) GetTotalDocs(ctx context.Context, collection string) (error,
queryRes.One(&totalDocs)
if err != nil {
return err, totalDocs.Total
return totalDocs.Total, err
}
return nil, totalDocs.Total
return totalDocs.Total, nil
}
type TotalDocs struct {

View File

@@ -77,13 +77,14 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (
Context: ctx,
PositionalParameters: []interface{}{paginationClone.Offset, paginationClone.Limit},
})
_, paginationClone.Total = p.GetTotalDocs(ctx, models.Collections.User)
if err != nil {
return nil, err
}
total, err := p.GetTotalDocs(ctx, models.Collections.User)
if err != nil {
return nil, err
}
paginationClone.Total = total
for queryResult.Next() {
var user models.User
err := queryResult.Row(&user)
@@ -92,12 +93,9 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (
}
users = append(users, user.AsAPIUser())
}
if err := queryResult.Err(); err != nil {
return nil, err
}
return &model.Users{
Pagination: &paginationClone,
Users: users,
@@ -150,10 +148,8 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err
func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error {
// set updated_at time for all users
data["updated_at"] = time.Now().Unix()
updateFields, params := GetSetFields(data)
if ids != nil && len(ids) > 0 {
if len(ids) > 0 {
for _, id := range ids {
params["id"] = id
userQuery := fmt.Sprintf("UPDATE %s.%s SET %s WHERE _id = $id", p.scopeName, models.Collections.User, updateFields)

View File

@@ -83,16 +83,17 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest
paginationClone := pagination
_, paginationClone.Total = p.GetTotalDocs(ctx, models.Collections.VerificationRequest)
total, err := p.GetTotalDocs(ctx, models.Collections.VerificationRequest)
if err != nil {
return nil, err
}
paginationClone.Total = total
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM %s.%s OFFSET $1 LIMIT $2", p.scopeName, models.Collections.VerificationRequest)
queryResult, err := p.db.Query(query, &gocb.QueryOptions{
Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
PositionalParameters: []interface{}{paginationClone.Offset, paginationClone.Limit},
})
if err != nil {
return nil, err
}
@@ -104,7 +105,6 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode
}
verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest())
}
if err := queryResult.Err(); err != nil {
return nil, err

View File

@@ -76,17 +76,17 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
params := make(map[string]interface{}, 1)
params["offset"] = paginationClone.Offset
params["limit"] = paginationClone.Limit
total, err := p.GetTotalDocs(ctx, models.Collections.Webhook)
if err != nil {
return nil, err
}
paginationClone.Total = total
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook)
_, paginationClone.Total = p.GetTotalDocs(ctx, models.Collections.Webhook)
queryResult, err := p.db.Query(query, &gocb.QueryOptions{
Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params,
})
if err != nil {
return nil, err
}
@@ -98,10 +98,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
if err := queryResult.Err(); err != nil {
return nil, err
}
return &model.Webhooks{
Pagination: &paginationClone,

View File

@@ -45,9 +45,11 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat
params["webhookID"] = webhookID
params["offset"] = paginationClone.Offset
params["limit"] = paginationClone.Limit
_, paginationClone.Total = p.GetTotalDocs(ctx, models.Collections.WebhookLog)
total, err := p.GetTotalDocs(ctx, models.Collections.WebhookLog)
if err != nil {
return nil, err
}
paginationClone.Total = total
if webhookID != "" {
query = fmt.Sprintf(`SELECT _id, http_status, response, request, webhook_id, created_at, updated_at FROM %s.%s WHERE webhook_id=$webhookID`, p.scopeName, models.Collections.WebhookLog)
} else {

View File

@@ -76,7 +76,7 @@ func AppHandler() gin.HandlerFunc {
"data": map[string]interface{}{
"authorizerURL": hostname,
"redirectURL": redirectURI,
"scope": scope,
"scope": strings.Join(scope, " "),
"state": state,
"organizationName": orgName,
"organizationLogo": orgLogo,

View File

@@ -284,7 +284,7 @@ func AuthorizeHandler() gin.HandlerFunc {
"access_token": authToken.AccessToken.Token,
"id_token": authToken.IDToken.Token,
"state": state,
"scope": scope,
"scope": strings.Join(scope, " "),
"token_type": "Bearer",
"expires_in": authToken.AccessToken.ExpiresAt,
}

View File

@@ -259,7 +259,7 @@ func TokenHandler() gin.HandlerFunc {
res := map[string]interface{}{
"access_token": authToken.AccessToken.Token,
"id_token": authToken.IDToken.Token,
"scope": scope,
"scope": strings.Join(scope, " "),
"roles": roles,
"expires_in": expiresIn,
}

View File

@@ -162,9 +162,7 @@ func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce
if err != nil {
expiryBound = time.Minute * 30
}
expiresAt := time.Now().Add(expiryBound).Unix()
clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID)
if err != nil {
return "", 0, err
@@ -182,7 +180,41 @@ func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce
"login_method": loginMethod,
"allowed_roles": strings.Split(user.Roles, ","),
}
// check for the extra access token script
accessTokenScript, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyCustomAccessTokenScript)
if err != nil {
log.Debug("Failed to get custom access token script: ", err)
accessTokenScript = ""
}
if accessTokenScript != "" {
resUser := user.AsAPIUser()
userBytes, _ := json.Marshal(&resUser)
var userMap map[string]interface{}
json.Unmarshal(userBytes, &userMap)
vm := otto.New()
claimBytes, _ := json.Marshal(customClaims)
vm.Run(fmt.Sprintf(`
var user = %s;
var tokenPayload = %s;
var customFunction = %s;
var functionRes = JSON.stringify(customFunction(user, tokenPayload));
`, string(userBytes), string(claimBytes), accessTokenScript))
val, err := vm.Get("functionRes")
if err != nil {
log.Debug("error getting custom access token script: ", err)
} else {
extraPayload := make(map[string]interface{})
err = json.Unmarshal([]byte(fmt.Sprintf("%s", val)), &extraPayload)
if err != nil {
log.Debug("error converting accessTokenScript response to map: ", err)
} else {
for k, v := range extraPayload {
customClaims[k] = v
}
}
}
}
token, err := SignJWTToken(customClaims)
if err != nil {
return "", 0, err
@@ -345,14 +377,11 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce, atHash, cH
if err != nil {
expiryBound = time.Minute * 30
}
expiresAt := time.Now().Add(expiryBound).Unix()
resUser := user.AsAPIUser()
userBytes, _ := json.Marshal(&resUser)
var userMap map[string]interface{}
json.Unmarshal(userBytes, &userMap)
claimKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtRoleClaim)
if err != nil {
claimKey = "roles"
@@ -376,7 +405,6 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce, atHash, cH
}
// split nonce to see if its authorization code grant method
if cHash != "" {
customClaims["at_hash"] = atHash
customClaims["c_hash"] = cHash
@@ -384,13 +412,11 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce, atHash, cH
customClaims["nonce"] = nonce
customClaims["at_hash"] = atHash
}
for k, v := range userMap {
if k != "roles" {
customClaims[k] = v
}
}
// check for the extra access token script
accessTokenScript, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyCustomAccessTokenScript)
if err != nil {
@@ -399,7 +425,6 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce, atHash, cH
}
if accessTokenScript != "" {
vm := otto.New()
claimBytes, _ := json.Marshal(customClaims)
vm.Run(fmt.Sprintf(`
var user = %s;