feat: add role based oauth

This commit is contained in:
Lakhan Samani
2021-09-20 10:00:17 +05:30
parent aec5d5a0c7
commit 94cdbc9268
22 changed files with 307 additions and 171 deletions

View File

@@ -7,3 +7,4 @@ JWT_SECRET=random_string
JWT_TYPE=HS256 JWT_TYPE=HS256
ROLES=user,admin ROLES=user,admin
DEFAULT_ROLE=user DEFAULT_ROLE=user
JWT_ROLE_CLAIM=role

View File

@@ -9,8 +9,8 @@ For the first version we will only support setting roles master list via env
- [x] `DEFAULT_ROLE` -> default role to assign to users - [x] `DEFAULT_ROLE` -> default role to assign to users
- [x] Add roles input for signup - [x] Add roles input for signup
- [x] Add roles to update profile mutation - [x] Add roles to update profile mutation
- [ ] Add roles input for login - [x] Add roles input for login
- [ ] Return roles to user - [x] Return roles to user
- [ ] Return roles in users list for super admin - [x] Return roles in users list for super admin
- [ ] Add roles to the JWT token generation - [x] Add roles to the JWT token generation
- [ ] Validate token should also validate the role, if roles to validate again is present in request - [ ] Validate token should also validate the role, if roles to validate again is present in request

View File

