feat: add roles based access

This commit is contained in:
Lakhan Samani
2021-09-18 16:56:51 +05:30
parent 195270525c
commit a6ce563d46
14 changed files with 261 additions and 44 deletions

View File

@@ -23,6 +23,10 @@ var (
DISABLE_EMAIL_VERIFICATION = "false" DISABLE_EMAIL_VERIFICATION = "false"
DISABLE_BASIC_AUTHENTICATION = "false" DISABLE_BASIC_AUTHENTICATION = "false"
// ROLES
ROLES = []string{}
DEFAULT_ROLE = ""
// OAuth login // OAuth login
GOOGLE_CLIENT_ID = "" GOOGLE_CLIENT_ID = ""
GOOGLE_CLIENT_SECRET = "" GOOGLE_CLIENT_SECRET = ""

View File

@@ -24,6 +24,7 @@ type Manager interface {
GetVerificationRequests() ([]VerificationRequest, error) GetVerificationRequests() ([]VerificationRequest, error)
GetVerificationByEmail(email string) (VerificationRequest, error) GetVerificationByEmail(email string) (VerificationRequest, error)
DeleteUser(email string) error DeleteUser(email string) error
SaveRoles(roles []Role) error
} }
type manager struct { type manager struct {
@@ -53,7 +54,7 @@ func InitDB() {
if err != nil { if err != nil {
log.Fatal("Failed to init db:", err) log.Fatal("Failed to init db:", err)
} else { } else {
db.AutoMigrate(&User{}, &VerificationRequest{}) db.AutoMigrate(&User{}, &VerificationRequest{}, &Role{})
} }
Mgr = &manager{db: db} Mgr = &manager{db: db}

19
server/db/roles.go Normal file
View File

@@ -0,0 +1,19 @@
package db
import "log"
type Role struct {
ID uint `gorm:"primaryKey"`
Role string
}
// SaveRoles function to save roles
func (mgr *manager) SaveRoles(roles []Role) error {
res := mgr.db.Create(&roles)
if res.Error != nil {
log.Println(`Error saving roles`)
return res.Error
}
return nil
}

View File

@@ -17,6 +17,7 @@ type User struct {
CreatedAt int64 `gorm:"autoCreateTime"` CreatedAt int64 `gorm:"autoCreateTime"`
UpdatedAt int64 `gorm:"autoUpdateTime"` UpdatedAt int64 `gorm:"autoUpdateTime"`
Image string Image string
Roles string
} }
// SaveUser function to add user even with email conflict // SaveUser function to add user even with email conflict

View File

@@ -73,6 +73,7 @@ func InitEnv() {
constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/") constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/")
constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION") constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION")
constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION") constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION")
constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE")
if constants.ADMIN_SECRET == "" { if constants.ADMIN_SECRET == "" {
panic("root admin secret is required") panic("root admin secret is required")
@@ -143,4 +144,28 @@ func InitEnv() {
constants.DISABLE_EMAIL_VERIFICATION = "false" constants.DISABLE_EMAIL_VERIFICATION = "false"
} }
} }
rolesSplit := strings.Split(os.Getenv("ROLES"), ",")
roles := []string{}
defaultRole := ""
for _, val := range rolesSplit {
trimVal := strings.TrimSpace(val)
if trimVal != "" {
roles = append(roles, trimVal)
}
if trimVal == constants.DEFAULT_ROLE {
defaultRole = trimVal
}
}
if len(roles) > 0 && defaultRole == "" {
panic(`Invalid DEFAULT_ROLE environment. It can be one from give ROLES environment variable value`)
}
if len(roles) == 0 {
roles = []string{"user", "admin"}
constants.DEFAULT_ROLE = "user"
}
constants.ROLES = roles
} }

View File

