feat/role based access (#50)

* feat: add roles based access

* feat: update roles env + todo

* feat: add roles to update profile

* feat: add role based oauth

* feat: validate role for a given token
This commit is contained in:
Lakhan Samani 2021-09-20 10:36:26 +05:30 committed by GitHub
parent 195270525c
commit 21e3425e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 544 additions and 141 deletions

View File

@ -4,4 +4,7 @@ DATABASE_TYPE=sqlite
ADMIN_SECRET=admin
DISABLE_EMAIL_VERIFICATION=true
JWT_SECRET=random_string
JWT_TYPE=HS256
JWT_TYPE=HS256
ROLES=user,admin
DEFAULT_ROLE=user
JWT_ROLE_CLAIM=role

16
TODO.md
View File

@ -0,0 +1,16 @@
# Task List
# Feature roles
For the first version we will only support setting roles master list via env
- [x] Support following ENV
- [x] `ROLES` -> comma separated list of role names
- [x] `DEFAULT_ROLE` -> default role to assign to users
- [x] Add roles input for signup
- [x] Add roles to update profile mutation
- [x] Add roles input for login
- [x] Return roles to user
- [x] Return roles in users list for super admin
- [x] Add roles to the JWT token generation
- [x] Validate token should also validate the role, if roles to validate again is present in request

View File

@ -23,6 +23,11 @@ var (
DISABLE_EMAIL_VERIFICATION = "false"
DISABLE_BASIC_AUTHENTICATION = "false"
// ROLES
ROLES = []string{}
DEFAULT_ROLE = ""
JWT_ROLE_CLAIM = "role"
// OAuth login
GOOGLE_CLIENT_ID = ""
GOOGLE_CLIENT_SECRET = ""

View File

@ -24,6 +24,7 @@ type Manager interface {
GetVerificationRequests() ([]VerificationRequest, error)
GetVerificationByEmail(email string) (VerificationRequest, error)
DeleteUser(email string) error
SaveRoles(roles []Role) error
}
type manager struct {
@ -53,7 +54,7 @@ func InitDB() {
if err != nil {
log.Fatal("Failed to init db:", err)
} else {
db.AutoMigrate(&User{}, &VerificationRequest{})
db.AutoMigrate(&User{}, &VerificationRequest{}, &Role{})
}
Mgr = &manager{db: db}

19
server/db/roles.go Normal file
View File

@ -0,0 +1,19 @@
package db
import "log"
type Role struct {
ID uint `gorm:"primaryKey"`
Role string
}
// SaveRoles function to save roles
func (mgr *manager) SaveRoles(roles []Role) error {
res := mgr.db.Create(&roles)
if res.Error != nil {
log.Println(`Error saving roles`)
return res.Error
}
return nil
}

View File

@ -17,6 +17,7 @@ type User struct {
CreatedAt int64 `gorm:"autoCreateTime"`
UpdatedAt int64 `gorm:"autoUpdateTime"`
Image string
Roles string
}
// SaveUser function to add user even with email conflict

View File

@ -73,6 +73,8 @@ func InitEnv() {
constants.RESET_PASSWORD_URL = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/")
constants.DISABLE_BASIC_AUTHENTICATION = os.Getenv("DISABLE_BASIC_AUTHENTICATION")
constants.DISABLE_EMAIL_VERIFICATION = os.Getenv("DISABLE_EMAIL_VERIFICATION")
constants.DEFAULT_ROLE = os.Getenv("DEFAULT_ROLE")
constants.JWT_ROLE_CLAIM = os.Getenv("JWT_ROLE_CLAIM")
if constants.ADMIN_SECRET == "" {
panic("root admin secret is required")
@ -143,4 +145,33 @@ func InitEnv() {
constants.DISABLE_EMAIL_VERIFICATION = "false"
}
}
rolesSplit := strings.Split(os.Getenv("ROLES"), ",")
roles := []string{}
defaultRole := ""
for _, val := range rolesSplit {
trimVal := strings.TrimSpace(val)
if trimVal != "" {
roles = append(roles, trimVal)
}
if trimVal == constants.DEFAULT_ROLE {
defaultRole = trimVal
}
}
if len(roles) > 0 && defaultRole == "" {
panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`)
}
if len(roles) == 0 {
roles = []string{"user", "admin"}
constants.DEFAULT_ROLE = "user"
}
constants.ROLES = roles
if constants.JWT_ROLE_CLAIM == "" {
constants.JWT_ROLE_CLAIM = "role"
}
}

View File

@ -80,7 +80,7 @@ type ComplexityRoot struct {
Query struct {
Meta func(childComplexity int) int
Profile func(childComplexity int) int
Token func(childComplexity int) int
Token func(childComplexity int, role *string) int
Users func(childComplexity int) int
VerificationRequests func(childComplexity int) int
}
@ -97,6 +97,7 @@ type ComplexityRoot struct {
ID func(childComplexity int) int
Image func(childComplexity int) int
LastName func(childComplexity int) int
Roles func(childComplexity int) int
SignupMethod func(childComplexity int) int
UpdatedAt func(childComplexity int) int
}
@ -126,7 +127,7 @@ type MutationResolver interface {
type QueryResolver interface {
Meta(ctx context.Context) (*model.Meta, error)
Users(ctx context.Context) ([]*model.User, error)
Token(ctx context.Context) (*model.AuthResponse, error)
Token(ctx context.Context, role *string) (*model.AuthResponse, error)
Profile(ctx context.Context) (*model.User, error)
VerificationRequests(ctx context.Context) ([]*model.VerificationRequest, error)
}
@ -359,7 +360,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
break
}
return e.complexity.Query.Token(childComplexity), true
args, err := ec.field_Query_token_args(context.TODO(), rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Query.Token(childComplexity, args["role"].(*string)), true
case "Query.users":
if e.complexity.Query.Users == nil {
@ -431,6 +437,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.User.LastName(childComplexity), true
case "User.roles":
if e.complexity.User.Roles == nil {
break
}
return e.complexity.User.Roles(childComplexity), true
case "User.signupMethod":
if e.complexity.User.SignupMethod == nil {
break
@ -562,6 +575,8 @@ var sources = []*ast.Source{
#
# https://gqlgen.com/getting-started/
scalar Int64
scalar Map
scalar Any
type Meta {
version: String!
@ -583,6 +598,7 @@ type User {
image: String
createdAt: Int64
updatedAt: Int64
roles: [String!]!
}
type VerificationRequest {
@ -618,11 +634,13 @@ input SignUpInput {
password: String!
confirmPassword: String!
image: String
roles: [String]
}
input LoginInput {
email: String!
password: String!
role: String
}
input VerifyEmailInput {
@ -641,6 +659,7 @@ input UpdateProfileInput {
lastName: String
image: String
email: String
# roles: [String]
}
input ForgotPasswordInput {
@ -672,7 +691,7 @@ type Mutation {
type Query {
meta: Meta!
users: [User!]!
token: AuthResponse
token(role: String): AuthResponse
profile: User!
verificationRequests: [VerificationRequest!]!
}
@ -819,6 +838,21 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
return args, nil
}
func (ec *executionContext) field_Query_token_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
var arg0 *string
if tmp, ok := rawArgs["role"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
arg0, err = ec.unmarshalOString2ᚖstring(ctx, tmp)
if err != nil {
return nil, err
}
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
@ -1760,9 +1794,16 @@ func (ec *executionContext) _Query_token(ctx context.Context, field graphql.Coll
}
ctx = graphql.WithFieldContext(ctx, fc)
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.field_Query_token_args(ctx, rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
fc.Args = args
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().Token(rctx)
return ec.resolvers.Query().Token(rctx, args["role"].(*string))
})
if err != nil {
ec.Error(ctx, err)
@ -2249,6 +2290,41 @@ func (ec *executionContext) _User_updatedAt(ctx context.Context, field graphql.C
return ec.marshalOInt642ᚖint64(ctx, field.Selections, res)
}
func (ec *executionContext) _User_roles(ctx context.Context, field graphql.CollectedField, obj *model.User) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
fc := &graphql.FieldContext{
Object: "User",
Field: field,
Args: nil,
IsMethod: false,
IsResolver: false,
}
ctx = graphql.WithFieldContext(ctx, fc)
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.Roles, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.([]string)
fc.Result = res
return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res)
}
func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
@ -3625,6 +3701,14 @@ func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj in
if err != nil {
return it, err
}
case "role":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
it.Role, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
}
}
@ -3741,6 +3825,14 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i
if err != nil {
return it, err
}
case "roles":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles"))
it.Roles, err = ec.unmarshalOString2ᚕᚖstring(ctx, v)
if err != nil {
return it, err
}
}
}
@ -4198,6 +4290,11 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
out.Values[i] = ec._User_createdAt(ctx, field, obj)
case "updatedAt":
out.Values[i] = ec._User_updatedAt(ctx, field, obj)
case "roles":
out.Values[i] = ec._User_roles(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
@ -4610,6 +4707,36 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S
return res
}
func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) {
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalNUpdateProfileInput2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUpdateProfileInput(ctx context.Context, v interface{}) (model.UpdateProfileInput, error) {
res, err := ec.unmarshalInputUpdateProfileInput(ctx, v)
return res, graphql.ErrorOnPath(ctx, err)
@ -5002,6 +5129,42 @@ func (ec *executionContext) marshalOString2string(ctx context.Context, sel ast.S
return graphql.MarshalString(v)
}
func (ec *executionContext) unmarshalOString2ᚕᚖstring(ctx context.Context, v interface{}) ([]*string, error) {
if v == nil {
return nil, nil
}
var vSlice []interface{}
if v != nil {
if tmp1, ok := v.([]interface{}); ok {
vSlice = tmp1
} else {
vSlice = []interface{}{v}
}
}
var err error
res := make([]*string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalOString2ᚖstring(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalOString2ᚕᚖstring(ctx context.Context, sel ast.SelectionSet, v []*string) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalOString2ᚖstring(ctx, sel, v[i])
}
return ret
}
func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) {
if v == nil {
return nil, nil

View File

@ -23,8 +23,9 @@ type ForgotPasswordInput struct {
}
type LoginInput struct {
Email string `json:"email"`
Password string `json:"password"`
Email string `json:"email"`
Password string `json:"password"`
Role *string `json:"role"`
}
type Meta struct {
@ -52,12 +53,13 @@ type Response struct {
}
type SignUpInput struct {
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
Email string `json:"email"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
Image *string `json:"image"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
Email string `json:"email"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
Image *string `json:"image"`
Roles []*string `json:"roles"`
}
type UpdateProfileInput struct {
@ -71,15 +73,16 @@ type UpdateProfileInput struct {
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`
SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"`
ID string `json:"id"`
Email string `json:"email"`
SignupMethod string `json:"signupMethod"`
FirstName *string `json:"firstName"`
LastName *string `json:"lastName"`
EmailVerifiedAt *int64 `json:"emailVerifiedAt"`
Image *string `json:"image"`
CreatedAt *int64 `json:"createdAt"`
UpdatedAt *int64 `json:"updatedAt"`
Roles []string `json:"roles"`
}
type VerificationRequest struct {

View File

@ -2,6 +2,8 @@
#
# https://gqlgen.com/getting-started/
scalar Int64
scalar Map
scalar Any
type Meta {
version: String!
@ -23,6 +25,7 @@ type User {
image: String
createdAt: Int64
updatedAt: Int64
roles: [String!]!
}
type VerificationRequest {
@ -58,11 +61,13 @@ input SignUpInput {
password: String!
confirmPassword: String!
image: String
roles: [String]
}
input LoginInput {
email: String!
password: String!
role: String
}
input VerifyEmailInput {
@ -81,6 +86,7 @@ input UpdateProfileInput {
lastName: String
image: String
email: String
# roles: [String]
}
input ForgotPasswordInput {
@ -112,7 +118,7 @@ type Mutation {
type Query {
meta: Meta!
users: [User!]!
token: AuthResponse
token(role: String): AuthResponse
profile: User!
verificationRequests: [VerificationRequest!]!
}

View File

@ -55,8 +55,8 @@ func (r *queryResolver) Users(ctx context.Context) ([]*model.User, error) {
return resolvers.Users(ctx)
}
func (r *queryResolver) Token(ctx context.Context) (*model.AuthResponse, error) {
return resolvers.Token(ctx)
func (r *queryResolver) Token(ctx context.Context, role *string) (*model.AuthResponse, error) {
return resolvers.Token(ctx, role)
}
func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) {

View File

@ -25,7 +25,6 @@ func AppHandler() gin.HandlerFunc {
if state == "" {
// cookie, err := utils.GetAuthToken(c)
// log.Println(`cookie`, cookie)
// if err != nil {
// c.JSON(400, gin.H{"error": "invalid state"})
// return
@ -67,13 +66,6 @@ func AppHandler() gin.HandlerFunc {
}
}
log.Println(gin.H{
"data": map[string]string{
"authorizerURL": stateObj.AuthorizerURL,
"redirectURL": stateObj.RedirectURL,
},
})
// debug the request state
if pusher := c.Writer.Pusher(); pusher != nil {
// use pusher.Push() to do server push

View File

@ -19,7 +19,7 @@ import (
"golang.org/x/oauth2"
)
func processGoogleUserInfo(code string, c *gin.Context) error {
func processGoogleUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("invalid google exchange code: %s", err.Error())
@ -50,6 +50,7 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
if err != nil {
// user not registered, register user and generate session token
user.SignupMethod = enum.Google.String()
user.Roles = role
} else {
// user exists in db, check if method was google
// if not append google to existing signup method and save it
@ -60,27 +61,26 @@ func processGoogleUserInfo(code string, c *gin.Context) error {
}
user.SignupMethod = signupMethod
user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
}
user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken)
return nil
}
func processGithubUserInfo(code string, c *gin.Context) error {
func processGithubUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("invalid github exchange code: %s", err.Error())
@ -128,6 +128,7 @@ func processGithubUserInfo(code string, c *gin.Context) error {
if err != nil {
// user not registered, register user and generate session token
user.SignupMethod = enum.Github.String()
user.Roles = role
} else {
// user exists in db, check if method was google
// if not append google to existing signup method and save it
@ -138,26 +139,26 @@ func processGithubUserInfo(code string, c *gin.Context) error {
}
user.SignupMethod = signupMethod
user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
}
user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken)
return nil
}
func processFacebookUserInfo(code string, c *gin.Context) error {
func processFacebookUserInfo(code string, role string, c *gin.Context) error {
token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("invalid facebook exchange code: %s", err.Error())
@ -199,6 +200,7 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
if err != nil {
// user not registered, register user and generate session token
user.SignupMethod = enum.Github.String()
user.Roles = role
} else {
// user exists in db, check if method was google
// if not append google to existing signup method and save it
@ -209,20 +211,20 @@ func processFacebookUserInfo(code string, c *gin.Context) error {
}
user.SignupMethod = signupMethod
user.Password = existingUser.Password
if !utils.IsValidRole(strings.Split(existingUser.Roles, ","), role) {
return fmt.Errorf("invalid role")
}
user.Roles = existingUser.Roles
}
user, _ = db.Mgr.SaveUser(user)
user, _ = db.Mgr.GetUserByEmail(user.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
utils.SetCookie(c, accessToken)
session.SetToken(userIdStr, refreshToken)
return nil
@ -238,23 +240,27 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": "invalid oauth state"})
}
session.DeleteToken(sessionState)
// contains random token, redirect url, role
sessionSplit := strings.Split(state, "___")
// TODO validate redirect url
if len(sessionSplit) != 2 {
if len(sessionSplit) < 2 {
c.JSON(400, gin.H{"error": "invalid redirect url"})
return
}
role := sessionSplit[2]
redirectURL := sessionSplit[1]
var err error
code := c.Request.FormValue("code")
switch provider {
case enum.Google.String():
err = processGoogleUserInfo(code, c)
err = processGoogleUserInfo(code, role, c)
case enum.Github.String():
err = processGithubUserInfo(code, c)
err = processGithubUserInfo(code, role, c)
case enum.Facebook.String():
err = processFacebookUserInfo(code, c)
err = processFacebookUserInfo(code, role, c)
default:
err = fmt.Errorf(`invalid oauth provider`)
}
@ -263,6 +269,6 @@ func OAuthCallbackHandler() gin.HandlerFunc {
c.JSON(400, gin.H{"error": err.Error()})
return
}
c.Redirect(http.StatusTemporaryRedirect, sessionSplit[1])
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@ -17,6 +18,7 @@ func OAuthLoginHandler() gin.HandlerFunc {
return func(c *gin.Context) {
// TODO validate redirect URL
redirectURL := c.Query("redirectURL")
role := c.Query("role")
if redirectURL == "" {
c.JSON(400, gin.H{
@ -24,8 +26,21 @@ func OAuthLoginHandler() gin.HandlerFunc {
})
return
}
if role != "" {
// validate role
if !utils.IsValidRole(constants.ROLES, role) {
c.JSON(400, gin.H{
"error": "invalid role",
})
return
}
} else {
role = constants.DEFAULT_ROLE
}
uuid := uuid.New()
oauthStateString := uuid.String() + "___" + redirectURL
oauthStateString := uuid.String() + "___" + redirectURL + "___" + role
provider := c.Param("oauth_provider")

View File

@ -50,15 +50,9 @@ func VerifyEmailHandler() gin.HandlerFunc {
db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, user.Roles)
accessToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, user.Roles)
session.SetToken(userIdStr, refreshToken)
utils.SetCookie(c, accessToken)

View File

@ -9,6 +9,7 @@ import (
"github.com/authorizerdev/authorizer/server/handlers"
"github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-contrib/location"
"github.com/gin-gonic/gin"
)
@ -50,6 +51,7 @@ func main() {
db.InitDB()
session.InitSession()
oauth.InitOAuth()
utils.InitServer()
r := gin.Default()
r.Use(location.Default())

View File

@ -46,16 +46,19 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
log.Println("Compare password error:", err)
return res, fmt.Errorf(`invalid password`)
}
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
role := constants.DEFAULT_ROLE
if params.Role != nil {
// validate role
if !utils.IsValidRole(strings.Split(user.Roles, ","), *params.Role) {
return res, fmt.Errorf(`invalid role`)
}
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
role = *params.Role
}
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, role)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, role)
session.SetToken(userIdStr, refreshToken)
@ -71,6 +74,7 @@ func Login(ctx context.Context, params model.LoginInput) (*model.AuthResponse, e
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
},

View File

@ -2,6 +2,7 @@ package resolvers
import (
"context"
"fmt"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/session"
@ -25,7 +26,8 @@ func Logout(ctx context.Context) (*model.Response, error) {
return res, err
}
session.DeleteToken(claim.ID)
userId := fmt.Sprintf("%v", claim["id"])
session.DeleteToken(userId)
res = &model.Response{
Message: "Logged out successfully",
}

View File

@ -3,6 +3,7 @@ package resolvers
import (
"context"
"fmt"
"strings"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
@ -27,13 +28,15 @@ func Profile(ctx context.Context) (*model.User, error) {
return res, err
}
sessionToken := session.GetToken(claim.ID)
userID := fmt.Sprintf("%v", claim["id"])
email := fmt.Sprintf("%v", claim["email"])
sessionToken := session.GetToken(userID)
if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`)
}
user, err := db.Mgr.GetUserByEmail(claim.Email)
user, err := db.Mgr.GetUserByEmail(email)
if err != nil {
return res, err
}
@ -48,6 +51,7 @@ func Profile(ctx context.Context) (*model.User, error) {
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
}

View File

@ -35,6 +35,20 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
return res, fmt.Errorf(`invalid email address`)
}
inputRoles := []string{}
if params.Roles != nil && len(params.Roles) > 0 {
// check if roles exists
for _, item := range params.Roles {
inputRoles = append(inputRoles, *item)
}
if !utils.IsValidRolesArray(inputRoles) {
return res, fmt.Errorf(`invalid roles`)
}
} else {
inputRoles = []string{constants.DEFAULT_ROLE}
}
// find user with email
existingUser, err := db.Mgr.GetUserByEmail(params.Email)
if err != nil {
@ -49,6 +63,8 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
Email: params.Email,
}
user.Roles = strings.Join(inputRoles, ",")
password, _ := utils.HashPassword(params.Password)
user.Password = password
@ -77,6 +93,7 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
}
@ -106,15 +123,9 @@ func Signup(ctx context.Context, params model.SignUpInput) (*model.AuthResponse,
}
} else {
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
session.SetToken(userIdStr, refreshToken)
res = &model.AuthResponse{

View File

@ -3,8 +3,10 @@ package resolvers
import (
"context"
"fmt"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model"
@ -12,7 +14,7 @@ import (
"github.com/authorizerdev/authorizer/server/utils"
)
func Token(ctx context.Context) (*model.AuthResponse, error) {
func Token(ctx context.Context, role *string) (*model.AuthResponse, error) {
var res *model.AuthResponse
gc, err := utils.GinContextFromContext(ctx)
@ -25,13 +27,19 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
}
claim, accessTokenErr := utils.VerifyAuthToken(token)
expiresAt := claim.ExpiresAt
expiresAt := claim["exp"].(int64)
email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(claim.Email)
claimRole := fmt.Sprintf("%v", claim[constants.JWT_ROLE_CLAIM])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil {
return res, err
}
if role != nil && role != &claimRole {
return res, fmt.Errorf(`unauthorized. invalid role for a given token`)
}
userIdStr := fmt.Sprintf("%v", user.ID)
sessionToken := session.GetToken(userIdStr)
@ -46,10 +54,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
if accessTokenErr != nil || expiresTimeObj.Sub(currentTimeObj).Minutes() <= 5 {
// if access token has expired and refresh/session token is valid
// generate new accessToken
token, expiresAt, _ = utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
token, expiresAt, _ = utils.CreateAuthToken(user, enum.AccessToken, claimRole)
}
utils.SetCookie(gc, token)
res = &model.AuthResponse{
@ -62,6 +67,7 @@ func Token(ctx context.Context) (*model.AuthResponse, error) {
Image: &user.Image,
FirstName: &user.FirstName,
LastName: &user.LastName,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
},

View File

@ -32,7 +32,8 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
return res, err
}
sessionToken := session.GetToken(claim.ID)
id := fmt.Sprintf("%v", claim["id"])
sessionToken := session.GetToken(id)
if sessionToken == "" {
return res, fmt.Errorf(`unauthorized`)
@ -43,7 +44,8 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
return res, fmt.Errorf("please enter atleast one param to update")
}
user, err := db.Mgr.GetUserByEmail(claim.Email)
email := fmt.Sprintf("%v", claim["email"])
user, err := db.Mgr.GetUserByEmail(email)
if err != nil {
return res, err
}
@ -120,9 +122,33 @@ func UpdateProfile(ctx context.Context, params model.UpdateProfileInput) (*model
go func() {
utils.SendVerificationMail(newEmail, token)
}()
}
// TODO this idea needs to be verified otherwise every user can make themselves super admin
// rolesToSave := ""
// if params.Roles != nil && len(params.Roles) > 0 {
// currentRoles := strings.Split(user.Roles, ",")
// inputRoles := []string{}
// for _, item := range params.Roles {
// inputRoles = append(inputRoles, *item)
// }
// if !utils.IsValidRolesArray(inputRoles) {
// return res, fmt.Errorf("invalid list of roles")
// }
// if !utils.IsStringArrayEqual(inputRoles, currentRoles) {
// rolesToSave = strings.Join(inputRoles, ",")
// }
// session.DeleteToken(fmt.Sprintf("%v", user.ID))
// utils.DeleteCookie(gc)
// }
// if rolesToSave != "" {
// user.Roles = rolesToSave
// }
_, err = db.Mgr.UpdateUser(user)
if err != nil {
log.Println("Error updating user:", err)

View File

@ -3,8 +3,10 @@ package resolvers
import (
"context"
"fmt"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum"
"github.com/authorizerdev/authorizer/server/graph/model"
@ -41,15 +43,9 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
db.Mgr.DeleteToken(claim.Email)
userIdStr := fmt.Sprintf("%v", user.ID)
refreshToken, _, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.RefreshToken)
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, constants.DEFAULT_ROLE)
accessToken, expiresAt, _ := utils.CreateAuthToken(utils.UserAuthInfo{
ID: userIdStr,
Email: user.Email,
}, enum.AccessToken)
accessToken, expiresAt, _ := utils.CreateAuthToken(user, enum.AccessToken, constants.DEFAULT_ROLE)
session.SetToken(userIdStr, refreshToken)
@ -65,6 +61,7 @@ func VerifyEmail(ctx context.Context, params model.VerifyEmailInput) (*model.Aut
LastName: &user.LastName,
SignupMethod: user.SignupMethod,
EmailVerifiedAt: &user.EmailVerifiedAt,
Roles: strings.Split(user.Roles, ","),
CreatedAt: &user.CreatedAt,
UpdatedAt: &user.UpdatedAt,
},

View File

@ -1,29 +1,32 @@
package utils
import (
"encoding/json"
"fmt"
"log"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/enum"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
)
type UserAuthInfo struct {
Email string `json:"email"`
ID string `json:"id"`
}
// type UserAuthInfo struct {
// Email string `json:"email"`
// ID string `json:"id"`
// }
type JWTCustomClaim map[string]interface{}
type UserAuthClaim struct {
*jwt.StandardClaims
TokenType string `json:"token_type"`
UserAuthInfo
*JWTCustomClaim `json:"authorizer"`
}
func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64, error) {
func CreateAuthToken(user db.User, tokenType enum.TokenType, role string) (string, int64, error) {
t := jwt.New(jwt.GetSigningMethod(constants.JWT_TYPE))
expiryBound := time.Hour
if tokenType == enum.RefreshToken {
@ -33,12 +36,19 @@ func CreateAuthToken(user UserAuthInfo, tokenType enum.TokenType) (string, int64
expiresAt := time.Now().Add(expiryBound).Unix()
customClaims := JWTCustomClaim{
"token_type": tokenType.String(),
"email": user.Email,
"id": user.ID,
"allowed_roles": strings.Split(user.Roles, ","),
constants.JWT_ROLE_CLAIM: role,
}
t.Claims = &UserAuthClaim{
&jwt.StandardClaims{
ExpiresAt: expiresAt,
},
tokenType.String(),
user,
&customClaims,
}
token, err := t.SignedString([]byte(constants.JWT_SECRET))
@ -63,14 +73,20 @@ func GetAuthToken(gc *gin.Context) (string, error) {
return token, nil
}
func VerifyAuthToken(token string) (*UserAuthClaim, error) {
func VerifyAuthToken(token string) (map[string]interface{}, error) {
var res map[string]interface{}
claims := &UserAuthClaim{}
_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(constants.JWT_SECRET), nil
})
if err != nil {
return claims, err
return res, err
}
return claims, nil
data, _ := json.Marshal(claims.JWTCustomClaim)
json.Unmarshal(data, &res)
res["exp"] = claims.ExpiresAt
return res, nil
}

20
server/utils/common.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"io"
"os"
)
func WriteToFile(filename string, data string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.WriteString(file, data)
if err != nil {
return err
}
return file.Sync()
}

View File

@ -0,0 +1,25 @@
package utils
import (
"log"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
)
// any jobs that we want to run at start of server can be executed here
// 1. create roles table and add the roles list from env to table
func InitServer() {
roles := []db.Role{}
for _, val := range constants.ROLES {
roles = append(roles, db.Role{
Role: val,
})
}
err := db.Mgr.SaveRoles(roles)
if err != nil {
log.Println(`Error saving roles`, err)
}
}

View File

@ -1,15 +0,0 @@
package utils
import (
"github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-gonic/gin"
)
func IsSuperAdmin(gc *gin.Context) bool {
secret := gc.Request.Header.Get("x-authorizer-admin-secret")
if secret == "" {
return false
}
return secret == constants.ADMIN_SECRET
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-gonic/gin"
)
func IsValidEmail(email string) bool {
@ -29,3 +30,52 @@ func IsValidRedirectURL(url string) bool {
return hasValidURL
}
func IsSuperAdmin(gc *gin.Context) bool {
secret := gc.Request.Header.Get("x-authorizer-admin-secret")
if secret == "" {
return false
}
return secret == constants.ADMIN_SECRET
}
func IsValidRolesArray(roles []string) bool {
valid := true
currentRoleMap := map[string]bool{}
for _, currentRole := range constants.ROLES {
currentRoleMap[currentRole] = true
}
for _, inputRole := range roles {
if !currentRoleMap[inputRole] {
valid = false
break
}
}
return valid
}
func IsValidRole(userRoles []string, role string) bool {
valid := false
for _, currentRole := range userRoles {
if role == currentRole {
valid = true
break
}
}
return valid
}
func IsStringArrayEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}