@@ -24,8 +24,9 @@ var (
DISABLE_BASIC_AUTHENTICATION = "false" DISABLE_BASIC_AUTHENTICATION = "false"
// ROLES // ROLES
ROLES = []string{} ROLES = []string{}
DEFAULT_ROLE = "" DEFAULT_ROLE = ""
JWT_ROLE_CLAIM = "role"
// OAuth login // OAuth login
GOOGLE_CLIENT_ID = "" GOOGLE_CLIENT_ID = ""

View File

@@ -74,6 +74,7 @@ func InitEnv() {
constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION") constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION")
constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION") constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION")
constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE") constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE")
constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM")
if constants.ADMIN_SECRET == "" { if constants.ADMIN_SECRET == "" {
panic("root admin secret is required") panic("root admin secret is required")
@@ -148,6 +149,7 @@ func InitEnv() {
rolesSplit := strings.Split(os.Getenv("ROLES"), ",") rolesSplit := strings.Split(os.Getenv("ROLES"), ",")
roles := []string{} roles := []string{}
defaultRole := "" defaultRole := ""
for _, val := range rolesSplit { for _, val := range rolesSplit {
trimVal := strings.TrimSpace(val) trimVal := strings.TrimSpace(val)
if trimVal != "" { if trimVal != "" {
@@ -159,7 +161,7 @@ func InitEnv() {
} }
} }
if len(roles) > 0 && defaultRole == "" { if len(roles) > 0 && defaultRole == "" {
panic(`Invalid DEFAULT_ROLE environment. It can be one from give ROLES environment variable value`) panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`)
} }
if len(roles) == 0 { if len(roles) == 0 {
@@ -168,4 +170,8 @@ func InitEnv() {
} }
constants.ROLES = roles constants.ROLES = roles
if constants.JWT_ROLE_CLAIM == "" {
constants.JWT_ROLE_CLAIM = "role"
}
} }

View File

@@ -80,7 +80,7 @@ type ComplexityRoot struct {
Query struct { Query struct {
Meta func(childComplexity int) int Meta func(childComplexity int) int
Profile func(childComplexity int) int Profile func(childComplexity int) int
Token func(childComplexity int) int Token func(childComplexity int, role *string) int
Users func(childComplexity int) int Users func(childComplexity int) int
VerificationRequests func(childComplexity int) int VerificationRequests func(childComplexity int) int
} }
@@ -127,7 +127,7 @@ type MutationResolver interface {
type QueryResolver interface { type QueryResolver interface {
Meta(ctx context.Context) (*model.Meta, error) Meta(ctx context.Context) (*model.Meta, error)
Users(ctx context.Context) ([]*model.User, error) Users(ctx context.Context) ([]*model.User, error)
Token(ctx context.Context) (*model.AuthResponse, error) Token(ctx context.Context, role *string) (*model.AuthResponse, error)
Profile(ctx context.Context) (*model.User, error) Profile(ctx context.Context) (*model.User, error)
VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error) VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error)
} }
@@ -360,7 +360,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
break break
} }
return e.complexity.Query.Token(childComplexity), true args, err := ec.field_Query_token_args(context.TODO(), rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true
case "Query.users": case "Query.users":
if e.complexity.Query.Users == nil { if e.complexity.Query.Users == nil {
@@ -570,6 +575,8 @@ var sources = []*ast.Source{
# #
# https://gqlgen.com/getting-started/ # https://gqlgen.com/getting-started/
scalar Int64 scalar Int64
scalar Map
scalar Any
type Meta { type Meta {
version: String! version: String!
@@ -591,7 +598,7 @@ type User {
image: String image: String
createdAt: Int64 createdAt: Int64
updatedAt: Int64 updatedAt: Int64
roles: [String] roles: [String!]!
} }
type VerificationRequest { type VerificationRequest {
@@ -652,7 +659,7 @@ input UpdateProfileInput {
lastName: String lastName: String
image: String image: String
email: String email: String
roles: [String] # roles: [String]
} }
input ForgotPasswordInput { input ForgotPasswordInput {
@@ -684,7 +691,7 @@ type Mutation {
type Query { type Query {
meta: Meta! meta: Meta!
users: [User!]! users: [User!]!
token: AuthResponse token(role: String): AuthResponse
profile: User! profile: User!
verificationRequests: [VerificationRequest!]! verificationRequests: [VerificationRequest!]!
} }
@@ -831,6 +838,21 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
return args, nil return args, nil
} }
func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 *string
if tmp, ok := rawArgs["role"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp)
if err != nil {
return nil, err
}
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error var err error
args := map[string]interface{}{} args := map[string]interface{}{}
@@ -1772,9 +1794,16 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll
} }
ctx = graphql.WithFieldContext(ctx, fc) ctx = graphql.WithFieldContext(ctx, fc)
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.field_Query_token_args(ctx, rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().Token(rctx) return ec.resolvers.Query().Token(rctx, args["role"].(*string))
}) })
if err != nil { if err != nil {
ec.Error(ctx, err) ec.Error(ctx, err)
@@ -2286,11 +2315,14 @@ func (ec *executionContext) _User_roles(ctx context.Context, field graphql.Colle
return graphql.Null return graphql.Null
} }
if resTmp == nil { if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null return graphql.Null
} }
res := resTmp.([]*string) res := resTmp.([]string)
fc.Result = res fc.Result = res
return ec.marshalOString2ᚕstring(ctx, field.Selections, res) return ec.marshalNString2ᚕstring(ctx, field.Selections, res)
} }
func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) {
@@ -3869,14 +3901,6 @@ func (ec *executionContext) unmarshalInputUpdateProfileInput(ctx context.Context
if err != nil { if err != nil {
return it, err return it, err
} }
case "roles":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles"))
it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@@ -4268,6 +4292,9 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
out.Values[i] = ec._User_updatedAt(ctx, field, obj) out.Values[i] = ec._User_updatedAt(ctx, field, obj)
case "roles": case "roles":
out.Values[i] = ec._User_roles(ctx, field, obj) out.Values[i] = ec._User_roles(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@@ -4680,6 +4707,36 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S
return res return res
} }
func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) {
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) { func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) {
res, err := ec.unmarshalInputUpdateProfileInput(ctx, v) res, err := ec.unmarshalInputUpdateProfileInput(ctx, v)
return res, graphql.ErrorOnPath(ctx, err) return res, graphql.ErrorOnPath(ctx, err)

View File

@@ -63,27 +63,26 @@ type SignUpInput struct {
} }
type UpdateProfileInput struct { type UpdateProfileInput struct {
OldPassword *string `json:"oldPassword"` OldPassword *string `json:"oldPassword"`
NewPassword *string `json:"newPassword"` NewPassword *string `json:"newPassword"`
ConfirmNewPassword *string `json:"confirmNewPassword"` ConfirmNewPassword *string `json:"confirmNewPassword"`
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
Image *string `json:"image"` Image *string `json:"image"`
Email *string `json:"email"` Email *string `json:"email"`
Roles []*string `json:"roles"`
} }
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
SignupMethod string `json:"signupMethod"` SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"` EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"` Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"` CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"` UpdatedAt *int64 `json:"updatedAt"`
Roles []*string `json:"roles"` Roles []string `json:"roles"`
} }
type VerificationRequest struct { type VerificationRequest struct {

View File

@@ -2,6 +2,8 @@
# #
# https://gqlgen.com/getting-started/ # https://gqlgen.com/getting-started/
scalar Int64 scalar Int64
scalar Map
scalar Any
type Meta { type Meta {
version: String! version: String!
@@ -23,7 +25,7 @@ type User {
image: String image: String
createdAt: Int64 createdAt: Int64
updatedAt: Int64 updatedAt: Int64
roles: [String] roles: [String!]!
} }
type VerificationRequest { type VerificationRequest {
@@ -84,7 +86,7 @@ input UpdateProfileInput {
lastName: String lastName: String
image: String image: String
email: String email: String
roles: [String] # roles: [String]
} }
input ForgotPasswordInput { input ForgotPasswordInput {
@@ -116,7 +118,7 @@ type Mutation {
type Query { type Query {
meta: Meta! meta: Meta!
users: [User!]! users: [User!]!
token: AuthResponse token(role: String): AuthResponse
profile: User! profile: User!
verificationRequests: [VerificationRequest!]! verificationRequests: [VerificationRequest!]!
} }

View File

@@ -55,7 +55,7 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
return resolvers.Users(ctx) return resolvers.Users(ctx)
} }
func (r *queryResolver) Token(ctx context.Context) (*model.AuthResponse, error) { func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) {
return resolvers.Token(ctx) return resolvers.Token(ctx)
} }

View File

@@ -25,7 +25,6 @@ func AppHandler() gin.HandlerFunc {
if state == "" { if state == "" {
// cookie, err := utils.GetAuthToken(c) // cookie, err := utils.GetAuthToken(c)
// log.Println(`cookie`, cookie)
// if err != nil { // if err != nil {
// c.JSON(400, gin.H{"error": "invalid state"}) // c.JSON(400, gin.H{"error": "invalid state"})
// return // return
@@ -67,13 +66,6 @@ func AppHandler() gin.HandlerFunc {
} }
} }
log.Println(gin.H{
"data": map[string]string{
"authorizerURL": stateObj.AuthorizerURL,
"redirectURL": stateObj.RedirectURL,
},
})
// debug the request state // debug the request state
if pusher := c.Writer.Pusher(); pusher != nil { if pusher := c.Writer.Pusher(); pusher != nil {
// use pusher.Push() to do server push // use pusher.Push() to do server push

View File

@@ -19,7 +19,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func processGoogleUserInfo(code string, c *gin.Context) error { func processGoogleUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code) token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("invalid google exchange code: %s", err.Error()) return fmt.Errorf("invalid google exchange code: %s", err.Error())
@@ -50,6 +50,7 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
if err != nil { if err != nil {
// user not registered, register user and generate session token // user not registered, register user and generate session token
user.SignupMethod = enum.Google.String() user.SignupMethod = enum.Google.String()
user.Roles = role
} else { } else {
// user exists in db, check if method was google // user exists in db, check if method was google
// if not append google to existing signup method and save it // if not append google to existing signup method and save it
@@ -60,27 +61,28 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
} }
user.SignupMethod = signupMethod user.SignupMethod = signupMethod
user.Password = existingUser.Password user.Password = existingUser.Password
log.Println("=> checking roles...", utils.IsValidRole(strings.Split(existingUser.Roles, ","), role))
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
log.Println("=> invalid role from google oauth")
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
} }
user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email) user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
return nil return nil
} }
func processGithubUserInfo(code string, c *gin.Context) error { func processGithubUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code) token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("invalid github exchange code: %s", err.Error()) return fmt.Errorf("invalid github exchange code: %s", err.Error())
@@ -128,6 +130,7 @@ func processGithubUserInfo(code string, c *gin.Context) error {
if err != nil { if err != nil {
// user not registered, register user and generate session token // user not registered, register user and generate session token
user.SignupMethod = enum.Github.String() user.SignupMethod = enum.Github.String()
user.Roles = role
} else { } else {
// user exists in db, check if method was google // user exists in db, check if method was google
// if not append google to existing signup method and save it // if not append google to existing signup method and save it
@@ -138,26 +141,26 @@ func processGithubUserInfo(code string, c *gin.Context) error {
} }
user.SignupMethod = signupMethod user.SignupMethod = signupMethod
user.Password = existingUser.Password user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
} }
user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email) user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
return nil return nil
} }
func processFacebookUserInfo(code string, c *gin.Context) error { func processFacebookUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code) token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("invalid facebook exchange code: %s", err.Error()) return fmt.Errorf("invalid facebook exchange code: %s", err.Error())
@@ -199,6 +202,7 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
if err != nil { if err != nil {
// user not registered, register user and generate session token // user not registered, register user and generate session token
user.SignupMethod = enum.Github.String() user.SignupMethod = enum.Github.String()
user.Roles = role
} else { } else {
// user exists in db, check if method was google // user exists in db, check if method was google
// if not append google to existing signup method and save it // if not append google to existing signup method and save it
@@ -209,20 +213,20 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
} }
user.SignupMethod = signupMethod user.SignupMethod = signupMethod
user.Password = existingUser.Password user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
} }
user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email) user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
return nil return nil
@@ -238,23 +242,27 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": "invalid oauth state"}) c.JSON(400, gin.H{"error": "invalid oauth state"})
} }
session.DeleteToken(sessionState) session.DeleteToken(sessionState)
// contains random token, redirect url, role
sessionSplit := strings.Split(state, "___") sessionSplit := strings.Split(state, "___")
// TODO validate redirect url // TODO validate redirect url
if len(sessionSplit) != 2 { if len(sessionSplit) < 2 {
c.JSON(400, gin.H{"error": "invalid redirect url"}) c.JSON(400, gin.H{"error": "invalid redirect url"})
return return
} }
role := sessionSplit[2]
redirectURL := sessionSplit[1]
var err error var err error
code := c.Request.FormValue("code") code := c.Request.FormValue("code")
switch provider { switch provider {
case enum.Google.String(): case enum.Google.String():
err = processGoogleUserInfo(code, c) err = processGoogleUserInfo(code, role, c)
case enum.Github.String(): case enum.Github.String():
err = processGithubUserInfo(code, c) err = processGithubUserInfo(code, role, c)
case enum.Facebook.String(): case enum.Facebook.String():
err = processFacebookUserInfo(code, c) err = processFacebookUserInfo(code, role, c)
default: default:
err = fmt.Errorf(`invalid oauth provider`) err = fmt.Errorf(`invalid oauth provider`)
} }
@@ -263,6 +271,6 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": err.Error()}) c.JSON(400, gin.H{"error": err.Error()})
return return
} }
c.Redirect(http.StatusTemporaryRedirect, sessionSplit[1]) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -17,6 +18,7 @@ func OAuthLoginHandler() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// TODO validate redirect URL // TODO validate redirect URL
redirectURL := c.Query("redirectURL") redirectURL := c.Query("redirectURL")
role := c.Query("role")
if redirectURL == "" { if redirectURL == "" {
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -24,8 +26,21 @@ func OAuthLoginHandler() gin.HandlerFunc {
}) })
return return
} }
if role != "" {
// validate role
if !utils.IsValidRole(constants.ROLES, role) {
c.JSON(400, gin.H{
"error": "invalid role",
})
return
}
} else {
role = constants.DEFAULT_ROLE
}
uuid := uuid.New() uuid := uuid.New()
oauthStateString := uuid.String() + "___" + redirectURL oauthStateString := uuid.String() + "___" + redirectURL + "___" + role
provider := c.Param("oauth_provider") provider := c.Param("oauth_provider")

