diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index 3dd7e42..9d5300e 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -144,6 +144,7 @@ type ComplexityRoot struct { Profile func(childComplexity int) int Session func(childComplexity int, params *model.SessionQueryInput) int Users func(childComplexity int, params *model.PaginatedInput) int + ValidateJwtToken func(childComplexity int, params model.ValidateJWTTokenInput) int VerificationRequests func(childComplexity int, params *model.PaginatedInput) int } @@ -176,6 +177,10 @@ type ComplexityRoot struct { Users func(childComplexity int) int } + ValidateJWTTokenResponse struct { + IsValid func(childComplexity int) int + } + VerificationRequest struct { CreatedAt func(childComplexity int) int Email func(childComplexity int) int @@ -217,6 +222,7 @@ type QueryResolver interface { Meta(ctx context.Context) (*model.Meta, error) Session(ctx context.Context, params *model.SessionQueryInput) (*model.AuthResponse, error) Profile(ctx context.Context) (*model.User, error) + ValidateJwtToken(ctx context.Context, params model.ValidateJWTTokenInput) (*model.ValidateJWTTokenResponse, error) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) VerificationRequests(ctx context.Context, params *model.PaginatedInput) (*model.VerificationRequests, error) AdminSession(ctx context.Context) (*model.Response, error) @@ -897,6 +903,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Users(childComplexity, args["params"].(*model.PaginatedInput)), true + case "Query.validate_jwt_token": + if e.complexity.Query.ValidateJwtToken == nil { + break + } + + args, err := ec.field_Query_validate_jwt_token_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.ValidateJwtToken(childComplexity, args["params"].(model.ValidateJWTTokenInput)), true + case "Query._verification_requests": if e.complexity.Query.VerificationRequests == nil { break @@ -1049,6 +1067,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Users.Users(childComplexity), true + case "ValidateJWTTokenResponse.is_valid": + if e.complexity.ValidateJWTTokenResponse.IsValid == nil { + break + } + + return e.complexity.ValidateJWTTokenResponse.IsValid(childComplexity), true + case "VerificationRequest.created_at": if e.complexity.VerificationRequest.CreatedAt == nil { break @@ -1318,6 +1343,10 @@ type Env { ORGANIZATION_LOGO: String } +type ValidateJWTTokenResponse { + is_valid: Boolean! +} + input UpdateEnvInput { ADMIN_SECRET: String CUSTOM_ACCESS_TOKEN_SCRIPT: String @@ -1473,6 +1502,12 @@ input InviteMemberInput { redirect_uri: String } +input ValidateJWTTokenInput { + token_type: String! + token: String! + roles: [String!] +} + type Mutation { signup(params: SignUpInput!): AuthResponse! login(params: LoginInput!): AuthResponse! @@ -1498,6 +1533,7 @@ type Query { meta: Meta! session(params: SessionQueryInput): AuthResponse! profile: User! + validate_jwt_token(params: ValidateJWTTokenInput!): ValidateJWTTokenResponse! # admin only apis _users(params: PaginatedInput): Users! _verification_requests(params: PaginatedInput): VerificationRequests! @@ -1797,6 +1833,21 @@ func (ec *executionContext) field_Query_session_args(ctx context.Context, rawArg return args, nil } +func (ec *executionContext) field_Query_validate_jwt_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 model.ValidateJWTTokenInput + if tmp, ok := rawArgs["params"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("params")) + arg0, err = ec.unmarshalNValidateJWTTokenInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateJWTTokenInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["params"] = arg0 + return args, nil +} + func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -4598,6 +4649,48 @@ func (ec *executionContext) _Query_profile(ctx context.Context, field graphql.Co return ec.marshalNUser2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } +func (ec *executionContext) _Query_validate_jwt_token(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_validate_jwt_token_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().ValidateJwtToken(rctx, args["params"].(model.ValidateJWTTokenInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.ValidateJWTTokenResponse) + fc.Result = res + return ec.marshalNValidateJWTTokenResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateJWTTokenResponse(ctx, field.Selections, res) +} + func (ec *executionContext) _Query__users(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -5487,6 +5580,41 @@ func (ec *executionContext) _Users_users(ctx context.Context, field graphql.Coll return ec.marshalNUser2ᚕᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUserᚄ(ctx, field.Selections, res) } +func (ec *executionContext) _ValidateJWTTokenResponse_is_valid(ctx context.Context, field graphql.CollectedField, obj *model.ValidateJWTTokenResponse) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ValidateJWTTokenResponse", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.IsValid, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -8025,6 +8153,45 @@ func (ec *executionContext) unmarshalInputUpdateUserInput(ctx context.Context, o return it, nil } +func (ec *executionContext) unmarshalInputValidateJWTTokenInput(ctx context.Context, obj interface{}) (model.ValidateJWTTokenInput, error) { + var it model.ValidateJWTTokenInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + for k, v := range asMap { + switch k { + case "token_type": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("token_type")) + it.TokenType, err = ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + case "token": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("token")) + it.Token, err = ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + case "roles": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) + it.Roles, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputVerifyEmailInput(ctx context.Context, obj interface{}) (model.VerifyEmailInput, error) { var it model.VerifyEmailInput asMap := map[string]interface{}{} @@ -8515,6 +8682,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "validate_jwt_token": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_validate_jwt_token(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "_users": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -8716,6 +8897,33 @@ func (ec *executionContext) _Users(ctx context.Context, sel ast.SelectionSet, ob return out } +var validateJWTTokenResponseImplementors = []string{"ValidateJWTTokenResponse"} + +func (ec *executionContext) _ValidateJWTTokenResponse(ctx context.Context, sel ast.SelectionSet, obj *model.ValidateJWTTokenResponse) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, validateJWTTokenResponseImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ValidateJWTTokenResponse") + case "is_valid": + out.Values[i] = ec._ValidateJWTTokenResponse_is_valid(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var verificationRequestImplementors = []string{"VerificationRequest"} func (ec *executionContext) _VerificationRequest(ctx context.Context, sel ast.SelectionSet, obj *model.VerificationRequest) graphql.Marshaler { @@ -9345,6 +9553,25 @@ func (ec *executionContext) marshalNUsers2ᚖgithubᚗcomᚋauthorizerdevᚋauth return ec._Users(ctx, sel, v) } +func (ec *executionContext) unmarshalNValidateJWTTokenInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateJWTTokenInput(ctx context.Context, v interface{}) (model.ValidateJWTTokenInput, error) { + res, err := ec.unmarshalInputValidateJWTTokenInput(ctx, v) + return res, graphql.ErrorOnPath(ctx, err) +} + +func (ec *executionContext) marshalNValidateJWTTokenResponse2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateJWTTokenResponse(ctx context.Context, sel ast.SelectionSet, v model.ValidateJWTTokenResponse) graphql.Marshaler { + return ec._ValidateJWTTokenResponse(ctx, sel, &v) +} + +func (ec *executionContext) marshalNValidateJWTTokenResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateJWTTokenResponse(ctx context.Context, sel ast.SelectionSet, v *model.ValidateJWTTokenResponse) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + return ec._ValidateJWTTokenResponse(ctx, sel, v) +} + func (ec *executionContext) marshalNVerificationRequest2ᚕᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerificationRequestᚄ(ctx context.Context, sel ast.SelectionSet, v []*model.VerificationRequest) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 1ebad63..ec260cf 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -256,6 +256,16 @@ type Users struct { Users []*User `json:"users"` } +type ValidateJWTTokenInput struct { + TokenType string `json:"token_type"` + Token string `json:"token"` + Roles []string `json:"roles"` +} + +type ValidateJWTTokenResponse struct { + IsValid bool `json:"is_valid"` +} + type VerificationRequest struct { ID string `json:"id"` Identifier *string `json:"identifier"` diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 13f2a1b..65cdf67 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -126,6 +126,10 @@ type Env { ORGANIZATION_LOGO: String } +type ValidateJWTTokenResponse { + is_valid: Boolean! +} + input UpdateEnvInput { ADMIN_SECRET: String CUSTOM_ACCESS_TOKEN_SCRIPT: String @@ -281,6 +285,12 @@ input InviteMemberInput { redirect_uri: String } +input ValidateJWTTokenInput { + token_type: String! + token: String! + roles: [String!] +} + type Mutation { signup(params: SignUpInput!): AuthResponse! login(params: LoginInput!): AuthResponse! @@ -306,6 +316,7 @@ type Query { meta: Meta! session(params: SessionQueryInput): AuthResponse! profile: User! + validate_jwt_token(params: ValidateJWTTokenInput!): ValidateJWTTokenResponse! # admin only apis _users(params: PaginatedInput): Users! _verification_requests(params: PaginatedInput): VerificationRequests! diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index e4f9275..5b8501c 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -91,6 +91,10 @@ func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) { return resolvers.ProfileResolver(ctx) } +func (r *queryResolver) ValidateJwtToken(ctx context.Context, params model.ValidateJWTTokenInput) (*model.ValidateJWTTokenResponse, error) { + return resolvers.ValidateJwtTokenResolver(ctx, params) +} + func (r *queryResolver) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) { return resolvers.UsersResolver(ctx, params) } @@ -113,5 +117,7 @@ func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResol // Query returns generated.QueryResolver implementation. func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } -type mutationResolver struct{ *Resolver } -type queryResolver struct{ *Resolver } +type ( + mutationResolver struct{ *Resolver } + queryResolver struct{ *Resolver } +) diff --git a/server/resolvers/validate_jwt_token.go b/server/resolvers/validate_jwt_token.go new file mode 100644 index 0000000..ce1c84c --- /dev/null +++ b/server/resolvers/validate_jwt_token.go @@ -0,0 +1,86 @@ +package resolvers + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/token" + "github.com/authorizerdev/authorizer/server/utils" + "github.com/golang-jwt/jwt" +) + +// ValidateJwtTokenResolver is used to validate a jwt token without its rotation +// this can be used at API level (backend) +// it can validate: +// access_token +// id_token +// refresh_token +func ValidateJwtTokenResolver(ctx context.Context, params model.ValidateJWTTokenInput) (*model.ValidateJWTTokenResponse, error) { + gc, err := utils.GinContextFromContext(ctx) + if err != nil { + return nil, err + } + + tokenType := params.TokenType + if tokenType != "access_token" && tokenType != "refresh_token" && tokenType != "id_token" { + return nil, errors.New("invalid token type") + } + + userID := "" + nonce := "" + // access_token and refresh_token should be validated from session store as well + if tokenType == "access_token" || tokenType == "refresh_token" { + savedSession := sessionstore.GetState(params.Token) + if savedSession == "" { + return &model.ValidateJWTTokenResponse{ + IsValid: false, + }, nil + } + savedSessionSplit := strings.Split(savedSession, "@") + nonce = savedSessionSplit[0] + userID = savedSessionSplit[1] + } + + hostname := utils.GetHost(gc) + var claimRoles []string + var claims jwt.MapClaims + + // we cannot validate sub and nonce in case of id_token as that token is not persisted in session store + if userID != "" && nonce != "" { + claims, err = token.ParseJWTToken(params.Token, hostname, nonce, userID) + if err != nil { + return &model.ValidateJWTTokenResponse{ + IsValid: false, + }, nil + } + } else { + claims, err = token.ParseJWTTokenWithoutNonce(params.Token, hostname) + if err != nil { + return &model.ValidateJWTTokenResponse{ + IsValid: false, + }, nil + } + + } + + claimRolesInterface := claims["roles"] + roleSlice := utils.ConvertInterfaceToSlice(claimRolesInterface) + for _, v := range roleSlice { + claimRoles = append(claimRoles, v.(string)) + } + + if params.Roles != nil && len(params.Roles) > 0 { + for _, v := range params.Roles { + if !utils.StringSliceContains(claimRoles, v) { + return nil, fmt.Errorf(`unauthorized`) + } + } + } + return &model.ValidateJWTTokenResponse{ + IsValid: true, + }, nil +} diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index 7e0c41d..b64695d 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -63,6 +63,7 @@ func TestResolvers(t *testing.T) { logoutTests(t, s) metaTests(t, s) inviteUserTest(t, s) + validateJwtTokenTest(t, s) }) } } diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go new file mode 100644 index 0000000..5bb4268 --- /dev/null +++ b/server/test/validate_jwt_token_test.go @@ -0,0 +1,90 @@ +package test + +import ( + "testing" + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/token" + "github.com/authorizerdev/authorizer/server/utils" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func validateJwtTokenTest(t *testing.T, s TestSetup) { + t.Helper() + _, ctx := createContext(s) + t.Run(`validate params`, func(t *testing.T) { + res, err := resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "access_token", + Token: "", + }) + assert.False(t, res.IsValid) + res, err = resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "access_token", + Token: "invalid", + }) + assert.False(t, res.IsValid) + _, err = resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "access_token_invalid", + Token: "invalid@invalid", + }) + assert.Error(t, err, "invalid token") + }) + + scope := []string{"openid", "email", "profile", "offline_access"} + user := models.User{ + ID: uuid.New().String(), + Email: "jwt_test_" + s.TestInfo.Email, + Roles: "user", + UpdatedAt: time.Now().Unix(), + CreatedAt: time.Now().Unix(), + } + + roles := []string{"user"} + gc, err := utils.GinContextFromContext(ctx) + assert.NoError(t, err) + authToken, err := token.CreateAuthToken(gc, user, roles, scope) + sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + + t.Run(`should validate the access token`, func(t *testing.T) { + res, err := resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "access_token", + Token: authToken.AccessToken.Token, + Roles: []string{"user"}, + }) + + assert.NoError(t, err) + assert.True(t, res.IsValid) + + res, err = resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "access_token", + Token: authToken.AccessToken.Token, + Roles: []string{"invalid_role"}, + }) + + assert.Error(t, err) + }) + + t.Run(`should validate the refresh token`, func(t *testing.T) { + res, err := resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "refresh_token", + Token: authToken.RefreshToken.Token, + }) + assert.NoError(t, err) + assert.True(t, res.IsValid) + }) + + t.Run(`should validate the id token`, func(t *testing.T) { + res, err := resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ + TokenType: "id_token", + Token: authToken.IDToken.Token, + }) + assert.NoError(t, err) + assert.True(t, res.IsValid) + }) +} diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 350da17..8714792 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -161,7 +161,12 @@ func GetAccessToken(gc *gin.Context) (string, error) { return "", fmt.Errorf(`unauthorized`) } - if !strings.HasPrefix(auth, "Bearer ") { + authSplit := strings.Split(auth, " ") + if len(authSplit) != 2 { + return "", fmt.Errorf(`unauthorized`) + } + + if strings.ToLower(authSplit[0]) != "bearer" { return "", fmt.Errorf(`not a bearer token`) } @@ -350,7 +355,12 @@ func GetIDToken(gc *gin.Context) (string, error) { return "", fmt.Errorf(`unauthorized`) } - if !strings.HasPrefix(auth, "Bearer ") { + authSplit := strings.Split(auth, " ") + if len(authSplit) != 2 { + return "", fmt.Errorf(`unauthorized`) + } + + if strings.ToLower(authSplit[0]) != "bearer" { return "", fmt.Errorf(`not a bearer token`) } diff --git a/server/token/jwt.go b/server/token/jwt.go index 90f6333..0b87c09 100644 --- a/server/token/jwt.go +++ b/server/token/jwt.go @@ -105,3 +105,59 @@ func ParseJWTToken(token, hostname, nonce, subject string) (jwt.MapClaims, error return claims, nil } + +// ParseJWTTokenWithoutNonce common util to parse jwt token without nonce +// used to validate ID token as it is not persisted in store +func ParseJWTTokenWithoutNonce(token, hostname string) (jwt.MapClaims, error) { + jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + signingMethod := jwt.GetSigningMethod(jwtType) + + var err error + var claims jwt.MapClaims + + switch signingMethod { + case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512: + _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil + }) + case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512: + _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { + key, err := crypto.ParseRsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + if err != nil { + return nil, err + } + return key, nil + }) + case jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512: + _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { + key, err := crypto.ParseEcdsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + if err != nil { + return nil, err + } + return key, nil + }) + default: + err = errors.New("unsupported signing method") + } + if err != nil { + return claims, err + } + + // claim parses exp & iat into float 64 with e^10, + // but we expect it to be int64 + // hence we need to assert interface and convert to int64 + intExp := int64(claims["exp"].(float64)) + intIat := int64(claims["iat"].(float64)) + claims["exp"] = intExp + claims["iat"] = intIat + + if claims["aud"] != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + return claims, errors.New("invalid audience") + } + + if claims["iss"] != hostname { + return claims, errors.New("invalid issuer") + } + + return claims, nil +} diff --git a/server/utils/common.go b/server/utils/common.go index d4a8d51..6835806 100644 --- a/server/utils/common.go +++ b/server/utils/common.go @@ -2,6 +2,7 @@ package utils import ( "log" + "reflect" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" @@ -47,3 +48,24 @@ func RemoveDuplicateString(strSlice []string) []string { } return list } + +// ConvertInterfaceToSlice to convert interface to slice interface +func ConvertInterfaceToSlice(slice interface{}) []interface{} { + s := reflect.ValueOf(slice) + if s.Kind() != reflect.Slice { + return nil + } + + // Keep the distinction between nil and empty slice input + if s.IsNil() { + return nil + } + + ret := make([]interface{}, s.Len()) + + for i := 0; i < s.Len(); i++ { + ret[i] = s.Index(i).Interface() + } + + return ret +}