Compare commits

...

4 Commits

Author SHA1 Message Date
Lakhan Samani
a124edfaee Add user to validate_session
Resolves #379
2023-08-19 20:45:20 +05:30
Lakhan Samani
5e6b033024 fix microsoft active directory config 2023-08-17 14:20:31 +05:30
Lakhan Samani
171d4e3fff remove unused code 2023-08-14 14:16:54 +05:30
Lakhan Samani
cf96a0087f Fix tests for verifying otp using mfa session 2023-08-14 14:15:52 +05:30
14 changed files with 204 additions and 1311 deletions

View File

@@ -74,7 +74,6 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination *mod
var verificationRequest models.VerificationRequest
err := scanner.Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt)
if err != nil {
fmt.Println("=> getting error here...", err)
return nil, err
}
verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest())

File diff suppressed because it is too large Load Diff

View File

@@ -279,6 +279,7 @@ type ComplexityRoot struct {
ValidateSessionResponse struct {
IsValid func(childComplexity int) int
User func(childComplexity int) int
}
VerificationRequest struct {
@@ -1871,6 +1872,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.ValidateSessionResponse.IsValid(childComplexity), true
case "ValidateSessionResponse.user":
if e.complexity.ValidateSessionResponse.User == nil {
break
}
return e.complexity.ValidateSessionResponse.User(childComplexity), true
case "VerificationRequest.created_at":
if e.complexity.VerificationRequest.CreatedAt == nil {
break
@@ -2367,6 +2375,7 @@ type ValidateJWTTokenResponse {
type ValidateSessionResponse {
is_valid: Boolean!
user: User!
}
type GenerateJWTKeysResponse {
@@ -10233,6 +10242,8 @@ func (ec *executionContext) fieldContext_Query_validate_session(ctx context.Cont
switch field.Name {
case "is_valid":
return ec.fieldContext_ValidateSessionResponse_is_valid(ctx, field)
case "user":
return ec.fieldContext_ValidateSessionResponse_user(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type ValidateSessionResponse", field.Name)
},
@@ -12562,6 +12573,92 @@ func (ec *executionContext) fieldContext_ValidateSessionResponse_is_valid(ctx co
return fc, nil
}
func (ec *executionContext) _ValidateSessionResponse_user(ctx context.Context, field graphql.CollectedField, obj *model.ValidateSessionResponse) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_ValidateSessionResponse_user(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.User, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
return graphql.Null
}
res := resTmp.(*model.User)
fc.Result = res
return ec.marshalNUser2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUser(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_ValidateSessionResponse_user(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "ValidateSessionResponse",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "id":
return ec.fieldContext_User_id(ctx, field)
case "email":
return ec.fieldContext_User_email(ctx, field)
case "email_verified":
return ec.fieldContext_User_email_verified(ctx, field)
case "signup_methods":
return ec.fieldContext_User_signup_methods(ctx, field)
case "given_name":
return ec.fieldContext_User_given_name(ctx, field)
case "family_name":
return ec.fieldContext_User_family_name(ctx, field)
case "middle_name":
return ec.fieldContext_User_middle_name(ctx, field)
case "nickname":
return ec.fieldContext_User_nickname(ctx, field)
case "preferred_username":
return ec.fieldContext_User_preferred_username(ctx, field)
case "gender":
return ec.fieldContext_User_gender(ctx, field)
case "birthdate":
return ec.fieldContext_User_birthdate(ctx, field)
case "phone_number":
return ec.fieldContext_User_phone_number(ctx, field)
case "phone_number_verified":
return ec.fieldContext_User_phone_number_verified(ctx, field)
case "picture":
return ec.fieldContext_User_picture(ctx, field)
case "roles":
return ec.fieldContext_User_roles(ctx, field)
case "created_at":
return ec.fieldContext_User_created_at(ctx, field)
case "updated_at":
return ec.fieldContext_User_updated_at(ctx, field)
case "revoked_timestamp":
return ec.fieldContext_User_revoked_timestamp(ctx, field)
case "is_multi_factor_auth_enabled":
return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field)
case "app_data":
return ec.fieldContext_User_app_data(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
}
return fc, nil
}
func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_VerificationRequest_id(ctx, field)
if err != nil {
@@ -19668,6 +19765,13 @@ func (ec *executionContext) _ValidateSessionResponse(ctx context.Context, sel as
out.Values[i] = ec._ValidateSessionResponse_is_valid(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}
case "user":
out.Values[i] = ec._ValidateSessionResponse_user(ctx, field, obj)
if out.Values[i] == graphql.Null {
invalids++
}

View File

@@ -469,7 +469,8 @@ type ValidateSessionInput struct {
}
type ValidateSessionResponse struct {
IsValid bool `json:"is_valid"`
IsValid bool `json:"is_valid"`
User *User `json:"user"`
}
type VerificationRequest struct {

View File

@@ -181,6 +181,7 @@ type ValidateJWTTokenResponse {
type ValidateSessionResponse {
is_valid: Boolean!
user: User!
}
type GenerateJWTKeysResponse {

View File

@@ -32,11 +32,11 @@ func OAuthCallbackHandler() gin.HandlerFunc {
return func(ctx *gin.Context) {
provider := ctx.Param("oauth_provider")
state := ctx.Request.FormValue("state")
sessionState, err := memorystore.Provider.GetState(state)
if sessionState == "" || err != nil {
log.Debug("Invalid oauth state: ", state)
ctx.JSON(400, gin.H{"error": "invalid oauth state"})
return
}
// contains random token, redirect url, role
sessionSplit := strings.Split(state, "___")
@@ -46,32 +46,34 @@ func OAuthCallbackHandler() gin.HandlerFunc {
ctx.JSON(400, gin.H{"error": "invalid redirect url"})
return
}
// remove state from store
go memorystore.Provider.RemoveState(state)
stateValue := sessionSplit[0]
redirectURL := sessionSplit[1]
inputRoles := strings.Split(sessionSplit[2], ",")
scopes := strings.Split(sessionSplit[3], ",")
var user *models.User
oauthCode := ctx.Request.FormValue("code")
if oauthCode == "" {
log.Debug("Invalid oauth code: ", oauthCode)
ctx.JSON(400, gin.H{"error": "invalid oauth code"})
return
}
switch provider {
case constants.AuthRecipeMethodGoogle:
user, err = processGoogleUserInfo(oauthCode)
user, err = processGoogleUserInfo(ctx, oauthCode)
case constants.AuthRecipeMethodGithub:
user, err = processGithubUserInfo(oauthCode)
user, err = processGithubUserInfo(ctx, oauthCode)
case constants.AuthRecipeMethodFacebook:
user, err = processFacebookUserInfo(oauthCode)
user, err = processFacebookUserInfo(ctx, oauthCode)
case constants.AuthRecipeMethodLinkedIn:
user, err = processLinkedInUserInfo(oauthCode)
user, err = processLinkedInUserInfo(ctx, oauthCode)
case constants.AuthRecipeMethodApple:
user, err = processAppleUserInfo(oauthCode)
user, err = processAppleUserInfo(ctx, oauthCode)
case constants.AuthRecipeMethodTwitter:
user, err = processTwitterUserInfo(oauthCode, sessionState)
user, err = processTwitterUserInfo(ctx, oauthCode, sessionState)
case constants.AuthRecipeMethodMicrosoft:
user, err = processMicrosoftUserInfo(oauthCode)
user, err = processMicrosoftUserInfo(ctx, oauthCode)
default:
log.Info("Invalid oauth provider")
err = fmt.Errorf(`invalid oauth provider`)
@@ -281,9 +283,8 @@ func OAuthCallbackHandler() gin.HandlerFunc {
}
}
func processGoogleUserInfo(code string) (*models.User, error) {
func processGoogleUserInfo(ctx context.Context, code string) (*models.User, error) {
var user *models.User
ctx := context.Background()
oauth2Token, err := oauth.OAuthProviders.GoogleConfig.Exchange(ctx, code)
if err != nil {
log.Debug("Failed to exchange code for token: ", err)
@@ -313,9 +314,9 @@ func processGoogleUserInfo(code string) (*models.User, error) {
return user, nil
}
func processGithubUserInfo(code string) (*models.User, error) {
func processGithubUserInfo(ctx context.Context, code string) (*models.User, error) {
var user *models.User
oauth2Token, err := oauth.OAuthProviders.GithubConfig.Exchange(context.TODO(), code)
oauth2Token, err := oauth.OAuthProviders.GithubConfig.Exchange(ctx, code)
if err != nil {
log.Debug("Failed to exchange code for token: ", err)
return user, fmt.Errorf("invalid github exchange code: %s", err.Error())
@@ -420,9 +421,9 @@ func processGithubUserInfo(code string) (*models.User, error) {
return user, nil
}
func processFacebookUserInfo(code string) (*models.User, error) {
func processFacebookUserInfo(ctx context.Context, code string) (*models.User, error) {
var user *models.User
oauth2Token, err := oauth.OAuthProviders.FacebookConfig.Exchange(context.TODO(), code)
oauth2Token, err := oauth.OAuthProviders.FacebookConfig.Exchange(ctx, code)
if err != nil {
log.Debug("Invalid facebook exchange code: ", err)
return user, fmt.Errorf("invalid facebook exchange code: %s", err.Error())
@@ -471,9 +472,9 @@ func processFacebookUserInfo(code string) (*models.User, error) {
return user, nil
}
func processLinkedInUserInfo(code string) (*models.User, error) {
func processLinkedInUserInfo(ctx context.Context, code string) (*models.User, error) {
var user *models.User
oauth2Token, err := oauth.OAuthProviders.LinkedInConfig.Exchange(context.TODO(), code)
oauth2Token, err := oauth.OAuthProviders.LinkedInConfig.Exchange(ctx, code)
if err != nil {
log.Debug("Failed to exchange code for token: ", err)
return user, fmt.Errorf("invalid linkedin exchange code: %s", err.Error())
@@ -553,9 +554,9 @@ func processLinkedInUserInfo(code string) (*models.User, error) {
return user, nil
}
func processAppleUserInfo(code string) (*models.User, error) {
func processAppleUserInfo(ctx context.Context, code string) (*models.User, error) {
var user *models.User
oauth2Token, err := oauth.OAuthProviders.AppleConfig.Exchange(context.TODO(), code)
oauth2Token, err := oauth.OAuthProviders.AppleConfig.Exchange(ctx, code)
if err != nil {
log.Debug("Failed to exchange code for token: ", err)
return user, fmt.Errorf("invalid apple exchange code: %s", err.Error())
@@ -606,9 +607,9 @@ func processAppleUserInfo(code string) (*models.User, error) {
return user, err
}
func processTwitterUserInfo(code, verifier string) (*models.User, error) {
func processTwitterUserInfo(ctx context.Context, code, verifier string) (*models.User, error) {
var user *models.User
oauth2Token, err := oauth.OAuthProviders.TwitterConfig.Exchange(context.TODO(), code, oauth2.SetAuthURLParam("code_verifier", verifier))
oauth2Token, err := oauth.OAuthProviders.TwitterConfig.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", verifier))
if err != nil {
log.Debug("Failed to exchange code for token: ", err)
return user, fmt.Errorf("invalid twitter exchange code: %s", err.Error())
@@ -674,24 +675,24 @@ func processTwitterUserInfo(code, verifier string) (*models.User, error) {
}
// process microsoft user information
func processMicrosoftUserInfo(code string) (*models.User, error) {
func processMicrosoftUserInfo(ctx context.Context, code string) (*models.User, error) {
var user *models.User
ctx := context.Background()
oauth2Token, err := oauth.OAuthProviders.MicrosoftConfig.Exchange(ctx, code)
if err != nil {
log.Debug("Failed to exchange code for token: ", err)
return user, fmt.Errorf("invalid google exchange code: %s", err.Error())
return user, fmt.Errorf("invalid microsoft exchange code: %s", err.Error())
}
verifier := oauth.OIDCProviders.MicrosoftOIDC.Verifier(&oidc.Config{ClientID: oauth.OAuthProviders.MicrosoftConfig.ClientID})
// we need to skip issuer check because for common tenant it will return internal issuer which does not match
verifier := oauth.OIDCProviders.MicrosoftOIDC.Verifier(&oidc.Config{
ClientID: oauth.OAuthProviders.MicrosoftConfig.ClientID,
SkipIssuerCheck: true,
})
// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
log.Debug("Failed to extract ID Token from OAuth2 token")
return user, fmt.Errorf("unable to extract id_token")
}
// Parse and verify ID Token payload.
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {

View File

@@ -37,8 +37,9 @@ func ValidateSessionResolver(ctx context.Context, params *model.ValidateSessionI
log := log.WithFields(log.Fields{
"user_id": userID,
})
_, err = db.Provider.GetUserByID(ctx, userID)
user, err := db.Provider.GetUserByID(ctx, userID)
if err != nil {
log.Debug("Failed to get user: ", err)
return nil, err
}
// refresh token has "roles" as claim
@@ -55,5 +56,6 @@ func ValidateSessionResolver(ctx context.Context, params *model.ValidateSessionI
}
return &model.ValidateSessionResponse{
IsValid: true,
User: user.AsAPIUser(),
}, nil
}

View File

@@ -69,7 +69,6 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod
user, err = db.Provider.GetUserByPhoneNumber(ctx, refs.StringValue(params.PhoneNumber))
}
if user == nil || err != nil {
fmt.Println("=> failing here....", err)
log.Debug("Failed to get user by email or phone number: ", err)
return res, err
}

View File

@@ -1,12 +1,18 @@
package test
import (
"fmt"
"strings"
"testing"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/resolvers"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
@@ -48,6 +54,17 @@ func mobileLoginTests(t *testing.T, s TestSetup) {
smsRequest, err := db.Provider.GetOTPByPhoneNumber(ctx, phoneNumber)
assert.NoError(t, err)
assert.NotEmpty(t, smsRequest.Otp)
// Get user by phone number
user, err := db.Provider.GetUserByPhoneNumber(ctx, phoneNumber)
assert.NoError(t, err)
assert.NotNil(t, user)
// Set mfa cookie session
mfaSession := uuid.NewString()
memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix())
cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession)
cookie = strings.TrimSuffix(cookie, ";")
req, ctx := createContext(s)
req.Header.Set("Cookie", cookie)
verifySMSRequest, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{
PhoneNumber: &phoneNumber,
Otp: smsRequest.Otp,

View File

@@ -1,7 +1,10 @@
package test
import (
"fmt"
"strings"
"testing"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
@@ -9,6 +12,7 @@ import (
"github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/resolvers"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
@@ -79,6 +83,17 @@ func mobileSingupTest(t *testing.T, s TestSetup) {
otp, err := db.Provider.GetOTPByPhoneNumber(ctx, phoneNumber)
assert.Nil(t, err)
assert.NotEmpty(t, otp.Otp)
// Get user by phone number
user, err := db.Provider.GetUserByPhoneNumber(ctx, phoneNumber)
assert.NoError(t, err)
assert.NotNil(t, user)
// Set mfa cookie session
mfaSession := uuid.NewString()
memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix())
cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession)
cookie = strings.TrimSuffix(cookie, ";")
req, ctx := createContext(s)
req.Header.Set("Cookie", cookie)
otpRes, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{
PhoneNumber: &phoneNumber,
Otp: otp.Otp,

View File

@@ -2,13 +2,18 @@ package test
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/resolvers"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
@@ -89,6 +94,16 @@ func resendOTPTest(t *testing.T, s TestSetup) {
})
assert.Error(t, err)
assert.Nil(t, verifyOtpRes)
// Get user by email
user, err := db.Provider.GetUserByEmail(ctx, email)
assert.NoError(t, err)
assert.NotNil(t, user)
// Set mfa cookie session
mfaSession := uuid.NewString()
memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix())
cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession)
cookie = strings.TrimSuffix(cookie, ";")
req.Header.Set("Cookie", cookie)
verifyOtpRes, err = resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{
Email: &email,
Otp: newOtp.Otp,

View File

@@ -56,6 +56,7 @@ func validateSessionTests(t *testing.T, s TestSetup) {
res, err = resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{})
assert.Nil(t, err)
assert.True(t, res.IsValid)
assert.Equal(t, res.User.ID, verifyRes.User.ID)
cleanData(email)
})
}

View File

@@ -2,13 +2,18 @@ package test
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/resolvers"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
@@ -63,7 +68,16 @@ func verifyOTPTest(t *testing.T, s TestSetup) {
otp, err := db.Provider.GetOTPByEmail(ctx, email)
assert.NoError(t, err)
assert.NotEmpty(t, otp.Otp)
// Get user by email
user, err := db.Provider.GetUserByEmail(ctx, email)
assert.NoError(t, err)
assert.NotNil(t, user)
// Set mfa cookie session
mfaSession := uuid.NewString()
memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix())
cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession)
cookie = strings.TrimSuffix(cookie, ";")
req.Header.Set("Cookie", cookie)
verifyOtpRes, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{
Email: &email,
Otp: otp.Otp,

View File

@@ -386,7 +386,6 @@ func CreateIDToken(user *models.User, roles []string, hostname, nonce, atHash, c
userBytes, _ := json.Marshal(&resUser)
var userMap map[string]interface{}
json.Unmarshal(userBytes, &userMap)
fmt.Println("=> userBytes", string(userBytes))
claimKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtRoleClaim)
if err != nil {
claimKey = "roles"