feat/role based access (#50)

* feat: add roles based access

* feat: update roles env + todo

* feat: add roles to update profile

* feat: add role based oauth

* feat: validate role for a given token
This commit is contained in:
Lakhan Samani 2021-09-20 10:36:26 +05:30 committed by GitHub
parent 195270525c
commit 21e3425e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 544 additions and 141 deletions

View File

@ -4,4 +4,7 @@ DATABASE_TYPE=sqlite
ADMIN_SECRET=admin ADMIN_SECRET=admin
DISABLE_EMAIL_VERIFICATION=true DISABLE_EMAIL_VERIFICATION=true
JWT_SECRET=random_string JWT_SECRET=random_string
JWT_TYPE=HS256 JWT_TYPE=HS256
ROLES=user,admin
DEFAULT_ROLE=user
JWT_ROLE_CLAIM=role

16
TODO.md
View File

@ -0,0 +1,16 @@
# Task List
# Feature roles
For the first version we will only support setting roles master list via env
- [x] Support following ENV
- [x] `ROLES` -> comma separated list of role names
- [x] `DEFAULT_ROLE` -> default role to assign to users
- [x] Add roles input for signup
- [x] Add roles to update profile mutation
- [x] Add roles input for login
- [x] Return roles to user
- [x] Return roles in users list for super admin
- [x] Add roles to the JWT token generation
- [x] Validate token should also validate the role, if roles to validate again is present in request

View File

@ -23,6 +23,11 @@ var (
DISABLE_EMAIL_VERIFICATION = "false" DISABLE_EMAIL_VERIFICATION = "false"
DISABLE_BASIC_AUTHENTICATION = "false" DISABLE_BASIC_AUTHENTICATION = "false"
// ROLES
ROLES = []string{}
DEFAULT_ROLE = ""
JWT_ROLE_CLAIM = "role"
// OAuth login // OAuth login
GOOGLE_CLIENT_ID = "" GOOGLE_CLIENT_ID = ""
GOOGLE_CLIENT_SECRET = "" GOOGLE_CLIENT_SECRET = ""

View File

@ -24,6 +24,7 @@ type Manager interface {
GetVerificationRequests() ([]VerificationRequest, error) GetVerificationRequests() ([]VerificationRequest, error)
GetVerificationByEmail(email string) (VerificationRequest, error) GetVerificationByEmail(email string) (VerificationRequest, error)
DeleteUser(email string) error DeleteUser(email string) error
SaveRoles(roles []Role) error
} }
type manager struct { type manager struct {
@ -53,7 +54,7 @@ func InitDB() {
if err != nil { if err != nil {
log.Fatal("Failed to init db:", err) log.Fatal("Failed to init db:", err)
} else { } else {
db.AutoMigrate(&User{}, &VerificationRequest{}) db.AutoMigrate(&User{}, &VerificationRequest{}, &Role{})
} }
Mgr = &manager{db: db} Mgr = &manager{db: db}

19
server/db/roles.go Normal file
View File

@ -0,0 +1,19 @@
package db
import "log"
type Role struct {
ID uint `gorm:"primaryKey"`
Role string
}
// SaveRoles function to save roles
func (mgr *manager) SaveRoles(roles []Role) error {
res := mgr.db.Create(&roles)
if res.Error != nil {
log.Println(`Error saving roles`)
return res.Error
}
return nil
}

View File

@ -17,6 +17,7 @@ type User struct {
CreatedAt int64 `gorm:"autoCreateTime"` CreatedAt int64 `gorm:"autoCreateTime"`
UpdatedAt int64 `gorm:"autoUpdateTime"` UpdatedAt int64 `gorm:"autoUpdateTime"`
Image string Image string
Roles string
} }
// SaveUser function to add user even with email conflict // SaveUser function to add user even with email conflict

View File

@ -73,6 +73,8 @@ func InitEnv() {
constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/") constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/")
constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION") constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION")
constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION") constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION")
constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE")
constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM")
if constants.ADMIN_SECRET == "" { if constants.ADMIN_SECRET == "" {
panic("root admin secret is required") panic("root admin secret is required")
@ -143,4 +145,33 @@ func InitEnv() {
constants.DISABLE_EMAIL_VERIFICATION = "false" constants.DISABLE_EMAIL_VERIFICATION = "false"
} }
} }
rolesSplit := strings.Split(os.Getenv("ROLES"), ",")
roles := []string{}
defaultRole := ""
for _, val := range rolesSplit {
trimVal := strings.TrimSpace(val)
if trimVal != "" {
roles = append(roles, trimVal)
}
if trimVal == constants.DEFAULT_ROLE {
defaultRole = trimVal
}
}
if len(roles) > 0 && defaultRole == "" {
panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`)
}
if len(roles) == 0 {
roles = []string{"user", "admin"}
constants.DEFAULT_ROLE = "user"
}
constants.ROLES = roles
if constants.JWT_ROLE_CLAIM == "" {
constants.JWT_ROLE_CLAIM = "role"
}
} }

View File

@ -80,7 +80,7 @@ type ComplexityRoot struct {
Query struct { Query struct {
Meta func(childComplexity int) int Meta func(childComplexity int) int
Profile func(childComplexity int) int Profile func(childComplexity int) int
Token func(childComplexity int) int Token func(childComplexity int, role *string) int
Users func(childComplexity int) int Users func(childComplexity int) int
VerificationRequests func(childComplexity int) int VerificationRequests func(childComplexity int) int
} }
@ -97,6 +97,7 @@ type ComplexityRoot struct {
ID func(childComplexity int) int ID func(childComplexity int) int
Image func(childComplexity int) int Image func(childComplexity int) int
LastName func(childComplexity int) int LastName func(childComplexity int) int
Roles func(childComplexity int) int
SignupMethod func(childComplexity int) int SignupMethod func(childComplexity int) int
UpdatedAt func(childComplexity int) int UpdatedAt func(childComplexity int) int
} }
@ -126,7 +127,7 @@ type MutationResolver interface {
type QueryResolver interface { type QueryResolver interface {
Meta(ctx context.Context) (*model.Meta, error) Meta(ctx context.Context) (*model.Meta, error)
Users(ctx context.Context) ([]*model.User, error) Users(ctx context.Context) ([]*model.User, error)
Token(ctx context.Context) (*model.AuthResponse, error) Token(ctx context.Context, role *string) (*model.AuthResponse, error)
Profile(ctx context.Context) (*model.User, error) Profile(ctx context.Context) (*model.User, error)
VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error) VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error)
} }
@ -359,7 +360,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
break break
} }
return e.complexity.Query.Token(childComplexity), true args, err := ec.field_Query_token_args(context.TODO(), rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true
case "Query.users": case "Query.users":
if e.complexity.Query.Users == nil { if e.complexity.Query.Users == nil {
@ -431,6 +437,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.User.LastName(childComplexity), true return e.complexity.User.LastName(childComplexity), true
case "User.roles":
if e.complexity.User.Roles == nil {
break
}
return e.complexity.User.Roles(childComplexity), true
case "User.signupMethod": case "User.signupMethod":
if e.complexity.User.SignupMethod == nil { if e.complexity.User.SignupMethod == nil {
break break
@ -562,6 +575,8 @@ var sources = []*ast.Source{
# #
# https://gqlgen.com/getting-started/ # https://gqlgen.com/getting-started/
scalar Int64 scalar Int64
scalar Map
scalar Any
type Meta { type Meta {
version: String! version: String!
@ -583,6 +598,7 @@ type User {
image: String image: String
createdAt: Int64 createdAt: Int64
updatedAt: Int64 updatedAt: Int64
roles: [String!]!
} }
type VerificationRequest { type VerificationRequest {
@ -618,11 +634,13 @@ input SignUpInput {
password: String! password: String!
confirmPassword: String! confirmPassword: String!
image: String image: String
roles: [String]
} }
input LoginInput { input LoginInput {
email: String! email: String!
password: String! password: String!
role: String
} }
input VerifyEmailInput { input VerifyEmailInput {
@ -641,6 +659,7 @@ input UpdateProfileInput {
lastName: String lastName: String
image: String image: String
email: String email: String
# roles: [String]
} }
input ForgotPasswordInput { input ForgotPasswordInput {
@ -672,7 +691,7 @@ type Mutation {
type Query { type Query {
meta: Meta! meta: Meta!
users: [User!]! users: [User!]!
token: AuthResponse token(role: String): AuthResponse
profile: User! profile: User!
verificationRequests: [VerificationRequest!]! verificationRequests: [VerificationRequest!]!
} }
@ -819,6 +838,21 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
return args, nil return args, nil
} }
func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 *string
if tmp, ok := rawArgs["role"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp)
if err != nil {
return nil, err
}
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error var err error
args := map[string]interface{}{} args := map[string]interface{}{}
@ -1760,9 +1794,16 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll
} }
ctx = graphql.WithFieldContext(ctx, fc) ctx = graphql.WithFieldContext(ctx, fc)
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.field_Query_token_args(ctx, rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().Token(rctx) return ec.resolvers.Query().Token(rctx, args["role"].(*string))
}) })
if err != nil { if err != nil {
ec.Error(ctx, err) ec.Error(ctx, err)
@ -2249,6 +2290,41 @@ func (ec *executionContext) _User_updatedAt(ctx context.Context, field graphql.C
return ec.marshalOInt642ᚖint64(ctx, field.Selections, res) return ec.marshalOInt642ᚖint64(ctx, field.Selections, res)
} }
func (ec *executionContext) _User_roles(ctx context.Context, field graphql.CollectedField, obj *model.User) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "User",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Roles, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.([]string)
fc.Result = res
return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res)
}
func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -3625,6 +3701,14 @@ func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj in
if err != nil { if err != nil {
return it, err return it, err
} }
case "role":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
it.Role, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@ -3741,6 +3825,14 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i
if err != nil { if err != nil {
return it, err return it, err
} }
case "roles":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles"))
it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@ -4198,6 +4290,11 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
out.Values[i] = ec._User_createdAt(ctx, field, obj) out.Values[i] = ec._User_createdAt(ctx, field, obj)
case "updatedAt": case "updatedAt":
out.Values[i] = ec._User_updatedAt(ctx, field, obj) out.Values[i] = ec._User_updatedAt(ctx, field, obj)
case "roles":
out.Values[i] = ec._User_roles(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@ -4610,6 +4707,36 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S
return res return res
} }
func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) {
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) { func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) {
res, err := ec.unmarshalInputUpdateProfileInput(ctx, v) res, err := ec.unmarshalInputUpdateProfileInput(ctx, v)
return res, graphql.ErrorOnPath(ctx, err) return res, graphql.ErrorOnPath(ctx, err)
@ -5002,6 +5129,42 @@ func (ec *executionContext) marshalOString2string(ctx context.Context, sel ast.S
return graphql.MarshalString(v) return graphql.MarshalString(v)
} }
func (ec *executionContext) unmarshalOString2ᚕᚖstring(ctx context.Context, v interface{}) ([]*string, error) {
if v == nil {
return nil, nil
}
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]*string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalOString2ᚖstring(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalOString2ᚕᚖstring(ctx context.Context, sel ast.SelectionSet, v []*string) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalOString2ᚖstring(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) { func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) {
if v == nil { if v == nil {
return nil, nil return nil, nil

View File

@ -23,8 +23,9 @@ type ForgotPasswordInput struct {
} }
type LoginInput struct { type LoginInput struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
Role *string `json:"role"`
} }
type Meta struct { type Meta struct {
@ -52,12 +53,13 @@ type Response struct {
} }
type SignUpInput struct { type SignUpInput struct {
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"` ConfirmPassword string `json:"confirmPassword"`
Image *string `json:"image"` Image *string `json:"image"`
Roles []*string `json:"roles"`
} }
type UpdateProfileInput struct { type UpdateProfileInput struct {
@ -71,15 +73,16 @@ type UpdateProfileInput struct {
} }
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
SignupMethod string `json:"signupMethod"` SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"` EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"` Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"` CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"` UpdatedAt *int64 `json:"updatedAt"`
Roles []string `json:"roles"`
} }
type VerificationRequest struct { type VerificationRequest struct {

View File

@ -2,6 +2,8 @@
# #
# https://gqlgen.com/getting-started/ # https://gqlgen.com/getting-started/
scalar Int64 scalar Int64
scalar Map
scalar Any
type Meta { type Meta {
version: String! version: String!
@ -23,6 +25,7 @@ type User {
image: String image: String
createdAt: Int64 createdAt: Int64
updatedAt: Int64 updatedAt: Int64
roles: [String!]!
} }
type VerificationRequest { type VerificationRequest {
@ -58,11 +61,13 @@ input SignUpInput {
password: String! password: String!
confirmPassword: String! confirmPassword: String!
image: String image: String
roles: [String]
} }
input LoginInput { input LoginInput {
email: String! email: String!
password: String! password: String!
role: String
} }
input VerifyEmailInput { input VerifyEmailInput {
@ -81,6 +86,7 @@ input UpdateProfileInput {
lastName: String lastName: String
image: String image: String
email: String email: String
# roles: [String]
} }
input ForgotPasswordInput { input ForgotPasswordInput {
@ -112,7 +118,7 @@ type Mutation {
type Query { type Query {
meta: Meta! meta: Meta!
users: [User!]! users: [User!]!
token: AuthResponse token(role: String): AuthResponse
profile: User! profile: User!
verificationRequests: [VerificationRequest!]! verificationRequests: [VerificationRequest!]!
} }

View File

@ -55,8 +55,8 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
return resolvers.Users(ctx) return resolvers.Users(ctx)
} }
func (r *queryResolver) Token(ctx context.Context) (*model.AuthResponse, error) { func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) {
return resolvers.Token(ctx) return resolvers.Token(ctx, role)
} }
func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) { func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) {

View File

@ -25,7 +25,6 @@ func AppHandler() gin.HandlerFunc {
if state == "" { if state == "" {
// cookie, err := utils.GetAuthToken(c) // cookie, err := utils.GetAuthToken(c)
// log.Println(`cookie`, cookie)
// if err != nil { // if err != nil {
// c.JSON(400, gin.H{"error": "invalid state"}) // c.JSON(400, gin.H{"error": "invalid state"})
// return // return
@ -67,13 +66,6 @@ func AppHandler() gin.HandlerFunc {
} }
} }
log.Println(gin.H{
"data": map[string]string{
"authorizerURL": stateObj.AuthorizerURL,
"redirectURL": stateObj.RedirectURL,
},
})
// debug the request state // debug the request state
if pusher := c.Writer.Pusher(); pusher != nil { if pusher := c.Writer.Pusher(); pusher != nil {
// use pusher.Push() to do server push // use pusher.Push() to do server push

View File

@ -19,7 +19,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func processGoogleUserInfo(code string, c *gin.Context) error { func processGoogleUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code) token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("invalid google exchange code: %s", err.Error()) return fmt.Errorf("invalid google exchange code: %s", err.Error())
@ -50,6 +50,7 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
if err != nil { if err != nil {
// user not registered, register user and generate session token // user not registered, register user and generate session token
user.SignupMethod = enum.Google.String() user.SignupMethod = enum.Google.String()
user.Roles = role
} else { } else {
// user exists in db, check if method was google // user exists in db, check if method was google
// if not append google to existing signup method and save it // if not append google to existing signup method and save it
@ -60,27 +61,26 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
} }
user.SignupMethod = signupMethod user.SignupMethod = signupMethod
user.Password = existingUser.Password user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
} }
user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email) user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
return nil return nil
} }
func processGithubUserInfo(code string, c *gin.Context) error { func processGithubUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code) token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("invalid github exchange code: %s", err.Error()) return fmt.Errorf("invalid github exchange code: %s", err.Error())
@ -128,6 +128,7 @@ func processGithubUserInfo(code string, c *gin.Context) error {
if err != nil { if err != nil {
// user not registered, register user and generate session token // user not registered, register user and generate session token
user.SignupMethod = enum.Github.String() user.SignupMethod = enum.Github.String()
user.Roles = role
} else { } else {
// user exists in db, check if method was google // user exists in db, check if method was google
// if not append google to existing signup method and save it // if not append google to existing signup method and save it
@ -138,26 +139,26 @@ func processGithubUserInfo(code string, c *gin.Context) error {
} }
user.SignupMethod = signupMethod user.SignupMethod = signupMethod
user.Password = existingUser.Password user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
} }
user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email) user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
return nil return nil
} }
func processFacebookUserInfo(code string, c *gin.Context) error { func processFacebookUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code) token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("invalid facebook exchange code: %s", err.Error()) return fmt.Errorf("invalid facebook exchange code: %s", err.Error())
@ -199,6 +200,7 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
if err != nil { if err != nil {
// user not registered, register user and generate session token // user not registered, register user and generate session token
user.SignupMethod = enum.Github.String() user.SignupMethod = enum.Github.String()
user.Roles = role
} else { } else {
// user exists in db, check if method was google // user exists in db, check if method was google
// if not append google to existing signup method and save it // if not append google to existing signup method and save it
@ -209,20 +211,20 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
} }
user.SignupMethod = signupMethod user.SignupMethod = signupMethod
user.Password = existingUser.Password user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
} }
user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email) user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
return nil return nil
@ -238,23 +240,27 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": "invalid oauth state"}) c.JSON(400, gin.H{"error": "invalid oauth state"})
} }
session.DeleteToken(sessionState) session.DeleteToken(sessionState)
// contains random token, redirect url, role
sessionSplit := strings.Split(state, "___") sessionSplit := strings.Split(state, "___")
// TODO validate redirect url // TODO validate redirect url
if len(sessionSplit) != 2 { if len(sessionSplit) < 2 {
c.JSON(400, gin.H{"error": "invalid redirect url"}) c.JSON(400, gin.H{"error": "invalid redirect url"})
return return
} }
role := sessionSplit[2]
redirectURL := sessionSplit[1]
var err error var err error
code := c.Request.FormValue("code") code := c.Request.FormValue("code")
switch provider { switch provider {
case enum.Google.String(): case enum.Google.String():
err = processGoogleUserInfo(code, c) err = processGoogleUserInfo(code, role, c)
case enum.Github.String(): case enum.Github.String():
err = processGithubUserInfo(code, c) err = processGithubUserInfo(code, role, c)
case enum.Facebook.String(): case enum.Facebook.String():
err = processFacebookUserInfo(code, c) err = processFacebookUserInfo(code, role, c)
default: default:
err = fmt.Errorf(`invalid oauth provider`) err = fmt.Errorf(`invalid oauth provider`)
} }
@ -263,6 +269,6 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": err.Error()}) c.JSON(400, gin.H{"error": err.Error()})
return return
} }
c.Redirect(http.StatusTemporaryRedirect, sessionSplit[1]) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -17,6 +18,7 @@ func OAuthLoginHandler() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// TODO validate redirect URL // TODO validate redirect URL
redirectURL := c.Query("redirectURL") redirectURL := c.Query("redirectURL")
role := c.Query("role")
if redirectURL == "" { if redirectURL == "" {
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@ -24,8 +26,21 @@ func OAuthLoginHandler() gin.HandlerFunc {
}) })
return return
} }
if role != "" {
// validate role
if !utils.IsValidRole(constants.ROLES, role) {
c.JSON(400, gin.H{
"error": "invalid role",
})
return
}
} else {
role = constants.DEFAULT_ROLE
}
uuid := uuid.New() uuid := uuid.New()
oauthStateString := uuid.String() + "___" + redirectURL oauthStateString := uuid.String() + "___" + redirectURL + "___" + role
provider := c.Param("oauth_provider") provider := c.Param("oauth_provider")

