From ba0cf189de056ca3f5fe1dab6af9e2b042638a5d Mon Sep 17 00:00:00 2001 From: catusax Date: Mon, 24 Jul 2023 11:58:36 +0800 Subject: [PATCH] userid ass mfa session key --- server/memorystore/providers/inmemory/store.go | 14 +++++++------- server/memorystore/providers/providers.go | 8 ++++---- server/memorystore/providers/redis/store.go | 16 ++++++++-------- server/resolvers/login.go | 2 +- server/resolvers/mobile_login.go | 12 +++++++++++- server/resolvers/verify_otp.go | 13 +++++++++++++ 6 files changed, 44 insertions(+), 21 deletions(-) diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go index d03a9df..b20fb62 100644 --- a/server/memorystore/providers/inmemory/store.go +++ b/server/memorystore/providers/inmemory/store.go @@ -42,15 +42,15 @@ func (c *provider) DeleteSessionForNamespace(namespace string) error { return nil } -// SetMfaSession sets the mfa session with key and value of email -func (c *provider) SetMfaSession(email, key string, expiration int64) error { - c.mfasessionStore.Set(email, key, email, expiration) +// SetMfaSession sets the mfa session with key and value of userId +func (c *provider) SetMfaSession(userId, key string, expiration int64) error { + c.mfasessionStore.Set(userId, key, userId, expiration) return nil } // GetMfaSession returns value of given mfa session -func (c *provider) GetMfaSession(email, key string) (string, error) { - val := c.mfasessionStore.Get(email, key) +func (c *provider) GetMfaSession(userId, key string) (string, error) { + val := c.mfasessionStore.Get(userId, key) if val == "" { return "", fmt.Errorf("Not found") } @@ -58,8 +58,8 @@ func (c *provider) GetMfaSession(email, key string) (string, error) { } // DeleteMfaSession deletes given mfa session from in-memory store. -func (c *provider) DeleteMfaSession(email, key string) error { - c.mfasessionStore.Remove(email, key) +func (c *provider) DeleteMfaSession(userId, key string) error { + c.mfasessionStore.Remove(userId, key) return nil } diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index 6b3eba0..331e34a 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -12,12 +12,12 @@ type Provider interface { DeleteAllUserSessions(userId string) error // DeleteSessionForNamespace deletes the session for a given namespace DeleteSessionForNamespace(namespace string) error - // SetMfaSession sets the mfa session with key and value of email - SetMfaSession(email, key string, expiration int64) error + // SetMfaSession sets the mfa session with key and value of userId + SetMfaSession(userId, key string, expiration int64) error // GetMfaSession returns value of given mfa session - GetMfaSession(email, key string) (string, error) + GetMfaSession(userId, key string) (string, error) // DeleteMfaSession deletes given mfa session from in-memory store. - DeleteMfaSession(email, key string) error + DeleteMfaSession(userId, key string) error // SetState sets the login state (key, value form) in the session store SetState(key, state string) error diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index d42e2c0..a6ff08f 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -93,12 +93,12 @@ func (c *provider) DeleteSessionForNamespace(namespace string) error { return nil } -// SetMfaSession sets the mfa session with key and value of email -func (c *provider) SetMfaSession(email, key string, expiration int64) error { +// SetMfaSession sets the mfa session with key and value of userId +func (c *provider) SetMfaSession(userId, key string, expiration int64) error { currentTime := time.Now() expireTime := time.Unix(expiration, 0) duration := expireTime.Sub(currentTime) - err := c.store.Set(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, email, key), email, duration).Err() + err := c.store.Set(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, userId, key), userId, duration).Err() if err != nil { log.Debug("Error saving user session to redis: ", err) return err @@ -106,9 +106,9 @@ func (c *provider) SetMfaSession(email, key string, expiration int64) error { return nil } - // GetMfaSession returns value of given mfa session -func (c *provider) GetMfaSession(email, key string) (string, error) { - data, err := c.store.Get(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, email, key)).Result() +// GetMfaSession returns value of given mfa session +func (c *provider) GetMfaSession(userId, key string) (string, error) { + data, err := c.store.Get(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, userId, key)).Result() if err != nil { return "", err } @@ -116,8 +116,8 @@ func (c *provider) GetMfaSession(email, key string) (string, error) { } // DeleteMfaSession deletes given mfa session from in-memory store. -func (c *provider) DeleteMfaSession(email, key string) error { - if err := c.store.Del(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, email, key)).Err(); err != nil { +func (c *provider) DeleteMfaSession(userId, key string) error { + if err := c.store.Del(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, userId, key)).Err(); err != nil { log.Debug("Error deleting user session from redis: ", err) // continue } diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 299c6ef..c588ca7 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -125,7 +125,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes } mfaSession := uuid.NewString() - err = memorystore.Provider.SetMfaSession(params.Email, mfaSession, expires) + err = memorystore.Provider.SetMfaSession(user.ID, mfaSession, expires) if err != nil { log.Debug("Failed to add mfasession: ", err) return nil, err diff --git a/server/resolvers/mobile_login.go b/server/resolvers/mobile_login.go index fc131b0..89c3825 100644 --- a/server/resolvers/mobile_login.go +++ b/server/resolvers/mobile_login.go @@ -122,15 +122,25 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m smsBody := strings.Builder{} smsBody.WriteString("Your verification code is: ") smsBody.WriteString(smsCode) + expires := time.Now().Add(duration).Unix() _, err := db.Provider.UpsertOTP(ctx, &models.OTP{ PhoneNumber: params.PhoneNumber, Otp: smsCode, - ExpiresAt: time.Now().Add(duration).Unix(), + ExpiresAt: expires, }) if err != nil { log.Debug("error while upserting OTP: ", err.Error()) return nil, err } + + mfaSession := uuid.NewString() + err = memorystore.Provider.SetMfaSession(user.ID, mfaSession, expires) + if err != nil { + log.Debug("Failed to add mfasession: ", err) + return nil, err + } + cookie.SetMfaSession(gc, mfaSession) + go func() { utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodMobileBasicAuth, *user) smsproviders.SendSMS(params.PhoneNumber, smsBody.String()) diff --git a/server/resolvers/verify_otp.go b/server/resolvers/verify_otp.go index 124ee3f..2982859 100644 --- a/server/resolvers/verify_otp.go +++ b/server/resolvers/verify_otp.go @@ -27,6 +27,13 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod log.Debug("Failed to get GinContext: ", err) return res, err } + + mfaSession, err := cookie.GetMfaSession(gc) + if err != nil { + log.Debug("Failed to get otp request by email: ", err) + return res, fmt.Errorf(`invalid session: %s`, err.Error()) + } + if refs.StringValue(params.Email) == "" && refs.StringValue(params.PhoneNumber) == "" { log.Debug("Email or phone number is required") return res, fmt.Errorf(`email or phone_number is required`) @@ -68,6 +75,12 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod log.Debug("Failed to get user by email: ", err) return res, err } + + if _, err := memorystore.Provider.GetMfaSession(user.ID, mfaSession); err != nil { + log.Debug("Failed to get mfa session: ", err) + return res, fmt.Errorf(`invalid session: %s`, err.Error()) + } + isSignUp := user.EmailVerifiedAt == nil && user.PhoneNumberVerifiedAt == nil // TODO - Add Login method in DB when we introduce OTP for social media login loginMethod := constants.AuthRecipeMethodBasicAuth