From 94cdbc92681a680d82e0caf1693b1780bce63d2a Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Mon, 20 Sep 2021 10:00:17 +0530 Subject: [PATCH] feat: add role based oauth --- .env.sample | 3 +- TODO.md | 8 +-- server/constants/constants.go | 5 +- server/env.go | 8 ++- server/graph/generated/generated.go | 91 +++++++++++++++++++++++------ server/graph/model/models_gen.go | 35 ++++++----- server/graph/schema.graphqls | 8 ++- server/graph/schema.resolvers.go | 2 +- server/handlers/app.go | 8 --- server/handlers/oauthCallback.go | 72 +++++++++++++---------- server/handlers/oauthLogin.go | 17 +++++- server/handlers/verifyEmail.go | 10 +--- server/resolvers/login.go | 22 ++++--- server/resolvers/logout.go | 4 +- server/resolvers/profile.go | 8 ++- server/resolvers/signup.go | 28 ++++----- server/resolvers/token.go | 15 ++--- server/resolvers/updateProfile.go | 45 ++++++++------ server/resolvers/verifyEmail.go | 13 ++--- server/utils/authToken.go | 40 +++++++++---- server/utils/common.go | 20 +++++++ server/utils/validator.go | 16 ++++- 22 files changed, 307 insertions(+), 171 deletions(-) create mode 100644 server/utils/common.go diff --git a/.env.sample b/.env.sample index 26d77f0..f5d9f37 100644 --- a/.env.sample +++ b/.env.sample @@ -6,4 +6,5 @@ DISABLE_EMAIL_VERIFICATION=true JWT_SECRET=random_string JWT_TYPE=HS256 ROLES=user,admin -DEFAULT_ROLE=user \ No newline at end of file +DEFAULT_ROLE=user +JWT_ROLE_CLAIM=role \ No newline at end of file diff --git a/TODO.md b/TODO.md index 5e6b62d..889f10b 100644 --- a/TODO.md +++ b/TODO.md @@ -9,8 +9,8 @@ For the first version we will only support setting roles master list via env - [x] `DEFAULT_ROLE` -> default role to assign to users - [x] Add roles input for signup - [x] Add roles to update profile mutation -- [ ] Add roles input for login -- [ ] Return roles to user -- [ ] Return roles in users list for super admin -- [ ] Add roles to the JWT token generation +- [x] Add roles input for login +- [x] Return roles to user +- [x] Return roles in users list for super admin +- [x] Add roles to the JWT token generation - [ ] Validate token should also validate the role, if roles to validate again is present in request diff --git a/server/constants/constants.go b/server/constants/constants.go index 666ddf0..6d3e400 100644 --- a/server/constants/constants.go +++ b/server/constants/constants.go @@ -24,8 +24,9 @@ var ( DISABLE_BASIC_AUTHENTICATION = "false" // ROLES - ROLES = []string{} - DEFAULT_ROLE = "" + ROLES = []string{} + DEFAULT_ROLE = "" + JWT_ROLE_CLAIM = "role" // OAuth login GOOGLE_CLIENT_ID = "" diff --git a/server/env.go b/server/env.go index 86f0b83..00c8e3e 100644 --- a/server/env.go +++ b/server/env.go @@ -74,6 +74,7 @@ func InitEnv() { constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION") constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION") constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE") + constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM") if constants.ADMIN_SECRET == "" { panic("root admin secret is required") @@ -148,6 +149,7 @@ func InitEnv() { rolesSplit := strings.Split(os.Getenv("ROLES"), ",") roles := []string{} defaultRole := "" + for _, val := range rolesSplit { trimVal := strings.TrimSpace(val) if trimVal != "" { @@ -159,7 +161,7 @@ func InitEnv() { } } if len(roles) > 0 && defaultRole == "" { - panic(`Invalid DEFAULT_ROLE environment. It can be one from give ROLES environment variable value`) + panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) } if len(roles) == 0 { @@ -168,4 +170,8 @@ func InitEnv() { } constants.ROLES = roles + + if constants.JWT_ROLE_CLAIM == "" { + constants.JWT_ROLE_CLAIM = "role" + } } diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index a869075..94b05ad 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -80,7 +80,7 @@ type ComplexityRoot struct { Query struct { Meta func(childComplexity int) int Profile func(childComplexity int) int - Token func(childComplexity int) int + Token func(childComplexity int, role *string) int Users func(childComplexity int) int VerificationRequests func(childComplexity int) int } @@ -127,7 +127,7 @@ type MutationResolver interface { type QueryResolver interface { Meta(ctx context.Context) (*model.Meta, error) Users(ctx context.Context) ([]*model.User, error) - Token(ctx context.Context) (*model.AuthResponse, error) + Token(ctx context.Context, role *string) (*model.AuthResponse, error) Profile(ctx context.Context) (*model.User, error) VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error) } @@ -360,7 +360,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in break } - return e.complexity.Query.Token(childComplexity), true + args, err := ec.field_Query_token_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true case "Query.users": if e.complexity.Query.Users == nil { @@ -570,6 +575,8 @@ var sources = []*ast.Source{ # # https://gqlgen.com/getting-started/ scalar Int64 +scalar Map +scalar Any type Meta { version: String! @@ -591,7 +598,7 @@ type User { image: String createdAt: Int64 updatedAt: Int64 - roles: [String] + roles: [String!]! } type VerificationRequest { @@ -652,7 +659,7 @@ input UpdateProfileInput { lastName: String image: String email: String - roles: [String] + # roles: [String] } input ForgotPasswordInput { @@ -684,7 +691,7 @@ type Mutation { type Query { meta: Meta! users: [User!]! - token: AuthResponse + token(role: String): AuthResponse profile: User! verificationRequests: [VerificationRequest!]! } @@ -831,6 +838,21 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs return args, nil } +func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *string + if tmp, ok := rawArgs["role"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role")) + arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["role"] = arg0 + return args, nil +} + func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -1772,9 +1794,16 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll } ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_token_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Token(rctx) + return ec.resolvers.Query().Token(rctx, args["role"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -2286,11 +2315,14 @@ func (ec *executionContext) _User_roles(ctx context.Context, field graphql.Colle return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.([]*string) + res := resTmp.([]string) fc.Result = res - return ec.marshalOString2ᚕᚖstring(ctx, field.Selections, res) + return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res) } func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { @@ -3869,14 +3901,6 @@ func (ec *executionContext) unmarshalInputUpdateProfileInput(ctx context.Context if err != nil { return it, err } - case "roles": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) - it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v) - if err != nil { - return it, err - } } } @@ -4268,6 +4292,9 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = ec._User_updatedAt(ctx, field, obj) case "roles": out.Values[i] = ec._User_roles(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -4680,6 +4707,36 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S return res } +func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) { + var vSlice []interface{} + if v != nil { + if tmp1, ok := v.([]interface{}); ok { + vSlice = tmp1 + } else { + vSlice = []interface{}{v} + } + } + var err error + res := make([]string, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalNString2string(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + for i := range v { + ret[i] = ec.marshalNString2string(ctx, sel, v[i]) + } + + return ret +} + func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) { res, err := ec.unmarshalInputUpdateProfileInput(ctx, v) return res, graphql.ErrorOnPath(ctx, err) diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index ccf97a7..3b8f59d 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -63,27 +63,26 @@ type SignUpInput struct { } type UpdateProfileInput struct { - OldPassword *string `json:"oldPassword"` - NewPassword *string `json:"newPassword"` - ConfirmNewPassword *string `json:"confirmNewPassword"` - FirstName *string `json:"firstName"` - LastName *string `json:"lastName"` - Image *string `json:"image"` - Email *string `json:"email"` - Roles []*string `json:"roles"` + OldPassword *string `json:"oldPassword"` + NewPassword *string `json:"newPassword"` + ConfirmNewPassword *string `json:"confirmNewPassword"` + FirstName *string `json:"firstName"` + LastName *string `json:"lastName"` + Image *string `json:"image"` + Email *string `json:"email"` } type User struct { - ID string `json:"id"` - Email string `json:"email"` - SignupMethod string `json:"signupMethod"` - FirstName *string `json:"firstName"` - LastName *string `json:"lastName"` - EmailVerifiedAt *int64 `json:"emailVerifiedAt"` - Image *string `json:"image"` - CreatedAt *int64 `json:"createdAt"` - UpdatedAt *int64 `json:"updatedAt"` - Roles []*string `json:"roles"` + ID string `json:"id"` + Email string `json:"email"` + SignupMethod string `json:"signupMethod"` + FirstName *string `json:"firstName"` + LastName *string `json:"lastName"` + EmailVerifiedAt *int64 `json:"emailVerifiedAt"` + Image *string `json:"image"` + CreatedAt *int64 `json:"createdAt"` + UpdatedAt *int64 `json:"updatedAt"` + Roles []string `json:"roles"` } type VerificationRequest struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index ca9a448..651720c 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -2,6 +2,8 @@ # # https://gqlgen.com/getting-started/ scalar Int64 +scalar Map +scalar Any type Meta { version: String! @@ -23,7 +25,7 @@ type User { image: String createdAt: Int64 updatedAt: Int64 - roles: [String] + roles: [String!]! } type VerificationRequest { @@ -84,7 +86,7 @@ input UpdateProfileInput { lastName: String image: String email: String - roles: [String] + # roles: [String] } input ForgotPasswordInput { @@ -116,7 +118,7 @@ type Mutation { type Query { meta: Meta! users: [User!]! - token: AuthResponse + token(role: String): AuthResponse profile: User! verificationRequests: [VerificationRequest!]! } diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index 064387f..f1d0519 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -55,7 +55,7 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) { return resolvers.Users(ctx) } -func (r *queryResolver) Token(ctx context.Context) (*model.AuthResponse, error) { +func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) { return resolvers.Token(ctx) } diff --git a/server/handlers/app.go b/server/handlers/app.go index 1f26c5b..41fc354 100644 --- a/server/handlers/app.go +++ b/server/handlers/app.go @@ -25,7 +25,6 @@ func AppHandler() gin.HandlerFunc { if state == "" { // cookie, err := utils.GetAuthToken(c) - // log.Println(`cookie`, cookie) // if err != nil { // c.JSON(400, gin.H{"error": "invalid state"}) // return @@ -67,13 +66,6 @@ func AppHandler() gin.HandlerFunc { } } - log.Println(gin.H{ - "data": map[string]string{ - "authorizerURL": stateObj.AuthorizerURL, - "redirectURL": stateObj.RedirectURL, - }, - }) - // debug the request state if pusher := c.Writer.Pusher(); pusher != nil { // use pusher.Push() to do server push diff --git a/server/handlers/oauthCallback.go b/server/handlers/oauthCallback.go index 2094e54..2c15104 100644 --- a/server/handlers/oauthCallback.go +++ b/server/handlers/oauthCallback.go @@ -19,7 +19,7 @@ import ( "golang.org/x/oauth2" ) -func processGoogleUserInfo(code string, c *gin.Context) error { +func processGoogleUserInfo(code string, role string, c *gin.Context) error { token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code) if err != nil { return fmt.Errorf("invalid google exchange code: %s", err.Error()) @@ -50,6 +50,7 @@ func processGoogleUserInfo(code string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Google.String() + user.Roles = role } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -60,27 +61,28 @@ func processGoogleUserInfo(code string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password + log.Println("=> checking roles...", utils.IsValidRole(strings.Split(existingUser.Roles, ","), role)) + if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { + log.Println("=> invalid role from google oauth") + return fmt.Errorf("invalid role") + } + + user.Roles = existingUser.Roles } user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) return nil } -func processGithubUserInfo(code string, c *gin.Context) error { +func processGithubUserInfo(code string, role string, c *gin.Context) error { token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code) if err != nil { return fmt.Errorf("invalid github exchange code: %s", err.Error()) @@ -128,6 +130,7 @@ func processGithubUserInfo(code string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Github.String() + user.Roles = role } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -138,26 +141,26 @@ func processGithubUserInfo(code string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password + + if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { + return fmt.Errorf("invalid role") + } + + user.Roles = existingUser.Roles } user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) return nil } -func processFacebookUserInfo(code string, c *gin.Context) error { +func processFacebookUserInfo(code string, role string, c *gin.Context) error { token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code) if err != nil { return fmt.Errorf("invalid facebook exchange code: %s", err.Error()) @@ -199,6 +202,7 @@ func processFacebookUserInfo(code string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Github.String() + user.Roles = role } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -209,20 +213,20 @@ func processFacebookUserInfo(code string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password + + if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { + return fmt.Errorf("invalid role") + } + + user.Roles = existingUser.Roles } user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) return nil @@ -238,23 +242,27 @@ func OAuthCallbackHandler() gin.HandlerFunc { c.JSON(400, gin.H{"error": "invalid oauth state"}) } session.DeleteToken(sessionState) + // contains random token, redirect url, role sessionSplit := strings.Split(state, "___") // TODO validate redirect url - if len(sessionSplit) != 2 { + if len(sessionSplit) < 2 { c.JSON(400, gin.H{"error": "invalid redirect url"}) return } + role := sessionSplit[2] + redirectURL := sessionSplit[1] + var err error code := c.Request.FormValue("code") switch provider { case enum.Google.String(): - err = processGoogleUserInfo(code, c) + err = processGoogleUserInfo(code, role, c) case enum.Github.String(): - err = processGithubUserInfo(code, c) + err = processGithubUserInfo(code, role, c) case enum.Facebook.String(): - err = processFacebookUserInfo(code, c) + err = processFacebookUserInfo(code, role, c) default: err = fmt.Errorf(`invalid oauth provider`) } @@ -263,6 +271,6 @@ func OAuthCallbackHandler() gin.HandlerFunc { c.JSON(400, gin.H{"error": err.Error()}) return } - c.Redirect(http.StatusTemporaryRedirect, sessionSplit[1]) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) } } diff --git a/server/handlers/oauthLogin.go b/server/handlers/oauthLogin.go index 62aa8b0..5730d9d 100644 --- a/server/handlers/oauthLogin.go +++ b/server/handlers/oauthLogin.go @@ -7,6 +7,7 @@ import ( "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/session" + "github.com/authorizerdev/authorizer/server/utils" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -17,6 +18,7 @@ func OAuthLoginHandler() gin.HandlerFunc { return func(c *gin.Context) { // TODO validate redirect URL redirectURL := c.Query("redirectURL") + role := c.Query("role") if redirectURL == "" { c.JSON(400, gin.H{ @@ -24,8 +26,21 @@ func OAuthLoginHandler() gin.HandlerFunc { }) return } + + if role != "" { + // validate role + if !utils.IsValidRole(constants.ROLES, role) { + c.JSON(400, gin.H{ + "error": "invalid role", + }) + return + } + } else { + role = constants.DEFAULT_ROLE + } + uuid := uuid.New() - oauthStateString := uuid.String() + "___" + redirectURL + oauthStateString := uuid.String() + "___" + redirectURL + "___" + role provider := c.Param("oauth_provider") diff --git a/server/handlers/verifyEmail.go b/server/handlers/verifyEmail.go index 64b752f..3f7e956 100644 --- a/server/handlers/verifyEmail.go +++ b/server/handlers/verifyEmail.go @@ -50,15 +50,9 @@ func VerifyEmailHandler() gin.HandlerFunc { db.Mgr.DeleteToken(claim.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles) session.SetToken(userIdStr, refreshToken) utils.SetCookie(c, accessToken) diff --git a/server/resolvers/login.go b/server/resolvers/login.go index c168ae3..f6f5ede 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -46,16 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e log.Println("Compare password error:", err) return res, fmt.Errorf(`invalid password`) } - userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + role := constants.DEFAULT_ROLE + if params.Role != nil { + // validate role + if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) { + return res, fmt.Errorf(`invalid role`) + } - accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + role = *params.Role + } + userIdStr := fmt.Sprintf("%v", user.ID) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) + + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role) session.SetToken(userIdStr, refreshToken) @@ -71,6 +74,7 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, }, diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index 654dcbc..7d623c2 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -2,6 +2,7 @@ package resolvers import ( "context" + "fmt" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/session" @@ -25,7 +26,8 @@ func Logout(ctx context.Context) (*model.Response, error) { return res, err } - session.DeleteToken(claim.ID) + userId := fmt.Sprintf("%v", claim["id"]) + session.DeleteToken(userId) res = &model.Response{ Message: "Logged out successfully", } diff --git a/server/resolvers/profile.go b/server/resolvers/profile.go index 1e64a96..3c21676 100644 --- a/server/resolvers/profile.go +++ b/server/resolvers/profile.go @@ -3,6 +3,7 @@ package resolvers import ( "context" "fmt" + "strings" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" @@ -27,13 +28,15 @@ func Profile(ctx context.Context) (*model.User, error) { return res, err } - sessionToken := session.GetToken(claim.ID) + userID := fmt.Sprintf("%v", claim["id"]) + email := fmt.Sprintf("%v", claim["email"]) + sessionToken := session.GetToken(userID) if sessionToken == "" { return res, fmt.Errorf(`unauthorized`) } - user, err := db.Mgr.GetUserByEmail(claim.Email) + user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } @@ -48,6 +51,7 @@ func Profile(ctx context.Context) (*model.User, error) { LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, } diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index f333b88..14809aa 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -35,13 +35,18 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, return res, fmt.Errorf(`invalid email address`) } + inputRoles := []string{} + if params.Roles != nil && len(params.Roles) > 0 { // check if roles exists - if !utils.IsValidRolesArray(params.Roles) { + for _, item := range params.Roles { + inputRoles = append(inputRoles, *item) + } + if !utils.IsValidRolesArray(inputRoles) { return res, fmt.Errorf(`invalid roles`) } } else { - params.Roles = []*string{&constants.DEFAULT_ROLE} + inputRoles = []string{constants.DEFAULT_ROLE} } // find user with email @@ -58,12 +63,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, Email: params.Email, } - roles := "" - for _, roleInput := range params.Roles { - roles += *roleInput + "," - } - roles = strings.TrimSuffix(roles, ",") - user.Roles = roles + user.Roles = strings.Join(inputRoles, ",") password, _ := utils.HashPassword(params.Password) user.Password = password @@ -93,9 +93,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, - Roles: params.Roles, } if constants.DISABLE_EMAIL_VERIFICATION != "true" { @@ -123,15 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, } } else { - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE) - accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE) session.SetToken(userIdStr, refreshToken) res = &model.AuthResponse{ diff --git a/server/resolvers/token.go b/server/resolvers/token.go index 9bb3e1d..a8ea886 100644 --- a/server/resolvers/token.go +++ b/server/resolvers/token.go @@ -3,8 +3,10 @@ package resolvers import ( "context" "fmt" + "strings" "time" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/graph/model" @@ -25,9 +27,10 @@ func Token(ctx context.Context) (*model.AuthResponse, error) { } claim, accessTokenErr := utils.VerifyAuthToken(token) - expiresAt := claim.ExpiresAt - - user, err := db.Mgr.GetUserByEmail(claim.Email) + expiresAt := claim["exp"].(int64) + email := fmt.Sprintf("%v", claim["email"]) + role := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM]) + user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } @@ -46,10 +49,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) { if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 { // if access token has expired and refresh/session token is valid // generate new accessToken - token, expiresAt, _ = utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, role) } utils.SetCookie(gc, token) res = &model.AuthResponse{ @@ -62,6 +62,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) { Image: &user.Image, FirstName: &user.FirstName, LastName: &user.LastName, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, }, diff --git a/server/resolvers/updateProfile.go b/server/resolvers/updateProfile.go index 42ebfec..25f8095 100644 --- a/server/resolvers/updateProfile.go +++ b/server/resolvers/updateProfile.go @@ -32,18 +32,20 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model return res, err } - sessionToken := session.GetToken(claim.ID) + id := fmt.Sprintf("%v", claim["id"]) + sessionToken := session.GetToken(id) if sessionToken == "" { return res, fmt.Errorf(`unauthorized`) } // validate if all params are not empty - if params.FirstName == nil && params.LastName == nil && params.Image == nil && params.OldPassword == nil && params.Email == nil && params.Roles != nil { + if params.FirstName == nil && params.LastName == nil && params.Image == nil && params.OldPassword == nil && params.Email == nil { return res, fmt.Errorf("please enter atleast one param to update") } - user, err := db.Mgr.GetUserByEmail(claim.Email) + email := fmt.Sprintf("%v", claim["email"]) + user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } @@ -122,21 +124,30 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model }() } - rolesToSave := "" - if params.Roles != nil && len(params.Roles) > 0 { - currentRoles := strings.Split(user.Roles, ",") - inputRoles := []string{} - for _, item := range params.Roles { - inputRoles = append(inputRoles, *item) - } - if !utils.IsStringArrayEqual(inputRoles, currentRoles) && utils.IsValidRolesArray(params.Roles) { - rolesToSave = strings.Join(inputRoles, ",") - } - } + // TODO this idea needs to be verified otherwise every user can make themselves super admin + // rolesToSave := "" + // if params.Roles != nil && len(params.Roles) > 0 { + // currentRoles := strings.Split(user.Roles, ",") + // inputRoles := []string{} + // for _, item := range params.Roles { + // inputRoles = append(inputRoles, *item) + // } - if rolesToSave != "" { - user.Roles = rolesToSave - } + // if !utils.IsValidRolesArray(inputRoles) { + // return res, fmt.Errorf("invalid list of roles") + // } + + // if !utils.IsStringArrayEqual(inputRoles, currentRoles) { + // rolesToSave = strings.Join(inputRoles, ",") + // } + + // session.DeleteToken(fmt.Sprintf("%v", user.ID)) + // utils.DeleteCookie(gc) + // } + + // if rolesToSave != "" { + // user.Roles = rolesToSave + // } _, err = db.Mgr.UpdateUser(user) if err != nil { diff --git a/server/resolvers/verifyEmail.go b/server/resolvers/verifyEmail.go index e05abcb..106cc05 100644 --- a/server/resolvers/verifyEmail.go +++ b/server/resolvers/verifyEmail.go @@ -3,8 +3,10 @@ package resolvers import ( "context" "fmt" + "strings" "time" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/graph/model" @@ -41,15 +43,9 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut db.Mgr.DeleteToken(claim.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE) - accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE) session.SetToken(userIdStr, refreshToken) @@ -65,6 +61,7 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, }, diff --git a/server/utils/authToken.go b/server/utils/authToken.go index 0ce6b67..72f4f8a 100644 --- a/server/utils/authToken.go +++ b/server/utils/authToken.go @@ -1,29 +1,32 @@ package utils import ( + "encoding/json" "fmt" "log" "strings" "time" "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" ) -type UserAuthInfo struct { - Email string `json:"email"` - ID string `json:"id"` -} +// type UserAuthInfo struct { +// Email string `json:"email"` +// ID string `json:"id"` +// } + +type JWTCustomClaim map[string]interface{} type UserAuthClaim struct { *jwt.StandardClaims - TokenType string `json:"token_type"` - UserAuthInfo + *JWTCustomClaim `json:"authorizer"` } -func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64, error) { +func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) { t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE)) expiryBound := time.Hour if tokenType == enum.RefreshToken { @@ -33,12 +36,19 @@ func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64 expiresAt := time.Now().Add(expiryBound).Unix() + customClaims := JWTCustomClaim{ + "token_type": tokenType.String(), + "email": user.Email, + "id": user.ID, + "allowed_roles": strings.Split(user.Roles, ","), + constants.JWT_ROLE_CLAIM: role, + } + t.Claims = &UserAuthClaim{ &jwt.StandardClaims{ ExpiresAt: expiresAt, }, - tokenType.String(), - user, + &customClaims, } token, err := t.SignedString([]byte(constants.JWT_SECRET)) @@ -63,14 +73,20 @@ func GetAuthToken(gc *gin.Context) (string, error) { return token, nil } -func VerifyAuthToken(token string) (*UserAuthClaim, error) { +func VerifyAuthToken(token string) (map[string]interface{}, error) { + var res map[string]interface{} claims := &UserAuthClaim{} + _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { return []byte(constants.JWT_SECRET), nil }) if err != nil { - return claims, err + return res, err } - return claims, nil + data, _ := json.Marshal(claims.JWTCustomClaim) + json.Unmarshal(data, &res) + res["exp"] = claims.ExpiresAt + + return res, nil } diff --git a/server/utils/common.go b/server/utils/common.go new file mode 100644 index 0000000..d2d713d --- /dev/null +++ b/server/utils/common.go @@ -0,0 +1,20 @@ +package utils + +import ( + "io" + "os" +) + +func WriteToFile(filename string, data string) error { + file, err := os.Create(filename) + if err != nil { + return err + } + defer file.Close() + + _, err = io.WriteString(file, data) + if err != nil { + return err + } + return file.Sync() +} diff --git a/server/utils/validator.go b/server/utils/validator.go index 63bb92a..c1f2fed 100644 --- a/server/utils/validator.go +++ b/server/utils/validator.go @@ -40,7 +40,7 @@ func IsSuperAdmin(gc *gin.Context) bool { return secret == constants.ADMIN_SECRET } -func IsValidRolesArray(roles []*string) bool { +func IsValidRolesArray(roles []string) bool { valid := true currentRoleMap := map[string]bool{} @@ -48,7 +48,7 @@ func IsValidRolesArray(roles []*string) bool { currentRoleMap[currentRole] = true } for _, inputRole := range roles { - if !currentRoleMap[*inputRole] { + if !currentRoleMap[inputRole] { valid = false break } @@ -56,6 +56,18 @@ func IsValidRolesArray(roles []*string) bool { return valid } +func IsValidRole(userRoles []string, role string) bool { + valid := false + for _, currentRole := range userRoles { + if role == currentRole { + valid = true + break + } + } + + return valid +} + func IsStringArrayEqual(a, b []string) bool { if len(a) != len(b) { return false