feat: add role based oauth

This commit is contained in:
Lakhan Samani
2021-09-20 10:00:17 +05:30
parent aec5d5a0c7
commit 94cdbc9268
22 changed files with 307 additions and 171 deletions

View File

@@ -6,4 +6,5 @@ DISABLE_EMAIL_VERIFICATION=true
JWT_SECRET=random_string
JWT_TYPE=HS256
ROLES=user,admin
DEFAULT_ROLE=user
DEFAULT_ROLE=user
JWT_ROLE_CLAIM=role

View File

@@ -9,8 +9,8 @@ For the first version we will only support setting roles master list via env
- [x] `DEFAULT_ROLE` -> default role to assign to users
- [x] Add roles input for signup
- [x] Add roles to update profile mutation
- [ ] Add roles input for login
- [ ] Return roles to user
- [ ] Return roles in users list for super admin
- [ ] Add roles to the JWT token generation
- [x] Add roles input for login
- [x] Return roles to user
- [x] Return roles in users list for super admin
- [x] Add roles to the JWT token generation
- [ ] Validate token should also validate the role, if roles to validate again is present in request

View File

@@ -24,8 +24,9 @@ var (
DISABLE_BASIC_AUTHENTICATION = "false"
// ROLES
ROLES = []string{}
DEFAULT_ROLE = ""
ROLES = []string{}
DEFAULT_ROLE = ""
JWT_ROLE_CLAIM = "role"
// OAuth login
GOOGLE_CLIENT_ID = ""

View File

@@ -74,6 +74,7 @@ func InitEnv() {
constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION")
constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION")
constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE")
constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM")
if constants.ADMIN_SECRET == "" {
panic("root admin secret is required")
@@ -148,6 +149,7 @@ func InitEnv() {
rolesSplit := strings.Split(os.Getenv("ROLES"), ",")
roles := []string{}
defaultRole := ""
for _, val := range rolesSplit {
trimVal := strings.TrimSpace(val)
if trimVal != "" {
@@ -159,7 +161,7 @@ func InitEnv() {
}
}
if len(roles) > 0 && defaultRole == "" {
panic(`Invalid DEFAULT_ROLE environment. It can be one from give ROLES environment variable value`)
panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`)
}
if len(roles) == 0 {
@@ -168,4 +170,8 @@ func InitEnv() {
}
constants.ROLES = roles
if constants.JWT_ROLE_CLAIM == "" {
constants.JWT_ROLE_CLAIM = "role"
}
}

View File

@@ -80,7 +80,7 @@ type ComplexityRoot struct {
Query struct {
Meta func(childComplexity int) int
Profile func(childComplexity int) int
Token func(childComplexity int) int
Token func(childComplexity int, role *string) int
Users func(childComplexity int) int
VerificationRequests func(childComplexity int) int
}
@@ -127,7 +127,7 @@ type MutationResolver interface {
type QueryResolver interface {
Meta(ctx context.Context) (*model.Meta, error)
Users(ctx context.Context) ([]*model.User, error)
Token(ctx context.Context) (*model.AuthResponse, error)
Token(ctx context.Context, role *string) (*model.AuthResponse, error)
Profile(ctx context.Context) (*model.User, error)
VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error)
}
@@ -360,7 +360,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
break
}
return e.complexity.Query.Token(childComplexity), true
args, err := ec.field_Query_token_args(context.TODO(), rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true
case "Query.users":
if e.complexity.Query.Users == nil {
@@ -570,6 +575,8 @@ var sources = []*ast.Source{
#
# https://gqlgen.com/getting-started/
scalar Int64
scalar Map
scalar Any
type Meta {
version: String!
@@ -591,7 +598,7 @@ type User {
image: String
createdAt: Int64
updatedAt: Int64
roles: [String]
roles: [String!]!
}
type VerificationRequest {
@@ -652,7 +659,7 @@ input UpdateProfileInput {
lastName: String
image: String
email: String
roles: [String]
# roles: [String]
}
input ForgotPasswordInput {
@@ -684,7 +691,7 @@ type Mutation {
type Query {
meta: Meta!
users: [User!]!
token: AuthResponse
token(role: String): AuthResponse
profile: User!
verificationRequests: [VerificationRequest!]!
}
@@ -831,6 +838,21 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
return args, nil
}
func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 *string
if tmp, ok := rawArgs["role"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp)
if err != nil {
return nil, err
}
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
@@ -1772,9 +1794,16 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll
}
ctx = graphql.WithFieldContext(ctx, fc)
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.field_Query_token_args(ctx, rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().Token(rctx)
return ec.resolvers.Query().Token(rctx, args["role"].(*string))
})
if err != nil {
ec.Error(ctx, err)
@@ -2286,11 +2315,14 @@ func (ec *executionContext) _User_roles(ctx context.Context, field graphql.Colle
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.([]*string)
res := resTmp.([]string)
fc.Result = res
return ec.marshalOString2ᚕstring(ctx, field.Selections, res)
return ec.marshalNString2ᚕstring(ctx, field.Selections, res)
}
func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) {
@@ -3869,14 +3901,6 @@ func (ec *executionContext) unmarshalInputUpdateProfileInput(ctx context.Context
if err != nil {
return it, err
}
case "roles":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles"))
it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v)
if err != nil {
return it, err
}
}
}
@@ -4268,6 +4292,9 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
out.Values[i] = ec._User_updatedAt(ctx, field, obj)
case "roles":
out.Values[i] = ec._User_roles(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
@@ -4680,6 +4707,36 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S
return res
}
func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) {
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) {
res, err := ec.unmarshalInputUpdateProfileInput(ctx, v)
return res, graphql.ErrorOnPath(ctx, err)

View File

@@ -63,27 +63,26 @@ type SignUpInput struct {
}
type UpdateProfileInput struct {
OldPassword *string `json:"oldPassword"`
NewPassword *string `json:"newPassword"`
ConfirmNewPassword *string `json:"confirmNewPassword"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
Image *string `json:"image"`
Email *string `json:"email"`
Roles []*string `json:"roles"`
OldPassword *string `json:"oldPassword"`
NewPassword *string `json:"newPassword"`
ConfirmNewPassword *string `json:"confirmNewPassword"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
Image *string `json:"image"`
Email *string `json:"email"`
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`
SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"`
Roles []*string `json:"roles"`
ID string `json:"id"`
Email string `json:"email"`
SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"`
Roles []string `json:"roles"`
}
type VerificationRequest struct {

View File

@@ -2,6 +2,8 @@
#
# https://gqlgen.com/getting-started/
scalar Int64
scalar Map
scalar Any
type Meta {
version: String!
@@ -23,7 +25,7 @@ type User {
image: String
createdAt: Int64
updatedAt: Int64
roles: [String]
roles: [String!]!
}
type VerificationRequest {
@@ -84,7 +86,7 @@ input UpdateProfileInput {
lastName: String
image: String
email: String
roles: [String]
# roles: [String]
}
input ForgotPasswordInput {
@@ -116,7 +118,7 @@ type Mutation {
type Query {
meta: Meta!
users: [User!]!
token: AuthResponse
token(role: String): AuthResponse
profile: User!
verificationRequests: [VerificationRequest!]!
}

View File

@@ -55,7 +55,7 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
return resolvers.Users(ctx)
}
func (r *queryResolver) Token(ctx context.Context) (*model.AuthResponse, error) {
func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) {
return resolvers.Token(ctx)
}

View File

@@ -25,7 +25,6 @@ func AppHandler() gin.HandlerFunc {
if state == "" {
// cookie, err := utils.GetAuthToken(c)
// log.Println(`cookie`, cookie)
// if err != nil {
// c.JSON(400, gin.H{"error": "invalid state"})
// return
@@ -67,13 +66,6 @@ func AppHandler() gin.HandlerFunc {
}
}
log.Println(gin.H{
"data": map[string]string{
"authorizerURL": stateObj.AuthorizerURL,
"redirectURL": stateObj.RedirectURL,
},
})
// debug the request state
if pusher := c.Writer.Pusher(); pusher != nil {
// use pusher.Push() to do server push

View File

@@ -19,7 +19,7 @@ import (
"golang.org/x/oauth2"
)
func processGoogleUserInfo(code string, c *gin.Context) error {
func processGoogleUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("invalid google exchange code: %s", err.Error())
@@ -50,6 +50,7 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
if err != nil {
// user not registered, register user and generate session token
user.SignupMethod = enum.Google.String()
user.Roles = role
} else {
// user exists in db, check if method was google
// if not append google to existing signup method and save it
@@ -60,27 +61,28 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
}
user.SignupMethod = signupMethod
user.Password = existingUser.Password
log.Println("=> checking roles...", utils.IsValidRole(strings.Split(existingUser.Roles, ","), role))
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
log.Println("=> invalid role from google oauth")
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
}
user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken)
return nil
}
func processGithubUserInfo(code string, c *gin.Context) error {
func processGithubUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("invalid github exchange code: %s", err.Error())
@@ -128,6 +130,7 @@ func processGithubUserInfo(code string, c *gin.Context) error {
if err != nil {
// user not registered, register user and generate session token
user.SignupMethod = enum.Github.String()
user.Roles = role
} else {
// user exists in db, check if method was google
// if not append google to existing signup method and save it
@@ -138,26 +141,26 @@ func processGithubUserInfo(code string, c *gin.Context) error {
}
user.SignupMethod = signupMethod
user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
}
user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken)
return nil
}
func processFacebookUserInfo(code string, c *gin.Context) error {
func processFacebookUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("invalid facebook exchange code: %s", err.Error())
@@ -199,6 +202,7 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
if err != nil {
// user not registered, register user and generate session token
user.SignupMethod = enum.Github.String()
user.Roles = role
} else {
// user exists in db, check if method was google
// if not append google to existing signup method and save it
@@ -209,20 +213,20 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
}
user.SignupMethod = signupMethod
user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
}
user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken)
return nil
@@ -238,23 +242,27 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": "invalid oauth state"})
}
session.DeleteToken(sessionState)
// contains random token, redirect url, role
sessionSplit := strings.Split(state, "___")
// TODO validate redirect url
if len(sessionSplit) != 2 {
if len(sessionSplit) < 2 {
c.JSON(400, gin.H{"error": "invalid redirect url"})
return
}
role := sessionSplit[2]
redirectURL := sessionSplit[1]
var err error
code := c.Request.FormValue("code")
switch provider {
case enum.Google.String():
err = processGoogleUserInfo(code, c)
err = processGoogleUserInfo(code, role, c)
case enum.Github.String():
err = processGithubUserInfo(code, c)
err = processGithubUserInfo(code, role, c)
case enum.Facebook.String():
err = processFacebookUserInfo(code, c)
err = processFacebookUserInfo(code, role, c)
default:
err = fmt.Errorf(`invalid oauth provider`)
}
@@ -263,6 +271,6 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": err.Error()})
return
}
c.Redirect(http.StatusTemporaryRedirect, sessionSplit[1])
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -17,6 +18,7 @@ func OAuthLoginHandler() gin.HandlerFunc {
return func(c *gin.Context) {
// TODO validate redirect URL
redirectURL := c.Query("redirectURL")
role := c.Query("role")
if redirectURL == "" {
c.JSON(400, gin.H{
@@ -24,8 +26,21 @@ func OAuthLoginHandler() gin.HandlerFunc {
})
return
}
if role != "" {
// validate role
if !utils.IsValidRole(constants.ROLES, role) {
c.JSON(400, gin.H{
"error": "invalid role",
})
return
}
} else {
role = constants.DEFAULT_ROLE
}
uuid := uuid.New()
oauthStateString := uuid.String() + "___" + redirectURL
oauthStateString := uuid.String() + "___" + redirectURL + "___" + role
provider := c.Param("oauth_provider")

View File

@@ -50,15 +50,9 @@ func VerifyEmailHandler() gin.HandlerFunc {
db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles)
session.SetToken(userIdStr, refreshToken)
utils.SetCookie(c, accessToken)

View File

@@ -46,16 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
log.Println("Compare password error:", err)
return res, fmt.Errorf(`invalid password`)
}
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
role := constants.DEFAULT_ROLE
if params.Role != nil {
// validate role
if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) {
return res, fmt.Errorf(`invalid role`)
}
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
role = *params.Role
}
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
session.SetToken(userIdStr, refreshToken)
@@ -71,6 +74,7 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
},

View File

@@ -2,6 +2,7 @@ package resolvers
import (
"context"
"fmt"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/session"
@@ -25,7 +26,8 @@ func Logout(ctx context.Context) (*model.Response, error) {
return res, err
}
session.DeleteToken(claim.ID)
userId := fmt.Sprintf("%v", claim["id"])
session.DeleteToken(userId)
res = &model.Response{
Message: "Logged out successfully",
}

View File

@@ -3,6 +3,7 @@ package resolvers
import (
"context"
"fmt"
"strings"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
@@ -27,13 +28,15 @@ func Profile(ctx context.Context) (*model.User, error) {
return res, err
}
sessionToken := session.GetToken(claim.ID)
userID := fmt.Sprintf("%v", claim["id"])
email := fmt.Sprintf("%v", claim["email"])
sessionToken := session.GetToken(userID)
if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`)
}
user, err := db.Mgr.GetUserByEmail(claim.Email)
user, err := db.Mgr.GetUserByEmail(email)
if err != nil {
return res, err
}
@@ -48,6 +51,7 @@ func Profile(ctx context.Context) (*model.User, error) {
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
}

View File

@@ -35,13 +35,18 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
return res, fmt.Errorf(`invalid email address`)
}
inputRoles := []string{}
if params.Roles != nil && len(params.Roles) > 0 {
// check if roles exists
if !utils.IsValidRolesArray(params.Roles) {
for _, item := range params.Roles {
inputRoles = append(inputRoles, *item)
}
if !utils.IsValidRolesArray(inputRoles) {
return res, fmt.Errorf(`invalid roles`)
}
} else {
params.Roles = []*string{&constants.DEFAULT_ROLE}
inputRoles = []string{constants.DEFAULT_ROLE}
}
// find user with email
@@ -58,12 +63,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
Email: params.Email,
}
roles := ""
for _, roleInput := range params.Roles {
roles += *roleInput + ","
}
roles = strings.TrimSuffix(roles, ",")
user.Roles = roles
user.Roles = strings.Join(inputRoles, ",")
password, _ := utils.HashPassword(params.Password)
user.Password = password
@@ -93,9 +93,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
Roles: params.Roles,
}
if constants.DISABLE_EMAIL_VERIFICATION != "true" {
@@ -123,15 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
}
} else {
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
session.SetToken(userIdStr, refreshToken)
res = &model.AuthResponse{

View File

@@ -3,8 +3,10 @@ package resolvers
import (
"context"
"fmt"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model"
@@ -25,9 +27,10 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
}
claim, accessTokenErr := utils.VerifyAuthToken(token)
expiresAt := claim.ExpiresAt
user, err := db.Mgr.GetUserByEmail(claim.Email)
expiresAt := claim["exp"].(int64)
email := fmt.Sprintf("%v", claim["email"])
role := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil {
return res, err
}
@@ -46,10 +49,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 {
// if access token has expired and refresh/session token is valid
// generate new accessToken
token, expiresAt, _ = utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, role)
}
utils.SetCookie(gc, token)
res = &model.AuthResponse{
@@ -62,6 +62,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
Image: &user.Image,
FirstName: &user.FirstName,
LastName: &user.LastName,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
},

View File

@@ -32,18 +32,20 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
return res, err
}
sessionToken := session.GetToken(claim.ID)
id := fmt.Sprintf("%v", claim["id"])
sessionToken := session.GetToken(id)
if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`)
}
// validate if all params are not empty
if params.FirstName == nil && params.LastName == nil && params.Image == nil && params.OldPassword == nil && params.Email == nil && params.Roles != nil {
if params.FirstName == nil && params.LastName == nil && params.Image == nil && params.OldPassword == nil && params.Email == nil {
return res, fmt.Errorf("please enter atleast one param to update")
}
user, err := db.Mgr.GetUserByEmail(claim.Email)
email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil {
return res, err
}
@@ -122,21 +124,30 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
}()
}
rolesToSave := ""
if params.Roles != nil && len(params.Roles) > 0 {
currentRoles := strings.Split(user.Roles, ",")
inputRoles := []string{}
for _, item := range params.Roles {
inputRoles = append(inputRoles, *item)
}
if !utils.IsStringArrayEqual(inputRoles, currentRoles) && utils.IsValidRolesArray(params.Roles) {
rolesToSave = strings.Join(inputRoles, ",")
}
}
// TODO this idea needs to be verified otherwise every user can make themselves super admin
// rolesToSave := ""
// if params.Roles != nil && len(params.Roles) > 0 {
// currentRoles := strings.Split(user.Roles, ",")
// inputRoles := []string{}
// for _, item := range params.Roles {
// inputRoles = append(inputRoles, *item)
// }
if rolesToSave != "" {
user.Roles = rolesToSave
}
// if !utils.IsValidRolesArray(inputRoles) {
// return res, fmt.Errorf("invalid list of roles")
// }
// if !utils.IsStringArrayEqual(inputRoles, currentRoles) {
// rolesToSave = strings.Join(inputRoles, ",")
// }
// session.DeleteToken(fmt.Sprintf("%v", user.ID))
// utils.DeleteCookie(gc)
// }
// if rolesToSave != "" {
// user.Roles = rolesToSave
// }
_, err = db.Mgr.UpdateUser(user)
if err != nil {

View File

@@ -3,8 +3,10 @@ package resolvers
import (
"context"
"fmt"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model"
@@ -41,15 +43,9 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
session.SetToken(userIdStr, refreshToken)
@@ -65,6 +61,7 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
},

View File

@@ -1,29 +1,32 @@
package utils
import (
"encoding/json"
"fmt"
"log"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
type UserAuthInfo struct {
Email string `json:"email"`
ID string `json:"id"`
}
// type UserAuthInfo struct {
// Email string `json:"email"`
// ID string `json:"id"`
// }
type JWTCustomClaim map[string]interface{}
type UserAuthClaim struct {
*jwt.StandardClaims
TokenType string `json:"token_type"`
UserAuthInfo
*JWTCustomClaim `json:"authorizer"`
}
func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64, error) {
func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) {
t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE))
expiryBound := time.Hour
if tokenType == enum.RefreshToken {
@@ -33,12 +36,19 @@ func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64
expiresAt := time.Now().Add(expiryBound).Unix()
customClaims := JWTCustomClaim{
"token_type": tokenType.String(),
"email": user.Email,
"id": user.ID,
"allowed_roles": strings.Split(user.Roles, ","),
constants.JWT_ROLE_CLAIM: role,
}
t.Claims = &UserAuthClaim{
&jwt.StandardClaims{
ExpiresAt: expiresAt,
},
tokenType.String(),
user,
&customClaims,
}
token, err := t.SignedString([]byte(constants.JWT_SECRET))
@@ -63,14 +73,20 @@ func GetAuthToken(gc *gin.Context) (string, error) {
return token, nil
}
func VerifyAuthToken(token string) (*UserAuthClaim, error) {
func VerifyAuthToken(token string) (map[string]interface{}, error) {
var res map[string]interface{}
claims := &UserAuthClaim{}
_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(constants.JWT_SECRET), nil
})
if err != nil {
return claims, err
return res, err
}
return claims, nil
data, _ := json.Marshal(claims.JWTCustomClaim)
json.Unmarshal(data, &res)
res["exp"] = claims.ExpiresAt
return res, nil
}

20
server/utils/common.go Normal file
View File

@@ -0,0 +1,20 @@
package utils
import (
"io"
"os"
)
func WriteToFile(filename string, data string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.WriteString(file, data)
if err != nil {
return err
}
return file.Sync()
}

View File

@@ -40,7 +40,7 @@ func IsSuperAdmin(gc *gin.Context) bool {
return secret == constants.ADMIN_SECRET
}
func IsValidRolesArray(roles []*string) bool {
func IsValidRolesArray(roles []string) bool {
valid := true
currentRoleMap := map[string]bool{}
@@ -48,7 +48,7 @@ func IsValidRolesArray(roles []*string) bool {
currentRoleMap[currentRole] = true
}
for _, inputRole := range roles {
if !currentRoleMap[*inputRole] {
if !currentRoleMap[inputRole] {
valid = false
break
}
@@ -56,6 +56,18 @@ func IsValidRolesArray(roles []*string) bool {
return valid
}
func IsValidRole(userRoles []string, role string) bool {
valid := false
for _, currentRole := range userRoles {
if role == currentRole {
valid = true
break
}
}
return valid
}
func IsStringArrayEqual(a, b []string) bool {
if len(a) != len(b) {
return false