From 6fa0ad180971b6f5ee897afb028968a67668c5ab Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Wed, 12 Jul 2023 22:12:17 +0530 Subject: [PATCH] feat: add resolver to validate browser session --- server/graph/generated/generated.go | 263 +++++++++++++++++++++++++++ server/graph/model/models_gen.go | 10 +- server/graph/schema.graphqls | 10 + server/graph/schema.resolvers.go | 5 + server/resolvers/validate_session.go | 59 ++++++ server/test/resolvers_test.go | 1 + server/test/validate_session_test.go | 61 +++++++ 7 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 server/resolvers/validate_session.go create mode 100644 server/test/validate_session_test.go diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index e849306..885b04d 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -219,6 +219,7 @@ type ComplexityRoot struct { User func(childComplexity int, params model.GetUserRequest) int Users func(childComplexity int, params *model.PaginatedInput) int ValidateJwtToken func(childComplexity int, params model.ValidateJWTTokenInput) int + ValidateSession func(childComplexity int, params *model.ValidateSessionInput) int VerificationRequests func(childComplexity int, params *model.PaginatedInput) int Webhook func(childComplexity int, params model.WebhookRequest) int WebhookLogs func(childComplexity int, params *model.ListWebhookLogRequest) int @@ -275,6 +276,10 @@ type ComplexityRoot struct { IsValid func(childComplexity int) int } + ValidateSessionResponse struct { + IsValid func(childComplexity int) int + } + VerificationRequest struct { CreatedAt func(childComplexity int) int Email func(childComplexity int) int @@ -363,6 +368,7 @@ type QueryResolver interface { Session(ctx context.Context, params *model.SessionQueryInput) (*model.AuthResponse, error) Profile(ctx context.Context) (*model.User, error) ValidateJwtToken(ctx context.Context, params model.ValidateJWTTokenInput) (*model.ValidateJWTTokenResponse, error) + ValidateSession(ctx context.Context, params *model.ValidateSessionInput) (*model.ValidateSessionResponse, error) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) User(ctx context.Context, params model.GetUserRequest) (*model.User, error) VerificationRequests(ctx context.Context, params *model.PaginatedInput) (*model.VerificationRequests, error) @@ -1572,6 +1578,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.ValidateJwtToken(childComplexity, args["params"].(model.ValidateJWTTokenInput)), true + case "Query.validate_session": + if e.complexity.Query.ValidateSession == nil { + break + } + + args, err := ec.field_Query_validate_session_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.ValidateSession(childComplexity, args["params"].(*model.ValidateSessionInput)), true + case "Query._verification_requests": if e.complexity.Query.VerificationRequests == nil { break @@ -1844,6 +1862,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ValidateJWTTokenResponse.IsValid(childComplexity), true + case "ValidateSessionResponse.is_valid": + if e.complexity.ValidateSessionResponse.IsValid == nil { + break + } + + return e.complexity.ValidateSessionResponse.IsValid(childComplexity), true + case "VerificationRequest.created_at": if e.complexity.VerificationRequest.CreatedAt == nil { break @@ -2093,6 +2118,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputUpdateUserInput, ec.unmarshalInputUpdateWebhookRequest, ec.unmarshalInputValidateJWTTokenInput, + ec.unmarshalInputValidateSessionInput, ec.unmarshalInputVerifyEmailInput, ec.unmarshalInputVerifyMobileRequest, ec.unmarshalInputVerifyOTPRequest, @@ -2341,6 +2367,10 @@ type ValidateJWTTokenResponse { claims: Map } +type ValidateSessionResponse { + is_valid: Boolean! +} + type GenerateJWTKeysResponse { secret: String public_key: String @@ -2633,6 +2663,11 @@ input ValidateJWTTokenInput { roles: [String!] } +input ValidateSessionInput { + cookie: String! + roles: [String!] +} + input GenerateJWTKeysInput { type: String! } @@ -2755,6 +2790,7 @@ type Query { session(params: SessionQueryInput): AuthResponse! profile: User! validate_jwt_token(params: ValidateJWTTokenInput!): ValidateJWTTokenResponse! + validate_session(params: ValidateSessionInput): ValidateSessionResponse! # admin only apis _users(params: PaginatedInput): Users! _user(params: GetUserRequest!): User! @@ -3374,6 +3410,21 @@ func (ec *executionContext) field_Query_validate_jwt_token_args(ctx context.Cont return args, nil } +func (ec *executionContext) field_Query_validate_session_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *model.ValidateSessionInput + if tmp, ok := rawArgs["params"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("params")) + arg0, err = ec.unmarshalOValidateSessionInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["params"] = arg0 + return args, nil +} + func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -10159,6 +10210,65 @@ func (ec *executionContext) fieldContext_Query_validate_jwt_token(ctx context.Co return fc, nil } +func (ec *executionContext) _Query_validate_session(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_validate_session(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().ValidateSession(rctx, fc.Args["params"].(*model.ValidateSessionInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.ValidateSessionResponse) + fc.Result = res + return ec.marshalNValidateSessionResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionResponse(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_validate_session(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "is_valid": + return ec.fieldContext_ValidateSessionResponse_is_valid(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ValidateSessionResponse", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_validate_session_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } + return fc, nil +} + func (ec *executionContext) _Query__users(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Query__users(ctx, field) if err != nil { @@ -12381,6 +12491,50 @@ func (ec *executionContext) fieldContext_ValidateJWTTokenResponse_claims(ctx con return fc, nil } +func (ec *executionContext) _ValidateSessionResponse_is_valid(ctx context.Context, field graphql.CollectedField, obj *model.ValidateSessionResponse) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ValidateSessionResponse_is_valid(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.IsValid, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ValidateSessionResponse_is_valid(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ValidateSessionResponse", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Boolean does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { fc, err := ec.fieldContext_VerificationRequest_id(ctx, field) if err != nil { @@ -17555,6 +17709,42 @@ func (ec *executionContext) unmarshalInputValidateJWTTokenInput(ctx context.Cont return it, nil } +func (ec *executionContext) unmarshalInputValidateSessionInput(ctx context.Context, obj interface{}) (model.ValidateSessionInput, error) { + var it model.ValidateSessionInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"cookie", "roles"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "cookie": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("cookie")) + it.Cookie, err = ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + case "roles": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) + it.Roles, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputVerifyEmailInput(ctx context.Context, obj interface{}) (model.VerifyEmailInput, error) { var it model.VerifyEmailInput asMap := map[string]interface{}{} @@ -18866,6 +19056,29 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) } + out.Concurrently(i, func() graphql.Marshaler { + return rrm(innerCtx) + }) + case "validate_session": + field := field + + innerFunc := func(ctx context.Context) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_validate_session(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) + } + out.Concurrently(i, func() graphql.Marshaler { return rrm(innerCtx) }) @@ -19395,6 +19608,34 @@ func (ec *executionContext) _ValidateJWTTokenResponse(ctx context.Context, sel a return out } +var validateSessionResponseImplementors = []string{"ValidateSessionResponse"} + +func (ec *executionContext) _ValidateSessionResponse(ctx context.Context, sel ast.SelectionSet, obj *model.ValidateSessionResponse) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, validateSessionResponseImplementors) + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ValidateSessionResponse") + case "is_valid": + + out.Values[i] = ec._ValidateSessionResponse_is_valid(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var verificationRequestImplementors = []string{"VerificationRequest"} func (ec *executionContext) _VerificationRequest(ctx context.Context, sel ast.SelectionSet, obj *model.VerificationRequest) graphql.Marshaler { @@ -20470,6 +20711,20 @@ func (ec *executionContext) marshalNValidateJWTTokenResponse2ᚖgithubᚗcomᚋa return ec._ValidateJWTTokenResponse(ctx, sel, v) } +func (ec *executionContext) marshalNValidateSessionResponse2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionResponse(ctx context.Context, sel ast.SelectionSet, v model.ValidateSessionResponse) graphql.Marshaler { + return ec._ValidateSessionResponse(ctx, sel, &v) +} + +func (ec *executionContext) marshalNValidateSessionResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionResponse(ctx context.Context, sel ast.SelectionSet, v *model.ValidateSessionResponse) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._ValidateSessionResponse(ctx, sel, v) +} + func (ec *executionContext) marshalNVerificationRequest2ᚕᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerificationRequestᚄ(ctx context.Context, sel ast.SelectionSet, v []*model.VerificationRequest) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup @@ -21158,6 +21413,14 @@ func (ec *executionContext) marshalOUser2ᚖgithubᚗcomᚋauthorizerdevᚋautho return ec._User(ctx, sel, v) } +func (ec *executionContext) unmarshalOValidateSessionInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionInput(ctx context.Context, v interface{}) (*model.ValidateSessionInput, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalInputValidateSessionInput(ctx, v) + return &res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalO__EnumValue2ᚕgithubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚋintrospectionᚐEnumValueᚄ(ctx context.Context, sel ast.SelectionSet, v []introspection.EnumValue) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 7a1e376..d327f9a 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -120,7 +120,6 @@ type Env struct { AdminCookieSecure bool `json:"ADMIN_COOKIE_SECURE"` DefaultAuthorizeResponseType *string `json:"DEFAULT_AUTHORIZE_RESPONSE_TYPE"` DefaultAuthorizeResponseMode *string `json:"DEFAULT_AUTHORIZE_RESPONSE_MODE"` - SmsCodeExpiryTime *string `json:"SMS_CODE_EXPIRY_TIME"` } type Error struct { @@ -456,6 +455,15 @@ type ValidateJWTTokenResponse struct { Claims map[string]interface{} `json:"claims"` } +type ValidateSessionInput struct { + Cookie string `json:"cookie"` + Roles []string `json:"roles"` +} + +type ValidateSessionResponse struct { + IsValid bool `json:"is_valid"` +} + type VerificationRequest struct { ID string `json:"id"` Identifier *string `json:"identifier"` diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 4013b12..c830236 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -182,6 +182,10 @@ type ValidateJWTTokenResponse { claims: Map } +type ValidateSessionResponse { + is_valid: Boolean! +} + type GenerateJWTKeysResponse { secret: String public_key: String @@ -474,6 +478,11 @@ input ValidateJWTTokenInput { roles: [String!] } +input ValidateSessionInput { + cookie: String! + roles: [String!] +} + input GenerateJWTKeysInput { type: String! } @@ -596,6 +605,7 @@ type Query { session(params: SessionQueryInput): AuthResponse! profile: User! validate_jwt_token(params: ValidateJWTTokenInput!): ValidateJWTTokenResponse! + validate_session(params: ValidateSessionInput): ValidateSessionResponse! # admin only apis _users(params: PaginatedInput): Users! _user(params: GetUserRequest!): User! diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index 75dad85..beb49b2 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -191,6 +191,11 @@ func (r *queryResolver) ValidateJwtToken(ctx context.Context, params model.Valid return resolvers.ValidateJwtTokenResolver(ctx, params) } +// ValidateSession is the resolver for the validate_session field. +func (r *queryResolver) ValidateSession(ctx context.Context, params *model.ValidateSessionInput) (*model.ValidateSessionResponse, error) { + return resolvers.ValidateSessionResolver(ctx, params) +} + // Users is the resolver for the _users field. func (r *queryResolver) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) { return resolvers.UsersResolver(ctx, params) diff --git a/server/resolvers/validate_session.go b/server/resolvers/validate_session.go new file mode 100644 index 0000000..cb2b11f --- /dev/null +++ b/server/resolvers/validate_session.go @@ -0,0 +1,59 @@ +package resolvers + +import ( + "context" + "errors" + "fmt" + + "github.com/authorizerdev/authorizer/server/cookie" + "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/token" + "github.com/authorizerdev/authorizer/server/utils" + log "github.com/sirupsen/logrus" +) + +// ValidateSessionResolver is used to validate a cookie session without its rotation +func ValidateSessionResolver(ctx context.Context, params *model.ValidateSessionInput) (*model.ValidateSessionResponse, error) { + gc, err := utils.GinContextFromContext(ctx) + if err != nil { + log.Debug("Failed to get GinContext: ", err) + return nil, err + } + sessionToken := params.Cookie + if sessionToken == "" { + sessionToken, err = cookie.GetSession(gc) + if err != nil { + log.Debug("Failed to get session token: ", err) + return nil, errors.New("unauthorized") + } + } + claims, err := token.ValidateBrowserSession(gc, sessionToken) + if err != nil { + log.Debug("Failed to validate session token", err) + return nil, errors.New("unauthorized") + } + userID := claims.Subject + log := log.WithFields(log.Fields{ + "user_id": userID, + }) + _, err = db.Provider.GetUserByID(ctx, userID) + if err != nil { + return nil, err + } + // refresh token has "roles" as claim + claimRoleInterface := claims.Roles + claimRoles := []string{} + claimRoles = append(claimRoles, claimRoleInterface...) + if params != nil && params.Roles != nil && len(params.Roles) > 0 { + for _, v := range params.Roles { + if !utils.StringSliceContains(claimRoles, v) { + log.Debug("User does not have required role: ", claimRoles, v) + return nil, fmt.Errorf(`unauthorized`) + } + } + } + return &model.ValidateSessionResponse{ + IsValid: true, + }, nil +} diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index 4c83bf3..446986c 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -136,6 +136,7 @@ func TestResolvers(t *testing.T) { verifyOTPTest(t, s) resendOTPTest(t, s) verifyMobileTest(t, s) + validateSessionTests(t, s) updateAllUsersTest(t, s) webhookLogsTest(t, s) // get logs after above resolver tests are done diff --git a/server/test/validate_session_test.go b/server/test/validate_session_test.go new file mode 100644 index 0000000..16211ad --- /dev/null +++ b/server/test/validate_session_test.go @@ -0,0 +1,61 @@ +package test + +import ( + "fmt" + "strings" + "testing" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/authorizerdev/authorizer/server/token" + "github.com/stretchr/testify/assert" +) + +// ValidateSessionTests tests all the validate session resolvers +func validateSessionTests(t *testing.T, s TestSetup) { + t.Helper() + t.Run(`should validate session`, func(t *testing.T) { + req, ctx := createContext(s) + email := "validate_session." + s.TestInfo.Email + + resolvers.SignupResolver(ctx, model.SignUpInput{ + Email: email, + Password: s.TestInfo.Password, + ConfirmPassword: s.TestInfo.Password, + }) + _, err := resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{}) + assert.NotNil(t, err, "unauthorized") + verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) + verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ + Token: verificationRequest.Token, + }) + assert.NoError(t, err) + assert.NotNil(t, verifyRes) + accessToken := *verifyRes.AccessToken + assert.NotEmpty(t, accessToken) + claims, err := token.ParseJWTToken(accessToken) + assert.NoError(t, err) + assert.NotEmpty(t, claims) + sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + verifyRes.User.ID + sessionToken, err := memorystore.Provider.GetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+claims["nonce"].(string)) + assert.NoError(t, err) + assert.NotEmpty(t, sessionToken) + cookie := fmt.Sprintf("%s=%s;", constants.AppCookieName+"_session", sessionToken) + cookie = strings.TrimSuffix(cookie, ";") + res, err := resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{ + Cookie: sessionToken, + }) + assert.Nil(t, err) + assert.True(t, res.IsValid) + req.Header.Set("Cookie", cookie) + res, err = resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{}) + assert.Nil(t, err) + assert.True(t, res.IsValid) + cleanData(email) + }) +}