View File

@ -50,15 +50,9 @@ func VerifyEmailHandler() gin.HandlerFunc {
db.Mgr.DeleteToken(claim.Email) db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
utils.SetCookie(c, accessToken) utils.SetCookie(c, accessToken)

View File

@ -9,6 +9,7 @@ import (
"github.com/authorizerdev/authorizer/server/handlers" "github.com/authorizerdev/authorizer/server/handlers"
"github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-contrib/location" "github.com/gin-contrib/location"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -50,6 +51,7 @@ func main() {
db.InitDB() db.InitDB()
session.InitSession() session.InitSession()
oauth.InitOAuth() oauth.InitOAuth()
utils.InitServer()
r := gin.Default() r := gin.Default()
r.Use(location.Default()) r.Use(location.Default())

View File

@ -46,16 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
log.Println("Compare password error:", err) log.Println("Compare password error:", err)
return res, fmt.Errorf(`invalid password`) return res, fmt.Errorf(`invalid password`)
} }
userIdStr := fmt.Sprintf("%v", user.ID) role := constants.DEFAULT_ROLE
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ if params.Role != nil {
ID: userIdStr, // validate role
Email: user.Email, if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) {
}, enum.RefreshToken) return res, fmt.Errorf(`invalid role`)
}
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ role = *params.Role
ID: userIdStr, }
Email: user.Email, userIdStr := fmt.Sprintf("%v", user.ID)
}, enum.AccessToken) refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
@ -71,6 +74,7 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
}, },