View File

@@ -50,15 +50,9 @@ func VerifyEmailHandler() gin.HandlerFunc {
db.Mgr.DeleteToken(claim.Email) db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)

View File

@@ -46,16 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
log.Println("Compare password error:", err) log.Println("Compare password error:", err)
return res, fmt.Errorf(`invalid password`) return res, fmt.Errorf(`invalid password`)
} }
userIdStr := fmt.Sprintf("%v", user.ID) role := constants.DEFAULT_ROLE
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ if params.Role != nil {
ID: userIdStr, // validate role
Email: user.Email, if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) {
}, enum.RefreshToken) return res, fmt.Errorf(`invalid role`)
}
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ role = *params.Role
ID: userIdStr, }
Email: user.Email, userIdStr := fmt.Sprintf("%v", user.ID)
}, enum.AccessToken) refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
@@ -71,6 +74,7 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
}, },

View File

@@ -2,6 +2,7 @@ package resolvers
import ( import (
"context" "context"
"fmt"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
@@ -25,7 +26,8 @@ func Logout(ctx context.Context) (*model.Response, error) {
return res, err return res, err
} }
session.DeleteToken(claim.ID) userId := fmt.Sprintf("%v", claim["id"])
session.DeleteToken(userId)
res = &model.Response{ res = &model.Response{
Message: "Logged out successfully", Message: "Logged out successfully",
} }

View File

@@ -3,6 +3,7 @@ package resolvers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@@ -27,13 +28,15 @@ func Profile(ctx context.Context) (*model.User, error) {
return res, err return res, err
} }
sessionToken := session.GetToken(claim.ID) userID := fmt.Sprintf("%v", claim["id"])
email := fmt.Sprintf("%v", claim["email"])
sessionToken := session.GetToken(userID)
if sessionToken == "" { if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`) return res, fmt.Errorf(`unauthorized`)
} }
user, err := db.Mgr.GetUserByEmail(claim.Email) user, err := db.Mgr.GetUserByEmail(email)
if err != nil { if err != nil {
return res, err return res, err
} }
@@ -48,6 +51,7 @@ func Profile(ctx context.Context) (*model.User, error) {
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
} }

View File

@@ -35,13 +35,18 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
return res, fmt.Errorf(`invalid email address`) return res, fmt.Errorf(`invalid email address`)
} }
inputRoles := []string{}
if params.Roles != nil && len(params.Roles) > 0 { if params.Roles != nil && len(params.Roles) > 0 {
// check if roles exists // check if roles exists
if !utils.IsValidRolesArray(params.Roles) { for _, item := range params.Roles {
inputRoles = append(inputRoles, *item)
}
if !utils.IsValidRolesArray(inputRoles) {
return res, fmt.Errorf(`invalid roles`) return res, fmt.Errorf(`invalid roles`)
} }
} else { } else {
params.Roles = []*string{&constants.DEFAULT_ROLE} inputRoles = []string{constants.DEFAULT_ROLE}
} }
// find user with email // find user with email
@@ -58,12 +63,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
Email: params.Email, Email: params.Email,
} }
roles := "" user.Roles = strings.Join(inputRoles, ",")
for _, roleInput := range params.Roles {
roles += *roleInput + ","
}
roles = strings.TrimSuffix(roles, ",")
user.Roles = roles
password, _ := utils.HashPassword(params.Password) password, _ := utils.HashPassword(params.Password)
user.Password = password user.Password = password
@@ -93,9 +93,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
Roles: params.Roles,
} }
if constants.DISABLE_EMAIL_VERIFICATION != "true" { if constants.DISABLE_EMAIL_VERIFICATION != "true" {
@@ -123,15 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
} }
} else { } else {
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
res = &model.AuthResponse{ res = &model.AuthResponse{

View File

@@ -3,8 +3,10 @@ package resolvers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@@ -25,9 +27,10 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
} }
claim, accessTokenErr := utils.VerifyAuthToken(token) claim, accessTokenErr := utils.VerifyAuthToken(token)
expiresAt := claim.ExpiresAt expiresAt := claim["exp"].(int64)
email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(claim.Email) role := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil { if err != nil {
return res, err return res, err
} }
@@ -46,10 +49,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 { if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 {
// if access token has expired and refresh/session token is valid // if access token has expired and refresh/session token is valid
// generate new accessToken // generate new accessToken
token, expiresAt, _ = utils.CreateAuthToken(utils.UserAuthInfo{ token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
} }
utils.SetCookie(gc, token) utils.SetCookie(gc, token)
res = &model.AuthResponse{ res = &model.AuthResponse{
@@ -62,6 +62,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
Image: &user.Image, Image: &user.Image,
FirstName: &user.FirstName, FirstName: &user.FirstName,
LastName: &user.LastName, LastName: &user.LastName,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
}, },

View File

@@ -32,18 +32,20 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
return res, err return res, err
} }
sessionToken := session.GetToken(claim.ID) id := fmt.Sprintf("%v", claim["id"])
sessionToken := session.GetToken(id)
if sessionToken == "" { if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`) return res, fmt.Errorf(`unauthorized`)
} }
// validate if all params are not empty // validate if all params are not empty
if params.FirstName == nil && params.LastName == nil && params.Image == nil && params.OldPassword == nil && params.Email == nil && params.Roles != nil { if params.FirstName == nil && params.LastName == nil && params.Image == nil && params.OldPassword == nil && params.Email == nil {
return res, fmt.Errorf("please enter atleast one param to update") return res, fmt.Errorf("please enter atleast one param to update")
} }
user, err := db.Mgr.GetUserByEmail(claim.Email) email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil { if err != nil {
return res, err return res, err
} }
@@ -122,21 +124,30 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
}() }()
} }
rolesToSave := "" // TODO this idea needs to be verified otherwise every user can make themselves super admin
if params.Roles != nil && len(params.Roles) > 0 { // rolesToSave := ""
currentRoles := strings.Split(user.Roles, ",") // if params.Roles != nil && len(params.Roles) > 0 {
inputRoles := []string{} // currentRoles := strings.Split(user.Roles, ",")
for _, item := range params.Roles { // inputRoles := []string{}
inputRoles = append(inputRoles, *item) // for _, item := range params.Roles {
} // inputRoles = append(inputRoles, *item)
if !utils.IsStringArrayEqual(inputRoles, currentRoles) && utils.IsValidRolesArray(params.Roles) { // }
rolesToSave = strings.Join(inputRoles, ",")
}
}
if rolesToSave != "" { // if !utils.IsValidRolesArray(inputRoles) {
user.Roles = rolesToSave // return res, fmt.Errorf("invalid list of roles")
} // }
// if !utils.IsStringArrayEqual(inputRoles, currentRoles) {
// rolesToSave = strings.Join(inputRoles, ",")
// }
// session.DeleteToken(fmt.Sprintf("%v", user.ID))
// utils.DeleteCookie(gc)
// }
// if rolesToSave != "" {
// user.Roles = rolesToSave
// }
_, err = db.Mgr.UpdateUser(user) _, err = db.Mgr.UpdateUser(user)
if err != nil { if err != nil {

View File

@@ -3,8 +3,10 @@ package resolvers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@@ -41,15 +43,9 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
db.Mgr.DeleteToken(claim.Email) db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
@@ -65,6 +61,7 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
}, },

