diff --git a/.env.sample b/.env.sample index f5d9f37..b2dffd7 100644 --- a/.env.sample +++ b/.env.sample @@ -5,6 +5,7 @@ ADMIN_SECRET=admin DISABLE_EMAIL_VERIFICATION=true JWT_SECRET=random_string JWT_TYPE=HS256 -ROLES=user,admin -DEFAULT_ROLE=user +ROLES=user +DEFAULT_ROLES=user +PROTECTED_ROLES=admin JWT_ROLE_CLAIM=role \ No newline at end of file diff --git a/server/constants/constants.go b/server/constants/constants.go index 48059e1..e8143b7 100644 --- a/server/constants/constants.go +++ b/server/constants/constants.go @@ -23,9 +23,10 @@ var ( DISABLE_BASIC_AUTHENTICATION = "false" // ROLES - ROLES = []string{} - DEFAULT_ROLE = "" - JWT_ROLE_CLAIM = "role" + ROLES = []string{} + PROTECTED_ROLES = []string{} + DEFAULT_ROLES = []string{} + JWT_ROLE_CLAIM = "role" // OAuth login GOOGLE_CLIENT_ID = "" diff --git a/server/env.go b/server/env.go index cd56f22..31dd673 100644 --- a/server/env.go +++ b/server/env.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/utils" "github.com/joho/godotenv" ) @@ -63,7 +64,6 @@ func InitEnv() { constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/") constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION") constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION") - constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE") constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM") if constants.ADMIN_SECRET == "" { @@ -136,7 +136,26 @@ func InitEnv() { rolesSplit := strings.Split(os.Getenv("ROLES"), ",") roles := []string{} - defaultRole := "" + if len(rolesSplit) == 0 { + roles = []string{"user"} + } + + defaultRoleSplit := strings.Split(os.Getenv("DEFAULT_ROLES"), ",") + defaultRoles := []string{} + + if len(defaultRoleSplit) == 0 { + defaultRoles = []string{"user"} + } + + protectedRolesSplit := strings.Split(os.Getenv("PROTECTED_ROLES"), ",") + protectedRoles := []string{} + + if len(protectedRolesSplit) > 0 { + for _, val := range protectedRolesSplit { + trimVal := strings.TrimSpace(val) + protectedRoles = append(protectedRoles, trimVal) + } + } for _, val := range rolesSplit { trimVal := strings.TrimSpace(val) @@ -144,20 +163,18 @@ func InitEnv() { roles = append(roles, trimVal) } - if trimVal == constants.DEFAULT_ROLE { - defaultRole = trimVal + if utils.StringContains(defaultRoleSplit, trimVal) { + defaultRoles = append(defaultRoles, trimVal) } } - if len(roles) > 0 && defaultRole == "" { + + if len(roles) > 0 && len(defaultRoles) == 0 && len(defaultRoleSplit) > 0 { panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) } - if len(roles) == 0 { - roles = []string{"user", "admin"} - constants.DEFAULT_ROLE = "user" - } - constants.ROLES = roles + constants.DEFAULT_ROLES = defaultRoles + constants.PROTECTED_ROLES = protectedRoles if constants.JWT_ROLE_CLAIM == "" { constants.JWT_ROLE_CLAIM = "role" diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index d1ea7c0..8e4284b 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -81,7 +81,7 @@ type ComplexityRoot struct { Query struct { Meta func(childComplexity int) int Profile func(childComplexity int) int - Token func(childComplexity int, role *string) int + Token func(childComplexity int, roles []string) int Users func(childComplexity int) int VerificationRequests func(childComplexity int) int } @@ -129,7 +129,7 @@ type MutationResolver interface { type QueryResolver interface { Meta(ctx context.Context) (*model.Meta, error) Users(ctx context.Context) ([]*model.User, error) - Token(ctx context.Context, role *string) (*model.AuthResponse, error) + Token(ctx context.Context, roles []string) (*model.AuthResponse, error) Profile(ctx context.Context) (*model.User, error) VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error) } @@ -379,7 +379,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true + return e.complexity.Query.Token(childComplexity, args["roles"].([]string)), true case "Query.users": if e.complexity.Query.Users == nil { @@ -648,13 +648,13 @@ input SignUpInput { password: String! confirmPassword: String! image: String - roles: [String] + roles: [String!] } input LoginInput { email: String! password: String! - role: String + roles: [String!] } input VerifyEmailInput { @@ -677,12 +677,12 @@ input UpdateProfileInput { } input AdminUpdateUserInput { - id: ID! - email: String - firstName: String - lastName: String - image: String - roles: [String] + id: ID! + email: String + firstName: String + lastName: String + image: String + roles: [String] } input ForgotPasswordInput { @@ -704,7 +704,7 @@ type Mutation { login(params: LoginInput!): AuthResponse! logout: Response! updateProfile(params: UpdateProfileInput!): Response! - adminUpdateUser(params: AdminUpdateUserInput!): User! + adminUpdateUser(params: AdminUpdateUserInput!): User! verifyEmail(params: VerifyEmailInput!): AuthResponse! resendVerifyEmail(params: ResendVerifyEmailInput!): Response! forgotPassword(params: ForgotPasswordInput!): Response! @@ -715,7 +715,7 @@ type Mutation { type Query { meta: Meta! users: [User!]! - token(role: String): AuthResponse + token(roles: [String!]): AuthResponse profile: User! verificationRequests: [VerificationRequest!]! } @@ -880,15 +880,15 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} - var arg0 *string - if tmp, ok := rawArgs["role"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role")) - arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp) + var arg0 []string + if tmp, ok := rawArgs["roles"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) + arg0, err = ec.unmarshalOString2ᚕstringᚄ(ctx, tmp) if err != nil { return nil, err } } - args["role"] = arg0 + args["roles"] = arg0 return args, nil } @@ -1884,7 +1884,7 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll fc.Args = args resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Token(rctx, args["role"].(*string)) + return ec.resolvers.Query().Token(rctx, args["roles"].([]string)) }) if err != nil { ec.Error(ctx, err) @@ -3842,11 +3842,11 @@ func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj in if err != nil { return it, err } - case "role": + case "roles": var err error - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role")) - it.Role, err = ec.unmarshalOString2ᚖstring(ctx, v) + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) + it.Roles, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) if err != nil { return it, err } @@ -3970,7 +3970,7 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) - it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v) + it.Roles, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) if err != nil { return it, err } @@ -5280,6 +5280,42 @@ func (ec *executionContext) marshalOString2string(ctx context.Context, sel ast.S return graphql.MarshalString(v) } +func (ec *executionContext) unmarshalOString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) { + if v == nil { + return nil, nil + } + var vSlice []interface{} + if v != nil { + if tmp1, ok := v.([]interface{}); ok { + vSlice = tmp1 + } else { + vSlice = []interface{}{v} + } + } + var err error + res := make([]string, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalNString2string(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) marshalOString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + for i := range v { + ret[i] = ec.marshalNString2string(ctx, sel, v[i]) + } + + return ret +} + func (ec *executionContext) unmarshalOString2ᚕᚖstring(ctx context.Context, v interface{}) ([]*string, error) { if v == nil { return nil, nil diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index ae2c163..8ffcd8a 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -32,9 +32,9 @@ type ForgotPasswordInput struct { } type LoginInput struct { - Email string `json:"email"` - Password string `json:"password"` - Role *string `json:"role"` + Email string `json:"email"` + Password string `json:"password"` + Roles []string `json:"roles"` } type Meta struct { @@ -62,13 +62,13 @@ type Response struct { } type SignUpInput struct { - FirstName *string `json:"firstName"` - LastName *string `json:"lastName"` - Email string `json:"email"` - Password string `json:"password"` - ConfirmPassword string `json:"confirmPassword"` - Image *string `json:"image"` - Roles []*string `json:"roles"` + FirstName *string `json:"firstName"` + LastName *string `json:"lastName"` + Email string `json:"email"` + Password string `json:"password"` + ConfirmPassword string `json:"confirmPassword"` + Image *string `json:"image"` + Roles []string `json:"roles"` } type UpdateProfileInput struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index cb10e84..1793ff5 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -61,13 +61,13 @@ input SignUpInput { password: String! confirmPassword: String! image: String - roles: [String] + roles: [String!] } input LoginInput { email: String! password: String! - role: String + roles: [String!] } input VerifyEmailInput { @@ -90,12 +90,12 @@ input UpdateProfileInput { } input AdminUpdateUserInput { - id: ID! - email: String - firstName: String - lastName: String - image: String - roles: [String] + id: ID! + email: String + firstName: String + lastName: String + image: String + roles: [String] } input ForgotPasswordInput { @@ -117,7 +117,7 @@ type Mutation { login(params: LoginInput!): AuthResponse! logout: Response! updateProfile(params: UpdateProfileInput!): Response! - adminUpdateUser(params: AdminUpdateUserInput!): User! + adminUpdateUser(params: AdminUpdateUserInput!): User! verifyEmail(params: VerifyEmailInput!): AuthResponse! resendVerifyEmail(params: ResendVerifyEmailInput!): Response! forgotPassword(params: ForgotPasswordInput!): Response! @@ -128,7 +128,7 @@ type Mutation { type Query { meta: Meta! users: [User!]! - token(role: String): AuthResponse + token(roles: [String!]): AuthResponse profile: User! verificationRequests: [VerificationRequest!]! } diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index 7a55de6..ae1f687 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -59,8 +59,8 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) { return resolvers.Users(ctx) } -func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) { - return resolvers.Token(ctx, role) +func (r *queryResolver) Token(ctx context.Context, roles []string) (*model.AuthResponse, error) { + return resolvers.Token(ctx, roles) } func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) { diff --git a/server/handlers/oauthCallback.go b/server/handlers/oauthCallback.go index 76201c2..3abc0af 100644 --- a/server/handlers/oauthCallback.go +++ b/server/handlers/oauthCallback.go @@ -19,28 +19,29 @@ import ( "golang.org/x/oauth2" ) -func processGoogleUserInfo(code string, role string, c *gin.Context) error { +func processGoogleUserInfo(code string, roles []string, c *gin.Context) (db.User, error) { + user := db.User{} token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code) if err != nil { - return fmt.Errorf("invalid google exchange code: %s", err.Error()) + return user, fmt.Errorf("invalid google exchange code: %s", err.Error()) } client := oauth.OAuthProvider.GoogleConfig.Client(oauth2.NoContext, token) response, err := client.Get(constants.GoogleUserInfoURL) if err != nil { - return err + return user, err } defer response.Body.Close() body, err := ioutil.ReadAll(response.Body) if err != nil { - return fmt.Errorf("failed to read google response body: %s", err.Error()) + return user, fmt.Errorf("failed to read google response body: %s", err.Error()) } userRawData := make(map[string]string) json.Unmarshal(body, &userRawData) existingUser, err := db.Mgr.GetUserByEmail(userRawData["email"]) - user := db.User{ + user = db.User{ FirstName: userRawData["given_name"], LastName: userRawData["family_name"], Image: userRawData["picture"], @@ -50,7 +51,7 @@ func processGoogleUserInfo(code string, role string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Google.String() - user.Roles = role + user.Roles = strings.Join(roles, ",") } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -61,34 +62,25 @@ func processGoogleUserInfo(code string, role string, c *gin.Context) error { } user.SignupMethod = signupMethod user.Password = existingUser.Password - if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { - return fmt.Errorf("invalid role") + if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) { + return user, fmt.Errorf("invalid role") } user.Roles = existingUser.Roles } - - user, _ = db.Mgr.SaveUser(user) - user, _ = db.Mgr.GetUserByEmail(user.Email) - userIdStr := fmt.Sprintf("%v", user.ID) - - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - - accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) - utils.SetCookie(c, accessToken) - session.SetToken(userIdStr, refreshToken) - return nil + return user, nil } -func processGithubUserInfo(code string, role string, c *gin.Context) error { +func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User, error) { + user := db.User{} token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code) if err != nil { - return fmt.Errorf("invalid github exchange code: %s", err.Error()) + return user, fmt.Errorf("invalid github exchange code: %s", err.Error()) } client := http.Client{} req, err := http.NewRequest("GET", constants.GithubUserInfoURL, nil) if err != nil { - return fmt.Errorf("error creating github user info request: %s", err.Error()) + return user, fmt.Errorf("error creating github user info request: %s", err.Error()) } req.Header = http.Header{ "Authorization": []string{fmt.Sprintf("token %s", token.AccessToken)}, @@ -96,13 +88,13 @@ func processGithubUserInfo(code string, role string, c *gin.Context) error { response, err := client.Do(req) if err != nil { - return err + return user, err } defer response.Body.Close() body, err := ioutil.ReadAll(response.Body) if err != nil { - return fmt.Errorf("failed to read github response body: %s", err.Error()) + return user, fmt.Errorf("failed to read github response body: %s", err.Error()) } userRawData := make(map[string]string) @@ -118,7 +110,7 @@ func processGithubUserInfo(code string, role string, c *gin.Context) error { if len(name) > 1 && strings.TrimSpace(name[1]) != "" { lastName = name[0] } - user := db.User{ + user = db.User{ FirstName: firstName, LastName: lastName, Image: userRawData["avatar_url"], @@ -128,7 +120,7 @@ func processGithubUserInfo(code string, role string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Github.String() - user.Roles = role + user.Roles = strings.Join(roles, ",") } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -140,45 +132,38 @@ func processGithubUserInfo(code string, role string, c *gin.Context) error { user.SignupMethod = signupMethod user.Password = existingUser.Password - if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { - return fmt.Errorf("invalid role") + if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) { + return user, fmt.Errorf("invalid role") } user.Roles = existingUser.Roles } - user, _ = db.Mgr.SaveUser(user) - user, _ = db.Mgr.GetUserByEmail(user.Email) - userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - - accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) - utils.SetCookie(c, accessToken) - session.SetToken(userIdStr, refreshToken) - return nil + return user, nil } -func processFacebookUserInfo(code string, role string, c *gin.Context) error { +func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.User, error) { + user := db.User{} token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code) if err != nil { - return fmt.Errorf("invalid facebook exchange code: %s", err.Error()) + return user, fmt.Errorf("invalid facebook exchange code: %s", err.Error()) } client := http.Client{} req, err := http.NewRequest("GET", constants.FacebookUserInfoURL+token.AccessToken, nil) if err != nil { - return fmt.Errorf("error creating facebook user info request: %s", err.Error()) + return user, fmt.Errorf("error creating facebook user info request: %s", err.Error()) } response, err := client.Do(req) if err != nil { log.Println("err:", err) - return err + return user, err } defer response.Body.Close() body, err := ioutil.ReadAll(response.Body) if err != nil { - return fmt.Errorf("failed to read facebook response body: %s", err.Error()) + return user, fmt.Errorf("failed to read facebook response body: %s", err.Error()) } userRawData := make(map[string]interface{}) @@ -189,7 +174,7 @@ func processFacebookUserInfo(code string, role string, c *gin.Context) error { picObject := userRawData["picture"].(map[string]interface{})["data"] picDataObject := picObject.(map[string]interface{}) - user := db.User{ + user = db.User{ FirstName: fmt.Sprintf("%v", userRawData["first_name"]), LastName: fmt.Sprintf("%v", userRawData["last_name"]), Image: fmt.Sprintf("%v", picDataObject["url"]), @@ -200,7 +185,7 @@ func processFacebookUserInfo(code string, role string, c *gin.Context) error { if err != nil { // user not registered, register user and generate session token user.SignupMethod = enum.Github.String() - user.Roles = role + user.Roles = strings.Join(roles, ",") } else { // user exists in db, check if method was google // if not append google to existing signup method and save it @@ -212,22 +197,14 @@ func processFacebookUserInfo(code string, role string, c *gin.Context) error { user.SignupMethod = signupMethod user.Password = existingUser.Password - if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) { - return fmt.Errorf("invalid role") + if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) { + return user, fmt.Errorf("invalid role") } user.Roles = existingUser.Roles } - user, _ = db.Mgr.SaveUser(user) - user, _ = db.Mgr.GetUserByEmail(user.Email) - userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) - - accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role) - utils.SetCookie(c, accessToken) - session.SetToken(userIdStr, refreshToken) - return nil + return user, nil } func OAuthCallbackHandler() gin.HandlerFunc { @@ -249,18 +226,19 @@ func OAuthCallbackHandler() gin.HandlerFunc { return } - role := sessionSplit[2] + roles := strings.Split(sessionSplit[2], ",") redirectURL := sessionSplit[1] var err error + user := db.User{} code := c.Request.FormValue("code") switch provider { case enum.Google.String(): - err = processGoogleUserInfo(code, role, c) + user, err = processGoogleUserInfo(code, roles, c) case enum.Github.String(): - err = processGithubUserInfo(code, role, c) + user, err = processGithubUserInfo(code, roles, c) case enum.Facebook.String(): - err = processFacebookUserInfo(code, role, c) + user, err = processFacebookUserInfo(code, roles, c) default: err = fmt.Errorf(`invalid oauth provider`) } @@ -269,6 +247,16 @@ func OAuthCallbackHandler() gin.HandlerFunc { c.JSON(400, gin.H{"error": err.Error()}) return } + + user, _ = db.Mgr.SaveUser(user) + user, _ = db.Mgr.GetUserByEmail(user.Email) + userIdStr := fmt.Sprintf("%v", user.ID) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles) + + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, roles) + utils.SetCookie(c, accessToken) + session.SetToken(userIdStr, refreshToken) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) } } diff --git a/server/handlers/oauthLogin.go b/server/handlers/oauthLogin.go index 5730d9d..31aca8e 100644 --- a/server/handlers/oauthLogin.go +++ b/server/handlers/oauthLogin.go @@ -2,6 +2,7 @@ package handlers import ( "net/http" + "strings" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/enum" @@ -18,7 +19,7 @@ func OAuthLoginHandler() gin.HandlerFunc { return func(c *gin.Context) { // TODO validate redirect URL redirectURL := c.Query("redirectURL") - role := c.Query("role") + roles := c.Query("roles") if redirectURL == "" { c.JSON(400, gin.H{ @@ -27,20 +28,24 @@ func OAuthLoginHandler() gin.HandlerFunc { return } - if role != "" { + if roles != "" { // validate role - if !utils.IsValidRole(constants.ROLES, role) { + rolesSplit := strings.Split(roles, ",") + + // use protected roles verification for admin login only. + // though if not associated with user, it will be rejected from oauth_callback + if !utils.IsValidRoles(append([]string{}, append(constants.ROLES, constants.PROTECTED_ROLES...)...), rolesSplit) { c.JSON(400, gin.H{ "error": "invalid role", }) return } } else { - role = constants.DEFAULT_ROLE + roles = strings.Join(constants.DEFAULT_ROLES, ",") } uuid := uuid.New() - oauthStateString := uuid.String() + "___" + redirectURL + "___" + role + oauthStateString := uuid.String() + "___" + redirectURL + "___" + roles provider := c.Param("oauth_provider") diff --git a/server/handlers/verifyEmail.go b/server/handlers/verifyEmail.go index 3f7e956..2e768ca 100644 --- a/server/handlers/verifyEmail.go +++ b/server/handlers/verifyEmail.go @@ -3,6 +3,7 @@ package handlers import ( "fmt" "net/http" + "strings" "time" "github.com/authorizerdev/authorizer/server/db" @@ -50,9 +51,10 @@ func VerifyEmailHandler() gin.HandlerFunc { db.Mgr.DeleteToken(claim.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles) + roles := strings.Split(user.Roles, ",") + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles) - accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, roles) session.SetToken(userIdStr, refreshToken) utils.SetCookie(c, accessToken) diff --git a/server/resolvers/login.go b/server/resolvers/login.go index f6f5ede..b112f3f 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -46,19 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e log.Println("Compare password error:", err) return res, fmt.Errorf(`invalid password`) } - role := constants.DEFAULT_ROLE - if params.Role != nil { - // validate role - if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) { - return res, fmt.Errorf(`invalid role`) + roles := constants.DEFAULT_ROLES + currentRoles := strings.Split(user.Roles, ",") + if len(params.Roles) > 0 { + if !utils.IsValidRoles(currentRoles, params.Roles) { + return res, fmt.Errorf(`invalid roles`) } - role = *params.Role + roles = params.Roles } userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles) - accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, roles) session.SetToken(userIdStr, refreshToken) diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 14809aa..2eeb982 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -37,16 +37,15 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, inputRoles := []string{} - if params.Roles != nil && len(params.Roles) > 0 { + if len(params.Roles) > 0 { // check if roles exists - for _, item := range params.Roles { - inputRoles = append(inputRoles, *item) - } - if !utils.IsValidRolesArray(inputRoles) { + if !utils.IsValidRolesArray(params.Roles) { return res, fmt.Errorf(`invalid roles`) + } else { + inputRoles = params.Roles } } else { - inputRoles = []string{constants.DEFAULT_ROLE} + inputRoles = constants.DEFAULT_ROLES } // find user with email @@ -85,6 +84,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, return res, err } userIdStr := fmt.Sprintf("%v", user.ID) + roles := strings.Split(user.Roles, ",") userToReturn := &model.User{ ID: userIdStr, Email: user.Email, @@ -123,9 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse, } } else { - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles) - accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, roles) session.SetToken(userIdStr, refreshToken) res = &model.AuthResponse{ diff --git a/server/resolvers/token.go b/server/resolvers/token.go index 581db6e..9345300 100644 --- a/server/resolvers/token.go +++ b/server/resolvers/token.go @@ -14,7 +14,7 @@ import ( "github.com/authorizerdev/authorizer/server/utils" ) -func Token(ctx context.Context, role *string) (*model.AuthResponse, error) { +func Token(ctx context.Context, roles []string) (*model.AuthResponse, error) { var res *model.AuthResponse gc, err := utils.GinContextFromContext(ctx) @@ -30,16 +30,11 @@ func Token(ctx context.Context, role *string) (*model.AuthResponse, error) { expiresAt := claim["exp"].(int64) email := fmt.Sprintf("%v", claim["email"]) - claimRole := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM]) user, err := db.Mgr.GetUserByEmail(email) if err != nil { return res, err } - if role != nil && *role != claimRole { - return res, fmt.Errorf(`unauthorized`) - } - userIdStr := fmt.Sprintf("%v", user.ID) sessionToken := session.GetToken(userIdStr) @@ -47,15 +42,30 @@ func Token(ctx context.Context, role *string) (*model.AuthResponse, error) { if sessionToken == "" { return res, fmt.Errorf(`unauthorized`) } - // TODO check if refresh/session token has expired expiresTimeObj := time.Unix(expiresAt, 0) currentTimeObj := time.Now() + + claimRoleInterface := claim[constants.JWT_ROLE_CLAIM].([]interface{}) + claimRoles := make([]string, len(claimRoleInterface)) + for i, v := range claimRoleInterface { + claimRoles[i] = v.(string) + } + + if len(roles) > 0 { + for _, v := range roles { + if !utils.StringContains(claimRoles, v) { + return res, fmt.Errorf(`unauthorized`) + } + } + } + if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 { // if access token has expired and refresh/session token is valid // generate new accessToken - token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, claimRole) + token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, claimRoles) } + utils.SetCookie(gc, token) res = &model.AuthResponse{ Message: `Token verified`, diff --git a/server/resolvers/verifyEmail.go b/server/resolvers/verifyEmail.go index 106cc05..92bccd8 100644 --- a/server/resolvers/verifyEmail.go +++ b/server/resolvers/verifyEmail.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/graph/model" @@ -43,9 +42,10 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut db.Mgr.DeleteToken(claim.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE) + roles := strings.Split(user.Roles, ",") + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles) - accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE) + accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, roles) session.SetToken(userIdStr, refreshToken) diff --git a/server/utils/authToken.go b/server/utils/authToken.go index 72f4f8a..71ddac5 100644 --- a/server/utils/authToken.go +++ b/server/utils/authToken.go @@ -26,7 +26,7 @@ type UserAuthClaim struct { *JWTCustomClaim `json:"authorizer"` } -func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) { +func CreateAuthToken(user db.User, tokenType enum.TokenType, roles []string) (string, int64, error) { t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE)) expiryBound := time.Hour if tokenType == enum.RefreshToken { @@ -41,7 +41,7 @@ func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (strin "email": user.Email, "id": user.ID, "allowed_roles": strings.Split(user.Roles, ","), - constants.JWT_ROLE_CLAIM: role, + constants.JWT_ROLE_CLAIM: roles, } t.Claims = &UserAuthClaim{ diff --git a/server/utils/common.go b/server/utils/common.go index d2d713d..82f2570 100644 --- a/server/utils/common.go +++ b/server/utils/common.go @@ -18,3 +18,12 @@ func WriteToFile(filename string, data string) error { } return file.Sync() } + +func StringContains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/server/utils/validator.go b/server/utils/validator.go index c1f2fed..192289c 100644 --- a/server/utils/validator.go +++ b/server/utils/validator.go @@ -56,11 +56,11 @@ func IsValidRolesArray(roles []string) bool { return valid } -func IsValidRole(userRoles []string, role string) bool { - valid := false - for _, currentRole := range userRoles { - if role == currentRole { - valid = true +func IsValidRoles(userRoles []string, roles []string) bool { + valid := true + for _, role := range roles { + if !StringContains(userRoles, role) { + valid = false break } }