View File

@ -2,6 +2,7 @@ package resolvers
import ( import (
"context" "context"
"fmt"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
@ -25,7 +26,8 @@ func Logout(ctx context.Context) (*model.Response, error) {
return res, err return res, err
} }
session.DeleteToken(claim.ID) userId := fmt.Sprintf("%v", claim["id"])
session.DeleteToken(userId)
res = &model.Response{ res = &model.Response{
Message: "Logged out successfully", Message: "Logged out successfully",
} }

View File

@ -3,6 +3,7 @@ package resolvers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@ -27,13 +28,15 @@ func Profile(ctx context.Context) (*model.User, error) {
return res, err return res, err
} }
sessionToken := session.GetToken(claim.ID) userID := fmt.Sprintf("%v", claim["id"])
email := fmt.Sprintf("%v", claim["email"])
sessionToken := session.GetToken(userID)
if sessionToken == "" { if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`) return res, fmt.Errorf(`unauthorized`)
} }
user, err := db.Mgr.GetUserByEmail(claim.Email) user, err := db.Mgr.GetUserByEmail(email)
if err != nil { if err != nil {
return res, err return res, err
} }
@ -48,6 +51,7 @@ func Profile(ctx context.Context) (*model.User, error) {
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
} }

View File

@ -35,6 +35,20 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
return res, fmt.Errorf(`invalid email address`) return res, fmt.Errorf(`invalid email address`)
} }
inputRoles := []string{}
if params.Roles != nil && len(params.Roles) > 0 {
// check if roles exists
for _, item := range params.Roles {
inputRoles = append(inputRoles, *item)
}
if !utils.IsValidRolesArray(inputRoles) {
return res, fmt.Errorf(`invalid roles`)
}
} else {
inputRoles = []string{constants.DEFAULT_ROLE}
}
// find user with email // find user with email
existingUser, err := db.Mgr.GetUserByEmail(params.Email) existingUser, err := db.Mgr.GetUserByEmail(params.Email)
if err != nil { if err != nil {
@ -49,6 +63,8 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
Email: params.Email, Email: params.Email,
} }
user.Roles = strings.Join(inputRoles, ",")
password, _ := utils.HashPassword(params.Password) password, _ := utils.HashPassword(params.Password)
user.Password = password user.Password = password
@ -77,6 +93,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
} }
@ -106,15 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
} }
} else { } else {
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
res = &model.AuthResponse{ res = &model.AuthResponse{

View File

@ -3,8 +3,10 @@ package resolvers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@ -12,7 +14,7 @@ import (
"github.com/authorizerdev/authorizer/server/utils" "github.com/authorizerdev/authorizer/server/utils"
) )
func Token(ctx context.Context) (*model.AuthResponse, error) { func Token(ctx context.Context, role *string) (*model.AuthResponse, error) {
var res *model.AuthResponse var res *model.AuthResponse
gc, err := utils.GinContextFromContext(ctx) gc, err := utils.GinContextFromContext(ctx)
@ -25,13 +27,19 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
} }
claim, accessTokenErr := utils.VerifyAuthToken(token) claim, accessTokenErr := utils.VerifyAuthToken(token)
expiresAt := claim.ExpiresAt expiresAt := claim["exp"].(int64)
email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(claim.Email) claimRole := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil { if err != nil {
return res, err return res, err
} }
if role != nil && role != &claimRole {
return res, fmt.Errorf(`unauthorized. invalid role for a given token`)
}
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
sessionToken := session.GetToken(userIdStr) sessionToken := session.GetToken(userIdStr)
@ -46,10 +54,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 { if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 {
// if access token has expired and refresh/session token is valid // if access token has expired and refresh/session token is valid
// generate new accessToken // generate new accessToken
token, expiresAt, _ = utils.CreateAuthToken(utils.UserAuthInfo{ token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, claimRole)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
} }
utils.SetCookie(gc, token) utils.SetCookie(gc, token)
res = &model.AuthResponse{ res = &model.AuthResponse{
@ -62,6 +67,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
Image: &user.Image, Image: &user.Image,
FirstName: &user.FirstName, FirstName: &user.FirstName,
LastName: &user.LastName, LastName: &user.LastName,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
}, },

View File

@ -32,7 +32,8 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
return res, err return res, err
} }
sessionToken := session.GetToken(claim.ID) id := fmt.Sprintf("%v", claim["id"])
sessionToken := session.GetToken(id)
if sessionToken == "" { if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`) return res, fmt.Errorf(`unauthorized`)
@ -43,7 +44,8 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
return res, fmt.Errorf("please enter atleast one param to update") return res, fmt.Errorf("please enter atleast one param to update")
} }
user, err := db.Mgr.GetUserByEmail(claim.Email) email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil { if err != nil {
return res, err return res, err
} }
@ -120,9 +122,33 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
go func() { go func() {
utils.SendVerificationMail(newEmail, token) utils.SendVerificationMail(newEmail, token)
}() }()
} }
// TODO this idea needs to be verified otherwise every user can make themselves super admin
// rolesToSave := ""
// if params.Roles != nil && len(params.Roles) > 0 {
// currentRoles := strings.Split(user.Roles, ",")
// inputRoles := []string{}
// for _, item := range params.Roles {
// inputRoles = append(inputRoles, *item)
// }
// if !utils.IsValidRolesArray(inputRoles) {
// return res, fmt.Errorf("invalid list of roles")
// }
// if !utils.IsStringArrayEqual(inputRoles, currentRoles) {
// rolesToSave = strings.Join(inputRoles, ",")
// }
// session.DeleteToken(fmt.Sprintf("%v", user.ID))
// utils.DeleteCookie(gc)
// }
// if rolesToSave != "" {
// user.Roles = rolesToSave
// }
_, err = db.Mgr.UpdateUser(user) _, err = db.Mgr.UpdateUser(user)
if err != nil { if err != nil {
log.Println("Error updating user:", err) log.Println("Error updating user:", err)

View File

@ -3,8 +3,10 @@ package resolvers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@ -41,15 +43,9 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
db.Mgr.DeleteToken(claim.Email) db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID) userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{ refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{ accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
session.SetToken(userIdStr, refreshToken) session.SetToken(userIdStr, refreshToken)
@ -65,6 +61,7 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
LastName: &user.LastName, LastName: &user.LastName,
SignupMethod: user.SignupMethod, SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
}, },

View File

@ -1,29 +1,32 @@
package utils package utils
import ( import (
"encoding/json"
"fmt" "fmt"
"log" "log"
"strings" "strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum" "github.com/authorizerdev/authorizer/server/enum"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
) )
type UserAuthInfo struct { // type UserAuthInfo struct {
Email string `json:"email"` // Email string `json:"email"`
ID string `json:"id"` // ID string `json:"id"`
} // }
type JWTCustomClaim map[string]interface{}
type UserAuthClaim struct { type UserAuthClaim struct {
*jwt.StandardClaims *jwt.StandardClaims
TokenType string `json:"token_type"` *JWTCustomClaim `json:"authorizer"`
UserAuthInfo
} }
func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64, error) { func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) {
t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE)) t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE))
expiryBound := time.Hour expiryBound := time.Hour
if tokenType == enum.RefreshToken { if tokenType == enum.RefreshToken {
@ -33,12 +36,19 @@ func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64
expiresAt := time.Now().Add(expiryBound).Unix() expiresAt := time.Now().Add(expiryBound).Unix()
customClaims := JWTCustomClaim{
"token_type": tokenType.String(),
"email": user.Email,
"id": user.ID,
"allowed_roles": strings.Split(user.Roles, ","),
constants.JWT_ROLE_CLAIM: role,
}
t.Claims = &UserAuthClaim{ t.Claims = &UserAuthClaim{
&jwt.StandardClaims{ &jwt.StandardClaims{
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
}, },
tokenType.String(), &customClaims,
user,
} }
token, err := t.SignedString([]byte(constants.JWT_SECRET)) token, err := t.SignedString([]byte(constants.JWT_SECRET))
@ -63,14 +73,20 @@ func GetAuthToken(gc *gin.Context) (string, error) {
return token, nil return token, nil
} }
func VerifyAuthToken(token string) (*UserAuthClaim, error) { func VerifyAuthToken(token string) (map[string]interface{}, error) {
var res map[string]interface{}
claims := &UserAuthClaim{} claims := &UserAuthClaim{}
_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(constants.JWT_SECRET), nil return []byte(constants.JWT_SECRET), nil
}) })
if err != nil { if err != nil {
return claims, err return res, err
} }
return claims, nil data, _ := json.Marshal(claims.JWTCustomClaim)
json.Unmarshal(data, &res)
res["exp"] = claims.ExpiresAt
return res, nil
} }