View File

@@ -1,29 +1,32 @@
package utils package utils
import ( import (
"encoding/json"
"fmt" "fmt"
"log" "log"
"strings" "strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
) )
type UserAuthInfo struct { // type UserAuthInfo struct {
Email string `json:"email"` // Email string `json:"email"`
ID string `json:"id"` // ID string `json:"id"`
} // }
type JWTCustomClaim map[string]interface{}
type UserAuthClaim struct { type UserAuthClaim struct {
*jwt.StandardClaims *jwt.StandardClaims
TokenType string `json:"token_type"` *JWTCustomClaim `json:"authorizer"`
UserAuthInfo
} }
func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64, error) { func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) {
t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE)) t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE))
expiryBound := time.Hour expiryBound := time.Hour
if tokenType == enum.RefreshToken { if tokenType == enum.RefreshToken {
@@ -33,12 +36,19 @@ func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64
expiresAt := time.Now().Add(expiryBound).Unix() expiresAt := time.Now().Add(expiryBound).Unix()
customClaims := JWTCustomClaim{
"token_type": tokenType.String(),
"email": user.Email,
"id": user.ID,
"allowed_roles": strings.Split(user.Roles, ","),
constants.JWT_ROLE_CLAIM: role,
}
t.Claims = &UserAuthClaim{ t.Claims = &UserAuthClaim{
&jwt.StandardClaims{ &jwt.StandardClaims{
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
}, },
tokenType.String(), &customClaims,
user,
} }
token, err := t.SignedString([]byte(constants.JWT_SECRET)) token, err := t.SignedString([]byte(constants.JWT_SECRET))
@@ -63,14 +73,20 @@ func GetAuthToken(gc *gin.Context) (string, error) {
return token, nil return token, nil
} }
func VerifyAuthToken(token string) (*UserAuthClaim, error) { func VerifyAuthToken(token string) (map[string]interface{}, error) {
var res map[string]interface{}
claims := &UserAuthClaim{} claims := &UserAuthClaim{}
_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(constants.JWT_SECRET), nil return []byte(constants.JWT_SECRET), nil
}) })
if err != nil { if err != nil {
return claims, err return res, err
} }
return claims, nil data, _ := json.Marshal(claims.JWTCustomClaim)
json.Unmarshal(data, &res)
res["exp"] = claims.ExpiresAt
return res, nil
} }