@@ -97,6 +97,7 @@ type ComplexityRoot struct {
ID func(childComplexity int) int ID func(childComplexity int) int
Image func(childComplexity int) int Image func(childComplexity int) int
LastName func(childComplexity int) int LastName func(childComplexity int) int
Roles func(childComplexity int) int
SignupMethod func(childComplexity int) int SignupMethod func(childComplexity int) int
UpdatedAt func(childComplexity int) int UpdatedAt func(childComplexity int) int
} }
@@ -431,6 +432,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.User.LastName(childComplexity), true return e.complexity.User.LastName(childComplexity), true
case "User.roles":
if e.complexity.User.Roles == nil {
break
}
return e.complexity.User.Roles(childComplexity), true
case "User.signupMethod": case "User.signupMethod":
if e.complexity.User.SignupMethod == nil { if e.complexity.User.SignupMethod == nil {
break break
@@ -583,6 +591,7 @@ type User {
image: String image: String
createdAt: Int64 createdAt: Int64
updatedAt: Int64 updatedAt: Int64
roles: [String]
} }
type VerificationRequest { type VerificationRequest {
@@ -618,11 +627,13 @@ input SignUpInput {
password: String! password: String!
confirmPassword: String! confirmPassword: String!
image: String image: String
roles: [String]
} }
input LoginInput { input LoginInput {
email: String! email: String!
password: String! password: String!
role: String
} }
input VerifyEmailInput { input VerifyEmailInput {
@@ -641,6 +652,7 @@ input UpdateProfileInput {
lastName: String lastName: String
image: String image: String
email: String email: String
roles: [String]
} }
input ForgotPasswordInput { input ForgotPasswordInput {
@@ -2249,6 +2261,38 @@ func (ec *executionContext) _User_updatedAt(ctx context.Context, field graphql.C
return ec.marshalOInt642ᚖint64(ctx, field.Selections, res) return ec.marshalOInt642ᚖint64(ctx, field.Selections, res)
} }
func (ec *executionContext) _User_roles(ctx context.Context, field graphql.CollectedField, obj *model.User) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "User",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Roles, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.([]*string)
fc.Result = res
return ec.marshalOString2ᚕᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -3625,6 +3669,14 @@ func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj in
if err != nil { if err != nil {
return it, err return it, err
} }
case "role":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
it.Role, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@@ -3741,6 +3793,14 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i
if err != nil { if err != nil {
return it, err return it, err
} }
case "roles":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles"))
it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@@ -3809,6 +3869,14 @@ func (ec *executionContext) unmarshalInputUpdateProfileInput(ctx context.Context
if err != nil { if err != nil {
return it, err return it, err
} }
case "roles":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles"))
it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@@ -4198,6 +4266,8 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
out.Values[i] = ec._User_createdAt(ctx, field, obj) out.Values[i] = ec._User_createdAt(ctx, field, obj)
case "updatedAt": case "updatedAt":
out.Values[i] = ec._User_updatedAt(ctx, field, obj) out.Values[i] = ec._User_updatedAt(ctx, field, obj)
case "roles":
out.Values[i] = ec._User_roles(ctx, field, obj)
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@@ -5002,6 +5072,42 @@ func (ec *executionContext) marshalOString2string(ctx context.Context, sel ast.S
return graphql.MarshalString(v) return graphql.MarshalString(v)
} }
func (ec *executionContext) unmarshalOString2ᚕᚖstring(ctx context.Context, v interface{}) ([]*string, error) {
if v == nil {
return nil, nil
}
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]*string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalOString2ᚖstring(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalOString2ᚕᚖstring(ctx context.Context, sel ast.SelectionSet, v []*string) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalOString2ᚖstring(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) { func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) {
if v == nil { if v == nil {
return nil, nil return nil, nil

View File

@@ -23,8 +23,9 @@ type ForgotPasswordInput struct {
} }
type LoginInput struct { type LoginInput struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
Role *string `json:"role"`
} }
type Meta struct { type Meta struct {
@@ -52,34 +53,37 @@ type Response struct {
} }
type SignUpInput struct { type SignUpInput struct {
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"` ConfirmPassword string `json:"confirmPassword"`
Image *string `json:"image"` Image *string `json:"image"`
Roles []*string `json:"roles"`
} }
type UpdateProfileInput struct { type UpdateProfileInput struct {
OldPassword *string `json:"oldPassword"` OldPassword *string `json:"oldPassword"`
NewPassword *string `json:"newPassword"` NewPassword *string `json:"newPassword"`
ConfirmNewPassword *string `json:"confirmNewPassword"` ConfirmNewPassword *string `json:"confirmNewPassword"`
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
Image *string `json:"image"` Image *string `json:"image"`
Email *string `json:"email"` Email *string `json:"email"`
Roles []*string `json:"roles"`
} }
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
SignupMethod string `json:"signupMethod"` SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"` FirstName *string `json:"firstName"`
LastName *string `json:"lastName"` LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"` EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"` Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"` CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"` UpdatedAt *int64 `json:"updatedAt"`
Roles []*string `json:"roles"`
} }
type VerificationRequest struct { type VerificationRequest struct {

View File

@@ -23,6 +23,7 @@ type User {
image: String image: String
createdAt: Int64 createdAt: Int64
updatedAt: Int64 updatedAt: Int64
roles: [String]
} }
type VerificationRequest { type VerificationRequest {
@@ -58,11 +59,13 @@ input SignUpInput {
password: String! password: String!
confirmPassword: String! confirmPassword: String!
image: String image: String
roles: [String]
} }
input LoginInput { input LoginInput {
email: String! email: String!
password: String! password: String!
role: String
} }
input VerifyEmailInput { input VerifyEmailInput {
@@ -81,6 +84,7 @@ input UpdateProfileInput {
lastName: String lastName: String
image: String image: String
email: String email: String
roles: [String]
} }
input ForgotPasswordInput { input ForgotPasswordInput {

View File

@@ -73,7 +73,5 @@ func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResol
// Query returns generated.QueryResolver implementation. // Query returns generated.QueryResolver implementation.
func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} }
type ( type mutationResolver struct{ *Resolver }
mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver }
queryResolver struct{ *Resolver }
)

View File

@@ -9,6 +9,7 @@ import (
"github.com/authorizerdev/authorizer/server/handlers" "github.com/authorizerdev/authorizer/server/handlers"
"github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-contrib/location" "github.com/gin-contrib/location"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -50,6 +51,7 @@ func main() {
db.InitDB() db.InitDB()
session.InitSession() session.InitSession()
oauth.InitOAuth() oauth.InitOAuth()
utils.InitServer()
r := gin.Default() r := gin.Default()
r.Use(location.Default()) r.Use(location.Default())

View File

@@ -35,6 +35,15 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
return res, fmt.Errorf(`invalid email address`) return res, fmt.Errorf(`invalid email address`)
} }
if len(params.Roles) > 0 {
// check if roles exists
if !utils.IsValidRolesArray(params.Roles) {
return res, fmt.Errorf(`invalid roles`)
}
} else {
params.Roles = []*string{&constants.DEFAULT_ROLE}
}
// find user with email // find user with email
existingUser, err := db.Mgr.GetUserByEmail(params.Email) existingUser, err := db.Mgr.GetUserByEmail(params.Email)
if err != nil { if err != nil {
@@ -49,6 +58,13 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
Email: params.Email, Email: params.Email,
} }
roles := ""
for _, roleInput := range params.Roles {
roles += *roleInput + ","
}
roles = strings.TrimSuffix(roles, ",")
user.Roles = roles
password, _ := utils.HashPassword(params.Password) password, _ := utils.HashPassword(params.Password)
user.Password = password user.Password = password
@@ -79,6 +95,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
EmailVerifiedAt: &user.EmailVerifiedAt, EmailVerifiedAt: &user.EmailVerifiedAt,
CreatedAt: &user.CreatedAt, CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt, UpdatedAt: &user.UpdatedAt,
Roles: params.Roles,
} }
if constants.DISABLE_EMAIL_VERIFICATION != "true" { if constants.DISABLE_EMAIL_VERIFICATION != "true" {

View File

@@ -0,0 +1,25 @@
package utils
import (
"log"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
)
// any jobs that we want to run at start of server can be executed here
// 1. create roles table and add the roles list from env to table
func InitServer() {
roles := []db.Role{}
for _, val := range constants.ROLES {
roles = append(roles, db.Role{
Role: val,
})
}
err := db.Mgr.SaveRoles(roles)
if err != nil {
log.Println(`Error saving roles`, err)
}
}

View File

@@ -1,15 +0,0 @@
package utils
import (
"github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-gonic/gin"
)
func IsSuperAdmin(gc *gin.Context) bool {
secret := gc.Request.Header.Get("x-authorizer-admin-secret")
if secret == "" {
return false
}
return secret == constants.ADMIN_SECRET
}

View File

@@ -5,6 +5,7 @@ import (
"strings" "strings"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-gonic/gin"
) )
func IsValidEmail(email string) bool { func IsValidEmail(email string) bool {
@@ -29,3 +30,28 @@ func IsValidRedirectURL(url string) bool {
return hasValidURL return hasValidURL
} }
func IsSuperAdmin(gc *gin.Context) bool {
secret := gc.Request.Header.Get("x-authorizer-admin-secret")
if secret == "" {
return false
}
return secret == constants.ADMIN_SECRET
}
func IsValidRolesArray(roles []*string) bool {
valid := true
currentRoleMap := map[string]bool{}
for _, currentRole := range constants.ROLES {
currentRoleMap[currentRole] = true
}
for _, inputRole := range roles {
if !currentRoleMap[*inputRole] {
valid = false
break
}
}
return valid
}