From 21e3425e76c975e811c538a14ef4b7f3a8e2c138 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Mon, 20 Sep 2021 10:36:26 +0530 Subject: [PATCH] feat/role based access (#50) * feat: add roles based access * feat: update roles env + todo * feat: add roles to update profile * feat: add role based oauth * feat: validate role for a given token --- .env.sample | 5 +- TODO.md | 16 +++ server/constants/constants.go | 5 + server/db/db.go | 3 +- server/db/roles.go | 19 +++ server/db/user.go | 1 + server/env.go | 31 +++++ server/graph/generated/generated.go | 173 +++++++++++++++++++++++++++- server/graph/model/models_gen.go | 37 +++--- server/graph/schema.graphqls | 8 +- server/graph/schema.resolvers.go | 4 +- server/handlers/app.go | 8 -- server/handlers/oauthCallback.go | 70 ++++++----- server/handlers/oauthLogin.go | 17 ++- server/handlers/verifyEmail.go | 10 +- server/main.go | 2 + server/resolvers/login.go | 22 ++-- server/resolvers/logout.go | 4 +- server/resolvers/profile.go | 8 +- server/resolvers/signup.go | 27 +++-- server/resolvers/token.go | 20 ++-- server/resolvers/updateProfile.go | 32 ++++- server/resolvers/verifyEmail.go | 13 +-- server/utils/authToken.go | 40 +++++-- server/utils/common.go | 20 ++++ server/utils/initServer.go | 25 ++++ server/utils/validateSuperAdmin.go | 15 --- server/utils/validator.go | 50 ++++++++ 28 files changed, 544 insertions(+), 141 deletions(-) create mode 100644 server/db/roles.go create mode 100644 server/utils/common.go create mode 100644 server/utils/initServer.go delete mode 100644 server/utils/validateSuperAdmin.go diff --git a/.env.sample b/.env.sample index 7e65f00..f5d9f37 100644 --- a/.env.sample +++ b/.env.sample @@ -4,4 +4,7 @@ DATABASE_TYPE=sqlite ADMIN_SECRET=admin DISABLE_EMAIL_VERIFICATION=true JWT_SECRET=random_string -JWT_TYPE=HS256 \ No newline at end of file +JWT_TYPE=HS256 +ROLES=user,admin +DEFAULT_ROLE=user +JWT_ROLE_CLAIM=role \ No newline at end of file diff --git a/TODO.md b/TODO.md index e69de29..461d024 100644 --- a/TODO.md +++ b/TODO.md @@ -0,0 +1,16 @@ +# Task List + +# Feature roles + +For the first version we will only support setting roles master list via env + +- [x] Support following ENV + - [x] `ROLES` -> comma separated list of role names + - [x] `DEFAULT_ROLE` -> default role to assign to users +- [x] Add roles input for signup +- [x] Add roles to update profile mutation +- [x] Add roles input for login +- [x] Return roles to user +- [x] Return roles in users list for super admin +- [x] Add roles to the JWT token generation +- [x] Validate token should also validate the role, if roles to validate again is present in request diff --git a/server/constants/constants.go b/server/constants/constants.go index 178080e..6d3e400 100644 --- a/server/constants/constants.go +++ b/server/constants/constants.go @@ -23,6 +23,11 @@ var ( DISABLE_EMAIL_VERIFICATION = "false" DISABLE_BASIC_AUTHENTICATION = "false" + // ROLES + ROLES = []string{} + DEFAULT_ROLE = "" + JWT_ROLE_CLAIM = "role" + // OAuth login GOOGLE_CLIENT_ID = "" GOOGLE_CLIENT_SECRET = "" diff --git a/server/db/db.go b/server/db/db.go index 0110322..dbd6350 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -24,6 +24,7 @@ type Manager interface { GetVerificationRequests() ([]VerificationRequest, error) GetVerificationByEmail(email string) (VerificationRequest, error) DeleteUser(email string) error + SaveRoles(roles []Role) error } type manager struct { @@ -53,7 +54,7 @@ func InitDB() { if err != nil { log.Fatal("Failed to init db:", err) } else { - db.AutoMigrate(&User{}, &VerificationRequest{}) + db.AutoMigrate(&User{}, &VerificationRequest{}, &Role{}) } Mgr = &manager{db: db} diff --git a/server/db/roles.go b/server/db/roles.go new file mode 100644 index 0000000..2506ad0 --- /dev/null +++ b/server/db/roles.go @@ -0,0 +1,19 @@ +package db + +import "log" + +type Role struct { + ID uint `gorm:"primaryKey"` + Role string +} + +// SaveRoles function to save roles +func (mgr *manager) SaveRoles(roles []Role) error { + res := mgr.db.Create(&roles) + if res.Error != nil { + log.Println(`Error saving roles`) + return res.Error + } + + return nil +} diff --git a/server/db/user.go b/server/db/user.go index 7a3dd59..ede8434 100644 --- a/server/db/user.go +++ b/server/db/user.go @@ -17,6 +17,7 @@ type User struct { CreatedAt int64 `gorm:"autoCreateTime"` UpdatedAt int64 `gorm:"autoUpdateTime"` Image string + Roles string } // SaveUser function to add user even with email conflict diff --git a/server/env.go b/server/env.go index 0ccfef2..00c8e3e 100644 --- a/server/env.go +++ b/server/env.go @@ -73,6 +73,8 @@ func InitEnv() { constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/") constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION") constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION") + constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE") + constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM") if constants.ADMIN_SECRET == "" { panic("root admin secret is required") @@ -143,4 +145,33 @@ func InitEnv() { constants.DISABLE_EMAIL_VERIFICATION = "false" } } + + rolesSplit := strings.Split(os.Getenv("ROLES"), ",") + roles := []string{} + defaultRole := "" + + for _, val := range rolesSplit { + trimVal := strings.TrimSpace(val) + if trimVal != "" { + roles = append(roles, trimVal) + } + + if trimVal == constants.DEFAULT_ROLE { + defaultRole = trimVal + } + } + if len(roles) > 0 && defaultRole == "" { + panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) + } + + if len(roles) == 0 { + roles = []string{"user", "admin"} + constants.DEFAULT_ROLE = "user" + } + + constants.ROLES = roles + + if constants.JWT_ROLE_CLAIM == "" { + constants.JWT_ROLE_CLAIM = "role" + } } diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index beb3d50..94b05ad 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -80,7 +80,7 @@ type ComplexityRoot struct { Query struct { Meta func(childComplexity int) int Profile func(childComplexity int) int - Token func(childComplexity int) int + Token func(childComplexity int, role *string) int Users func(childComplexity int) int VerificationRequests func(childComplexity int) int } @@ -97,6 +97,7 @@ type ComplexityRoot struct { ID func(childComplexity int) int Image func(childComplexity int) int LastName func(childComplexity int) int + Roles func(childComplexity int) int SignupMethod func(childComplexity int) int UpdatedAt func(childComplexity int) int } @@ -126,7 +127,7 @@ type MutationResolver interface { type QueryResolver interface { Meta(ctx context.Context) (*model.Meta, error) Users(ctx context.Context) ([]*model.User, error) - Token(ctx context.Context) (*model.AuthResponse, error) + Token(ctx context.Context, role *string) (*model.AuthResponse, error) Profile(ctx context.Context) (*model.User, error) VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error) } @@ -359,7 +360,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in break } - return e.complexity.Query.Token(childComplexity), true + args, err := ec.field_Query_token_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true case "Query.users": if e.complexity.Query.Users == nil { @@ -431,6 +437,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.User.LastName(childComplexity), true + case "User.roles": + if e.complexity.User.Roles == nil { + break + } + + return e.complexity.User.Roles(childComplexity), true + case "User.signupMethod": if e.complexity.User.SignupMethod == nil { break @@ -562,6 +575,8 @@ var sources = []*ast.Source{ # # https://gqlgen.com/getting-started/ scalar Int64 +scalar Map +scalar Any type Meta { version: String! @@ -583,6 +598,7 @@ type User { image: String createdAt: Int64 updatedAt: Int64 + roles: [String!]! } type VerificationRequest { @@ -618,11 +634,13 @@ input SignUpInput { password: String! confirmPassword: String! image: String + roles: [String] } input LoginInput { email: String! password: String! + role: String } input VerifyEmailInput { @@ -641,6 +659,7 @@ input UpdateProfileInput { lastName: String image: String email: String + # roles: [String] } input ForgotPasswordInput { @@ -672,7 +691,7 @@ type Mutation { type Query { meta: Meta! users: [User!]! - token: AuthResponse + token(role: String): AuthResponse profile: User! verificationRequests: [VerificationRequest!]! } @@ -819,6 +838,21 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs return args, nil } +func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *string + if tmp, ok := rawArgs["role"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role")) + arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + if err != nil { + return nil, err + } + } + args["role"] = arg0 + return args, nil +} + func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -1760,9 +1794,16 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll } ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_token_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Token(rctx) + return ec.resolvers.Query().Token(rctx, args["role"].(*string)) }) if err != nil { ec.Error(ctx, err) @@ -2249,6 +2290,41 @@ func (ec *executionContext) _User_updatedAt(ctx context.Context, field graphql.C return ec.marshalOInt642ᚖint64(ctx, field.Selections, res) } +func (ec *executionContext) _User_roles(ctx context.Context, field graphql.CollectedField, obj *model.User) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "User", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Roles, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]string) + fc.Result = res + return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res) +} + func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -3625,6 +3701,14 @@ func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj in if err != nil { return it, err } + case "role": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role")) + it.Role, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } } } @@ -3741,6 +3825,14 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i if err != nil { return it, err } + case "roles": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) + it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v) + if err != nil { + return it, err + } } } @@ -4198,6 +4290,11 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = ec._User_createdAt(ctx, field, obj) case "updatedAt": out.Values[i] = ec._User_updatedAt(ctx, field, obj) + case "roles": + out.Values[i] = ec._User_roles(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -4610,6 +4707,36 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S return res } +func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) { + var vSlice []interface{} + if v != nil { + if tmp1, ok := v.([]interface{}); ok { + vSlice = tmp1 + } else { + vSlice = []interface{}{v} + } + } + var err error + res := make([]string, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalNString2string(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + for i := range v { + ret[i] = ec.marshalNString2string(ctx, sel, v[i]) + } + + return ret +} + func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) { res, err := ec.unmarshalInputUpdateProfileInput(ctx, v) return res, graphql.ErrorOnPath(ctx, err) @@ -5002,6 +5129,42 @@ func (ec *executionContext) marshalOString2string(ctx context.Context, sel ast.S return graphql.MarshalString(v) } +func (ec *executionContext) unmarshalOString2ᚕᚖstring(ctx context.Context, v interface{}) ([]*string, error) { + if v == nil { + return nil, nil + } + var vSlice []interface{} + if v != nil { + if tmp1, ok := v.([]interface{}); ok { + vSlice = tmp1 + } else { + vSlice = []interface{}{v} + } + } + var err error + res := make([]*string, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalOString2ᚖstring(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) marshalOString2ᚕᚖstring(ctx context.Context, sel ast.SelectionSet, v []*string) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + for i := range v { + ret[i] = ec.marshalOString2ᚖstring(ctx, sel, v[i]) + } + + return ret +} + func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) { if v == nil { return nil, nil diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 1b564fe..3b8f59d 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -23,8 +23,9 @@ type ForgotPasswordInput struct { } type LoginInput struct { - Email string `json:"email"` - Password string `json:"password"` + Email string `json:"email"` + Password string `json:"password"` + Role *string `json:"role"` } type Meta struct { @@ -52,12 +53,13 @@ type Response struct { } type SignUpInput struct { - FirstName *string `json:"firstName"` - LastName *string `json:"lastName"` - Email string `json:"email"` - Password string `json:"password"` - ConfirmPassword string `json:"confirmPassword"` - Image *string `json:"image"` + FirstName *string `json:"firstName"` + LastName *string `json:"lastName"` + Email string `json:"email"` + Password string `json:"password"` + ConfirmPassword string `json:"confirmPassword"` + Image *string `json:"image"` + Roles []*string `json:"roles"` } type UpdateProfileInput struct { @@ -71,15 +73,16 @@ type UpdateProfileInput struct { } type User struct { - ID string `json:"id"` - Email string `json:"email"` - SignupMethod string `json:"signupMethod"` - FirstName *string `json:"firstName"` - LastName *string `json:"lastName"` - EmailVerifiedAt *int64 `json:"emailVerifiedAt"` - Image *string `json:"image"` - CreatedAt *int64 `json:"createdAt"` - UpdatedAt *int64 `json:"updatedAt"` + ID string `json:"id"` + Email string `json:"email"` + SignupMethod string `json:"signupMethod"` + FirstName *string `json:"firstName"` + LastName *string `json:"lastName"` + EmailVerifiedAt *int64 `json:"emailVerifiedAt"` + Image *string `json:"image"` + CreatedAt *int64 `json:"createdAt"` + UpdatedAt *int64 `json:"updatedAt"` + Roles []string `json:"roles"` } type VerificationRequest struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 04a14d3..651720c 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -2,6 +2,8 @@ # # https://gqlgen.com/getting-started/ scalar Int64 +scalar Map +scalar Any type Meta { version: String! @@ -23,6 +25,7 @@ type User { image: String createdAt: Int64 updatedAt: Int64 + roles: [String!]! } type VerificationRequest { @@ -58,11 +61,13 @@ input SignUpInput { password: String! confirmPassword: String! image: String + roles: [String] } input LoginInput { email: String! password: String! + role: String } input VerifyEmailInput { @@ -81,6 +86,7 @@ input UpdateProfileInput { lastName: String image: String email: String + # roles: [String] } input ForgotPasswordInput { @@ -112,7 +118,7 @@ type Mutation { type Query { meta: Meta! users: [User!]! - token: AuthResponse + token(role: String): AuthResponse profile: User! verificationRequests: [VerificationRequest!]! } diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index 50bb5cf..448346c 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -55,8 +55,8 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) { return resolvers.Users(ctx) } -func (r *queryResolver) Token(ctx context.Context) (*model.AuthResponse, error) { - return resolvers.Token(ctx) +func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) { + return resolvers.Token(ctx, role) } func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) { diff --git a/server/handlers/app.go b/server/handlers/app.go index 1f26c5b..41fc354 100644 --- a/server/handlers/app.go +++ b/server/handlers/app.go @@ -25,7 +25,6 @@ func AppHandler() gin.HandlerFunc { if state == "" { // cookie, err := utils.GetAuthToken(c) - // log.Println(`cookie`, cookie) // if err != nil { // c.JSON(400, gin.H{"error": "invalid state"}) // return @@ -67,13 +66,6 @@ func AppHandler() gin.HandlerFunc { } } - log.Println(gin.H{ - "data": map[string]string{ - "authorizerURL": stateObj.AuthorizerURL, - "redirectURL": stateObj.RedirectURL, - }, - }) - // debug the request state if pusher := c.Writer.Pusher(); pusher != nil { // use pusher.Push() to do server push diff --git a/server/handlers/oauthCallback.go b/server/handlers/oauthCallback.go index 2094e54..76201c2 100644 --- a/server/handlers/oauthCallback.go +++ b/server/handlers/oauthCallback.go @@ -19,7 +19,7 @@ import ( "golang.org/x/oauth2" ) -func processGoogleUserInfo(code string, c *gin.Context) error { +func processGoogleUserInfo(code string, role string, c *gin.Context) error { token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code) if err != nil { return fmt.Errorf("invalid google exchange code: %s", err.Error()) @@ -50,6 +50,7 @@ func processGoogleUserInfo(code string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Google.String() + user.Roles = role } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -60,27 +61,26 @@ func processGoogleUserInfo(code string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password + if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { + return fmt.Errorf("invalid role") + } + + user.Roles = existingUser.Roles } user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) return nil } -func processGithubUserInfo(code string, c *gin.Context) error { +func processGithubUserInfo(code string, role string, c *gin.Context) error { token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code) if err != nil { return fmt.Errorf("invalid github exchange code: %s", err.Error()) @@ -128,6 +128,7 @@ func processGithubUserInfo(code string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Github.String() + user.Roles = role } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -138,26 +139,26 @@ func processGithubUserInfo(code string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password + + if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { + return fmt.Errorf("invalid role") + } + + user.Roles = existingUser.Roles } user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) return nil } -func processFacebookUserInfo(code string, c *gin.Context) error { +func processFacebookUserInfo(code string, role string, c *gin.Context) error { token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code) if err != nil { return fmt.Errorf("invalid facebook exchange code: %s", err.Error()) @@ -199,6 +200,7 @@ func processFacebookUserInfo(code string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Github.String() + user.Roles = role } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -209,20 +211,20 @@ func processFacebookUserInfo(code string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password + + if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { + return fmt.Errorf("invalid role") + } + + user.Roles = existingUser.Roles } user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) return nil @@ -238,23 +240,27 @@ func OAuthCallbackHandler() gin.HandlerFunc { c.JSON(400, gin.H{"error": "invalid oauth state"}) } session.DeleteToken(sessionState) + // contains random token, redirect url, role sessionSplit := strings.Split(state, "___") // TODO validate redirect url - if len(sessionSplit) != 2 { + if len(sessionSplit) < 2 { c.JSON(400, gin.H{"error": "invalid redirect url"}) return } + role := sessionSplit[2] + redirectURL := sessionSplit[1] + var err error code := c.Request.FormValue("code") switch provider { case enum.Google.String(): - err = processGoogleUserInfo(code, c) + err = processGoogleUserInfo(code, role, c) case enum.Github.String(): - err = processGithubUserInfo(code, c) + err = processGithubUserInfo(code, role, c) case enum.Facebook.String(): - err = processFacebookUserInfo(code, c) + err = processFacebookUserInfo(code, role, c) default: err = fmt.Errorf(`invalid oauth provider`) } @@ -263,6 +269,6 @@ func OAuthCallbackHandler() gin.HandlerFunc { c.JSON(400, gin.H{"error": err.Error()}) return } - c.Redirect(http.StatusTemporaryRedirect, sessionSplit[1]) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) } } diff --git a/server/handlers/oauthLogin.go b/server/handlers/oauthLogin.go index 62aa8b0..5730d9d 100644 --- a/server/handlers/oauthLogin.go +++ b/server/handlers/oauthLogin.go @@ -7,6 +7,7 @@ import ( "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/session" + "github.com/authorizerdev/authorizer/server/utils" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -17,6 +18,7 @@ func OAuthLoginHandler() gin.HandlerFunc { return func(c *gin.Context) { // TODO validate redirect URL redirectURL := c.Query("redirectURL") + role := c.Query("role") if redirectURL == "" { c.JSON(400, gin.H{ @@ -24,8 +26,21 @@ func OAuthLoginHandler() gin.HandlerFunc { }) return } + + if role != "" { + // validate role + if !utils.IsValidRole(constants.ROLES, role) { + c.JSON(400, gin.H{ + "error": "invalid role", + }) + return + } + } else { + role = constants.DEFAULT_ROLE + } + uuid := uuid.New() - oauthStateString := uuid.String() + "___" + redirectURL + oauthStateString := uuid.String() + "___" + redirectURL + "___" + role provider := c.Param("oauth_provider") diff --git a/server/handlers/verifyEmail.go b/server/handlers/verifyEmail.go index 64b752f..3f7e956 100644 --- a/server/handlers/verifyEmail.go +++ b/server/handlers/verifyEmail.go @@ -50,15 +50,9 @@ func VerifyEmailHandler() gin.HandlerFunc { db.Mgr.DeleteToken(claim.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles) - accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles) session.SetToken(userIdStr, refreshToken) utils.SetCookie(c, accessToken) diff --git a/server/main.go b/server/main.go index 58a3b4f..d189314 100644 --- a/server/main.go +++ b/server/main.go @@ -9,6 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/handlers" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/session" + "github.com/authorizerdev/authorizer/server/utils" "github.com/gin-contrib/location" "github.com/gin-gonic/gin" ) @@ -50,6 +51,7 @@ func main() { db.InitDB() session.InitSession() oauth.InitOAuth() + utils.InitServer() r := gin.Default() r.Use(location.Default()) diff --git a/server/resolvers/login.go b/server/resolvers/login.go index c168ae3..f6f5ede 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -46,16 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e log.Println("Compare password error:", err) return res, fmt.Errorf(`invalid password`) } - userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + role := constants.DEFAULT_ROLE + if params.Role != nil { + // validate role + if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) { + return res, fmt.Errorf(`invalid role`) + } - accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + role = *params.Role + } + userIdStr := fmt.Sprintf("%v", user.ID) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) + + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role) session.SetToken(userIdStr, refreshToken) @@ -71,6 +74,7 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, }, diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index 654dcbc..7d623c2 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -2,6 +2,7 @@ package resolvers import ( "context" + "fmt" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/session" @@ -25,7 +26,8 @@ func Logout(ctx context.Context) (*model.Response, error) { return res, err } - session.DeleteToken(claim.ID) + userId := fmt.Sprintf("%v", claim["id"]) + session.DeleteToken(userId) res = &model.Response{ Message: "Logged out successfully", } diff --git a/server/resolvers/profile.go b/server/resolvers/profile.go index 1e64a96..3c21676 100644 --- a/server/resolvers/profile.go +++ b/server/resolvers/profile.go @@ -3,6 +3,7 @@ package resolvers import ( "context" "fmt" + "strings" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" @@ -27,13 +28,15 @@ func Profile(ctx context.Context) (*model.User, error) { return res, err } - sessionToken := session.GetToken(claim.ID) + userID := fmt.Sprintf("%v", claim["id"]) + email := fmt.Sprintf("%v", claim["email"]) + sessionToken := session.GetToken(userID) if sessionToken == "" { return res, fmt.Errorf(`unauthorized`) } - user, err := db.Mgr.GetUserByEmail(claim.Email) + user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } @@ -48,6 +51,7 @@ func Profile(ctx context.Context) (*model.User, error) { LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, } diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 1c82515..14809aa 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -35,6 +35,20 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, return res, fmt.Errorf(`invalid email address`) } + inputRoles := []string{} + + if params.Roles != nil && len(params.Roles) > 0 { + // check if roles exists + for _, item := range params.Roles { + inputRoles = append(inputRoles, *item) + } + if !utils.IsValidRolesArray(inputRoles) { + return res, fmt.Errorf(`invalid roles`) + } + } else { + inputRoles = []string{constants.DEFAULT_ROLE} + } + // find user with email existingUser, err := db.Mgr.GetUserByEmail(params.Email) if err != nil { @@ -49,6 +63,8 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, Email: params.Email, } + user.Roles = strings.Join(inputRoles, ",") + password, _ := utils.HashPassword(params.Password) user.Password = password @@ -77,6 +93,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, } @@ -106,15 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, } } else { - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE) - accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE) session.SetToken(userIdStr, refreshToken) res = &model.AuthResponse{ diff --git a/server/resolvers/token.go b/server/resolvers/token.go index 9bb3e1d..892b268 100644 --- a/server/resolvers/token.go +++ b/server/resolvers/token.go @@ -3,8 +3,10 @@ package resolvers import ( "context" "fmt" + "strings" "time" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/graph/model" @@ -12,7 +14,7 @@ import ( "github.com/authorizerdev/authorizer/server/utils" ) -func Token(ctx context.Context) (*model.AuthResponse, error) { +func Token(ctx context.Context, role *string) (*model.AuthResponse, error) { var res *model.AuthResponse gc, err := utils.GinContextFromContext(ctx) @@ -25,13 +27,19 @@ func Token(ctx context.Context) (*model.AuthResponse, error) { } claim, accessTokenErr := utils.VerifyAuthToken(token) - expiresAt := claim.ExpiresAt + expiresAt := claim["exp"].(int64) + email := fmt.Sprintf("%v", claim["email"]) - user, err := db.Mgr.GetUserByEmail(claim.Email) + claimRole := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM]) + user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } + if role != nil && role != &claimRole { + return res, fmt.Errorf(`unauthorized. invalid role for a given token`) + } + userIdStr := fmt.Sprintf("%v", user.ID) sessionToken := session.GetToken(userIdStr) @@ -46,10 +54,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) { if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 { // if access token has expired and refresh/session token is valid // generate new accessToken - token, expiresAt, _ = utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, claimRole) } utils.SetCookie(gc, token) res = &model.AuthResponse{ @@ -62,6 +67,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) { Image: &user.Image, FirstName: &user.FirstName, LastName: &user.LastName, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, }, diff --git a/server/resolvers/updateProfile.go b/server/resolvers/updateProfile.go index db4da8d..25f8095 100644 --- a/server/resolvers/updateProfile.go +++ b/server/resolvers/updateProfile.go @@ -32,7 +32,8 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model return res, err } - sessionToken := session.GetToken(claim.ID) + id := fmt.Sprintf("%v", claim["id"]) + sessionToken := session.GetToken(id) if sessionToken == "" { return res, fmt.Errorf(`unauthorized`) @@ -43,7 +44,8 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model return res, fmt.Errorf("please enter atleast one param to update") } - user, err := db.Mgr.GetUserByEmail(claim.Email) + email := fmt.Sprintf("%v", claim["email"]) + user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } @@ -120,9 +122,33 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model go func() { utils.SendVerificationMail(newEmail, token) }() - } + // TODO this idea needs to be verified otherwise every user can make themselves super admin + // rolesToSave := "" + // if params.Roles != nil && len(params.Roles) > 0 { + // currentRoles := strings.Split(user.Roles, ",") + // inputRoles := []string{} + // for _, item := range params.Roles { + // inputRoles = append(inputRoles, *item) + // } + + // if !utils.IsValidRolesArray(inputRoles) { + // return res, fmt.Errorf("invalid list of roles") + // } + + // if !utils.IsStringArrayEqual(inputRoles, currentRoles) { + // rolesToSave = strings.Join(inputRoles, ",") + // } + + // session.DeleteToken(fmt.Sprintf("%v", user.ID)) + // utils.DeleteCookie(gc) + // } + + // if rolesToSave != "" { + // user.Roles = rolesToSave + // } + _, err = db.Mgr.UpdateUser(user) if err != nil { log.Println("Error updating user:", err) diff --git a/server/resolvers/verifyEmail.go b/server/resolvers/verifyEmail.go index e05abcb..106cc05 100644 --- a/server/resolvers/verifyEmail.go +++ b/server/resolvers/verifyEmail.go @@ -3,8 +3,10 @@ package resolvers import ( "context" "fmt" + "strings" "time" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/graph/model" @@ -41,15 +43,9 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut db.Mgr.DeleteToken(claim.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.RefreshToken) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE) - accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ - ID: userIdStr, - Email: user.Email, - }, enum.AccessToken) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE) session.SetToken(userIdStr, refreshToken) @@ -65,6 +61,7 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut LastName: &user.LastName, SignupMethod: user.SignupMethod, EmailVerifiedAt: &user.EmailVerifiedAt, + Roles: strings.Split(user.Roles, ","), CreatedAt: &user.CreatedAt, UpdatedAt: &user.UpdatedAt, }, diff --git a/server/utils/authToken.go b/server/utils/authToken.go index 0ce6b67..72f4f8a 100644 --- a/server/utils/authToken.go +++ b/server/utils/authToken.go @@ -1,29 +1,32 @@ package utils import ( + "encoding/json" "fmt" "log" "strings" "time" "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" ) -type UserAuthInfo struct { - Email string `json:"email"` - ID string `json:"id"` -} +// type UserAuthInfo struct { +// Email string `json:"email"` +// ID string `json:"id"` +// } + +type JWTCustomClaim map[string]interface{} type UserAuthClaim struct { *jwt.StandardClaims - TokenType string `json:"token_type"` - UserAuthInfo + *JWTCustomClaim `json:"authorizer"` } -func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64, error) { +func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) { t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE)) expiryBound := time.Hour if tokenType == enum.RefreshToken { @@ -33,12 +36,19 @@ func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64 expiresAt := time.Now().Add(expiryBound).Unix() + customClaims := JWTCustomClaim{ + "token_type": tokenType.String(), + "email": user.Email, + "id": user.ID, + "allowed_roles": strings.Split(user.Roles, ","), + constants.JWT_ROLE_CLAIM: role, + } + t.Claims = &UserAuthClaim{ &jwt.StandardClaims{ ExpiresAt: expiresAt, }, - tokenType.String(), - user, + &customClaims, } token, err := t.SignedString([]byte(constants.JWT_SECRET)) @@ -63,14 +73,20 @@ func GetAuthToken(gc *gin.Context) (string, error) { return token, nil } -func VerifyAuthToken(token string) (*UserAuthClaim, error) { +func VerifyAuthToken(token string) (map[string]interface{}, error) { + var res map[string]interface{} claims := &UserAuthClaim{} + _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { return []byte(constants.JWT_SECRET), nil }) if err != nil { - return claims, err + return res, err } - return claims, nil + data, _ := json.Marshal(claims.JWTCustomClaim) + json.Unmarshal(data, &res) + res["exp"] = claims.ExpiresAt + + return res, nil } diff --git a/server/utils/common.go b/server/utils/common.go new file mode 100644 index 0000000..d2d713d --- /dev/null +++ b/server/utils/common.go @@ -0,0 +1,20 @@ +package utils + +import ( + "io" + "os" +) + +func WriteToFile(filename string, data string) error { + file, err := os.Create(filename) + if err != nil { + return err + } + defer file.Close() + + _, err = io.WriteString(file, data) + if err != nil { + return err + } + return file.Sync() +} diff --git a/server/utils/initServer.go b/server/utils/initServer.go new file mode 100644 index 0000000..a15c08d --- /dev/null +++ b/server/utils/initServer.go @@ -0,0 +1,25 @@ +package utils + +import ( + "log" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" +) + +// any jobs that we want to run at start of server can be executed here + +// 1. create roles table and add the roles list from env to table + +func InitServer() { + roles := []db.Role{} + for _, val := range constants.ROLES { + roles = append(roles, db.Role{ + Role: val, + }) + } + err := db.Mgr.SaveRoles(roles) + if err != nil { + log.Println(`Error saving roles`, err) + } +} diff --git a/server/utils/validateSuperAdmin.go b/server/utils/validateSuperAdmin.go deleted file mode 100644 index c19d62b..0000000 --- a/server/utils/validateSuperAdmin.go +++ /dev/null @@ -1,15 +0,0 @@ -package utils - -import ( - "github.com/authorizerdev/authorizer/server/constants" - "github.com/gin-gonic/gin" -) - -func IsSuperAdmin(gc *gin.Context) bool { - secret := gc.Request.Header.Get("x-authorizer-admin-secret") - if secret == "" { - return false - } - - return secret == constants.ADMIN_SECRET -} diff --git a/server/utils/validator.go b/server/utils/validator.go index b953045..c1f2fed 100644 --- a/server/utils/validator.go +++ b/server/utils/validator.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/authorizerdev/authorizer/server/constants" + "github.com/gin-gonic/gin" ) func IsValidEmail(email string) bool { @@ -29,3 +30,52 @@ func IsValidRedirectURL(url string) bool { return hasValidURL } + +func IsSuperAdmin(gc *gin.Context) bool { + secret := gc.Request.Header.Get("x-authorizer-admin-secret") + if secret == "" { + return false + } + + return secret == constants.ADMIN_SECRET +} + +func IsValidRolesArray(roles []string) bool { + valid := true + currentRoleMap := map[string]bool{} + + for _, currentRole := range constants.ROLES { + currentRoleMap[currentRole] = true + } + for _, inputRole := range roles { + if !currentRoleMap[inputRole] { + valid = false + break + } + } + return valid +} + +func IsValidRole(userRoles []string, role string) bool { + valid := false + for _, currentRole := range userRoles { + if role == currentRole { + valid = true + break + } + } + + return valid +} + +func IsStringArrayEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +}