diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index 2ecd7b2..5c38f9c 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -2630,7 +2630,8 @@ input ResendOTPRequest { } input GetUserRequest { - id: String! + id: String + email: String } type Mutation { @@ -15369,7 +15370,7 @@ func (ec *executionContext) unmarshalInputGetUserRequest(ctx context.Context, ob asMap[k] = v } - fieldsInOrder := [...]string{"id"} + fieldsInOrder := [...]string{"id", "email"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -15380,7 +15381,15 @@ func (ec *executionContext) unmarshalInputGetUserRequest(ctx context.Context, ob var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("id")) - it.ID, err = ec.unmarshalNString2string(ctx, v) + it.ID, err = ec.unmarshalOString2áš–string(ctx, v) + if err != nil { + return it, err + } + case "email": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("email")) + it.Email, err = ec.unmarshalOString2áš–string(ctx, v) if err != nil { return it, err } diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index bdc78a8..82b57ea 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -144,7 +144,8 @@ type GenerateJWTKeysResponse struct { } type GetUserRequest struct { - ID string `json:"id"` + ID *string `json:"id"` + Email *string `json:"email"` } type InviteMemberInput struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 15b5cc9..7691c9d 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -537,7 +537,8 @@ input ResendOTPRequest { } input GetUserRequest { - id: String! + id: String + email: String } type Mutation { diff --git a/server/resolvers/user.go b/server/resolvers/user.go index 956194f..994348c 100644 --- a/server/resolvers/user.go +++ b/server/resolvers/user.go @@ -3,6 +3,7 @@ package resolvers import ( "context" "fmt" + "strings" log "github.com/sirupsen/logrus" @@ -20,17 +21,28 @@ func UserResolver(ctx context.Context, params model.GetUserRequest) (*model.User log.Debug("Failed to get GinContext: ", err) return nil, err } - if !token.IsSuperAdmin(gc) { log.Debug("Not logged in as super admin.") return nil, fmt.Errorf("unauthorized") } - - res, err := db.Provider.GetUserByID(ctx, params.ID) - if err != nil { - log.Debug("Failed to get users: ", err) - return nil, err + // Try getting user by ID + if params.ID != nil && strings.Trim(*params.ID, " ") != "" { + res, err := db.Provider.GetUserByID(ctx, *params.ID) + if err != nil { + log.Debug("Failed to get users by ID: ", err) + return nil, err + } + return res.AsAPIUser(), nil } - - return res.AsAPIUser(), nil + // Try getting user by email + if params.Email != nil && strings.Trim(*params.Email, " ") != "" { + res, err := db.Provider.GetUserByEmail(ctx, *params.Email) + if err != nil { + log.Debug("Failed to get users by email: ", err) + return nil, err + } + return res.AsAPIUser(), nil + } + // Return error if no params are provided + return nil, fmt.Errorf("invalid params, user id or email is required") } diff --git a/server/test/user_test.go b/server/test/user_test.go index f7f3238..0529d29 100644 --- a/server/test/user_test.go +++ b/server/test/user_test.go @@ -8,6 +8,7 @@ import ( "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -26,7 +27,7 @@ func userTest(t *testing.T, s TestSetup) { assert.NotEmpty(t, res.User) userRes, err := resolvers.UserResolver(ctx, model.GetUserRequest{ - ID: res.User.ID, + ID: &res.User.ID, }) assert.Nil(t, userRes) assert.NotNil(t, err, "unauthorized") @@ -36,14 +37,36 @@ func userTest(t *testing.T, s TestSetup) { h, err := crypto.EncryptPassword(adminSecret) assert.Nil(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) - + // Should throw error for invalid params + userRes, err = resolvers.UserResolver(ctx, model.GetUserRequest{}) + assert.Nil(t, userRes) + assert.NotNil(t, err, "invalid params, user id or email is required") + // Should throw error for invalid params with empty id userRes, err = resolvers.UserResolver(ctx, model.GetUserRequest{ - ID: res.User.ID, + ID: refs.NewStringRef(" "), + }) + assert.Nil(t, userRes) + assert.NotNil(t, err, "invalid params, user id or email is required") + // Should throw error for invalid params with empty email + userRes, err = resolvers.UserResolver(ctx, model.GetUserRequest{ + Email: refs.NewStringRef(" "), + }) + assert.Nil(t, userRes) + assert.NotNil(t, err, "invalid params, user id or email is required") + // Should get user by id + userRes, err = resolvers.UserResolver(ctx, model.GetUserRequest{ + ID: &res.User.ID, + }) + assert.Nil(t, err) + assert.Equal(t, res.User.ID, userRes.ID) + assert.Equal(t, email, userRes.Email) + // Should get user by email + userRes, err = resolvers.UserResolver(ctx, model.GetUserRequest{ + Email: &email, }) assert.Nil(t, err) assert.Equal(t, res.User.ID, userRes.ID) assert.Equal(t, email, userRes.Email) - cleanData(email) }) }