20
server/utils/common.go Normal file
View File

@@ -0,0 +1,20 @@
package utils
import (
"io"
"os"
)
func WriteToFile(filename string, data string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.WriteString(file, data)
if err != nil {
return err
}
return file.Sync()
}

View File

@@ -40,7 +40,7 @@ func IsSuperAdmin(gc *gin.Context) bool {
return secret == constants.ADMIN_SECRET return secret == constants.ADMIN_SECRET
} }
func IsValidRolesArray(roles []*string) bool { func IsValidRolesArray(roles []string) bool {
valid := true valid := true
currentRoleMap := map[string]bool{} currentRoleMap := map[string]bool{}
@@ -48,7 +48,7 @@ func IsValidRolesArray(roles []*string) bool {
currentRoleMap[currentRole] = true currentRoleMap[currentRole] = true
} }
for _, inputRole := range roles { for _, inputRole := range roles {
if !currentRoleMap[*inputRole] { if !currentRoleMap[inputRole] {
valid = false valid = false
break break
} }
@@ -56,6 +56,18 @@ func IsValidRolesArray(roles []*string) bool {
return valid return valid
} }
func IsValidRole(userRoles []string, role string) bool {
valid := false
for _, currentRole := range userRoles {
if role == currentRole {
valid = true
break
}
}
return valid
}
func IsStringArrayEqual(a, b []string) bool { func IsStringArrayEqual(a, b []string) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false