20
server/utils/common.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"io"
"os"
)
func WriteToFile(filename string, data string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.WriteString(file, data)
if err != nil {
return err
}
return file.Sync()
}

View File

@ -0,0 +1,25 @@
package utils
import (
"log"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
)
// any jobs that we want to run at start of server can be executed here
// 1. create roles table and add the roles list from env to table
func InitServer() {
roles := []db.Role{}
for _, val := range constants.ROLES {
roles = append(roles, db.Role{
Role: val,
})
}
err := db.Mgr.SaveRoles(roles)
if err != nil {
log.Println(`Error saving roles`, err)
}
}

View File

@ -1,15 +0,0 @@
package utils
import (
"github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-gonic/gin"
)
func IsSuperAdmin(gc *gin.Context) bool {
secret := gc.Request.Header.Get("x-authorizer-admin-secret")
if secret == "" {
return false
}
return secret == constants.ADMIN_SECRET
}

View File

@ -5,6 +5,7 @@ import (
"strings" "strings"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-gonic/gin"
) )
func IsValidEmail(email string) bool { func IsValidEmail(email string) bool {
@ -29,3 +30,52 @@ func IsValidRedirectURL(url string) bool {
return hasValidURL return hasValidURL
} }
func IsSuperAdmin(gc *gin.Context) bool {
secret := gc.Request.Header.Get("x-authorizer-admin-secret")
if secret == "" {
return false
}
return secret == constants.ADMIN_SECRET
}
func IsValidRolesArray(roles []string) bool {
valid := true
currentRoleMap := map[string]bool{}
for _, currentRole := range constants.ROLES {
currentRoleMap[currentRole] = true
}
for _, inputRole := range roles {
if !currentRoleMap[inputRole] {
valid = false
break
}
}
return valid
}
func IsValidRole(userRoles []string, role string) bool {
valid := false
for _, currentRole := range userRoles {
if role == currentRole {
valid = true
break
}
}
return valid
}
func IsStringArrayEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}