Compare commits

...

39 Commits

Author SHA1 Message Date
Lakhan Samani
f2886e6da8 fix: disable other db test for quick test 2022-07-12 11:57:46 +05:30
Lakhan Samani
6b57bce6d9 fix: cassandra + mongo + arangodb issues with webhook 2022-07-12 11:48:42 +05:30
Lakhan Samani
bfbeb6add2 fix: couple session deletion with user deletion 2022-07-12 08:42:32 +05:30
Lakhan Samani
1fe0d65874 feat: add support for planetscale
Resolves #195
2022-07-11 22:37:07 +05:30
Lakhan Samani
bfaa0f9d89 fix: make list webhooks params optional 2022-07-11 22:05:44 +05:30
Lakhan Samani
4f5a6c77f8 Merge pull request #194 from authorizerdev/feat/webhook
feat: add webhook apis + integrate in events
2022-07-11 19:56:48 +05:30
Lakhan Samani
018a13ab3c feat: add tests for webhook resolvers 2022-07-11 19:40:54 +05:30
Lakhan Samani
334041d0e4 fix: delete user event flow 2022-07-11 11:13:32 +05:30
Lakhan Samani
6a8309a231 feat: register event for revoke/enable access + delete user 2022-07-11 11:12:30 +05:30
Lakhan Samani
6347b60753 fix: rename revoke refresh token handler for better reading 2022-07-11 11:10:30 +05:30
Lakhan Samani
bbb064b939 feat: add register event 2022-07-11 10:42:42 +05:30
Lakhan Samani
e91a819067 feat: implement resolvers 2022-07-10 21:49:33 +05:30
Lakhan Samani
09c3eafe6b feat: add mongodb database methods for webhook 2022-07-09 12:23:48 +05:30
Lakhan Samani
bb51775d34 feat: add cassandradb database methods for webhook 2022-07-09 12:16:54 +05:30
Lakhan Samani
6d586b16e4 feat: add arangodb database methods for webhook 2022-07-09 11:44:14 +05:30
Lakhan Samani
e8eb62769e feat: add sql database methods for webhook 2022-07-09 11:21:32 +05:30
Lakhan Samani
0ffb3f67f1 fix: index for arangodb 2022-07-08 19:10:37 +05:30
Lakhan Samani
ec62686fbc feat: add database methods for webhookLog 2022-07-08 19:09:23 +05:30
Lakhan Samani
a8064e79a1 feat: add template for webhook db methods 2022-07-06 10:38:21 +05:30
Lakhan Samani
265331801f feat: add database models 2022-07-04 22:37:13 +05:30
Lakhan Samani
6a74a50493 feat: add support for scylladb
Resolves #177
2022-07-04 21:57:14 +05:30
Lakhan Samani
8c27f20534 Merge pull request #193 from authorizerdev/fix/invalidate-session-token
fix: add provider to token creation
2022-07-01 22:14:50 +05:30
Lakhan Samani
29c6003ea3 fix: remove test log 2022-07-01 22:04:58 +05:30
Lakhan Samani
ae34fc7c2b fix: update_env resolver 2022-07-01 22:02:34 +05:30
Lakhan Samani
2a5d5d43b0 fix: add namespace to session token keys 2022-06-29 22:24:00 +05:30
Lakhan Samani
e6a4670ba9 fix: add provider to token creation 2022-06-29 09:54:12 +05:30
Lakhan Samani
64d64b4099 feat: add ability to disable strong password 2022-06-18 15:31:57 +05:30
Lakhan Samani
88f9a10f21 Merge pull request #192 from authorizerdev/feat/apple-login
feat: add apple login
2022-06-16 09:48:02 +05:30
Lakhan Samani
4e08d4f8fd fix: update app 2022-06-16 09:47:16 +05:30
Lakhan Samani
1c4dda9299 fix: remove debug logs 2022-06-16 07:34:10 +05:30
Lakhan Samani
ab18fa5832 fix: use raw base64 url decoding 2022-06-15 21:55:41 +05:30
Lakhan Samani
484d0c0882 chore: update app 2022-06-14 16:39:06 +05:30
Lakhan Samani
be59c3615f fix: add comment for scope 2022-06-14 15:47:08 +05:30
Lakhan Samani
db351f7771 fix: remove debug logs 2022-06-14 15:45:06 +05:30
Lakhan Samani
91c29c4092 fix: redirect 2022-06-14 15:43:23 +05:30
Lakhan Samani
415b97535e fix: update scope param 2022-06-14 15:05:56 +05:30
Lakhan Samani
7d1272d815 fix: update scope for apple login 2022-06-14 14:41:31 +05:30
Lakhan Samani
c9ba0b13f8 fix: update scope for apple login 2022-06-14 13:37:05 +05:30
Lakhan Samani
fadd9f6168 fix: update scope for apple login 2022-06-14 13:11:39 +05:30
133 changed files with 5266 additions and 602 deletions

30
app/package-lock.json generated
View File

@@ -9,7 +9,7 @@
"version": "1.0.0",
"license": "ISC",
"dependencies": {
"@authorizerdev/authorizer-react": "^0.24.0-beta.1",
"@authorizerdev/authorizer-react": "^0.25.0",
"@types/react": "^17.0.15",
"@types/react-dom": "^17.0.9",
"esbuild": "^0.12.17",
@@ -26,9 +26,9 @@
}
},
"node_modules/@authorizerdev/authorizer-js": {
"version": "0.13.0-beta.2",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.13.0-beta.2.tgz",
"integrity": "sha512-3K3Qew/npD375DQdJnYb43aWVOU7giG2+8sHOjy8aXhZ+GlQQ8cGu54lozFYUMDcM1HjQU2N53xLmneONPepSw==",
"version": "0.14.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.14.0.tgz",
"integrity": "sha512-cpeeFrmG623QPLn+nf+ACHayZYqW8xokIidGikeboBDJtuAAQB50a54/7HwLHriG2FB7WvPuHQ/9LFFX//N1lg==",
"dependencies": {
"node-fetch": "^2.6.1"
},
@@ -37,11 +37,11 @@
}
},
"node_modules/@authorizerdev/authorizer-react": {
"version": "0.24.0-beta.1",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.24.0-beta.1.tgz",
"integrity": "sha512-S/Oqc24EfotbrABuv379i/3uCfQPYJQqXrOU9d8AytF++pzG/2dcoIoaMbWZQkATR3m6a5AnhpG6bIB+4NbrUQ==",
"version": "0.25.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.25.0.tgz",
"integrity": "sha512-Dt2rZf+cGCVb8dqcJ/9l8Trx+QeXnTdfhER6r/cq0iOnFC9MqWzQPB3RgrlUoMLHtZvKNDXIk1HvfD5hSX9lhw==",
"dependencies": {
"@authorizerdev/authorizer-js": "^0.13.0-beta.2",
"@authorizerdev/authorizer-js": "^0.14.0",
"final-form": "^4.20.2",
"react-final-form": "^6.5.3",
"styled-components": "^5.3.0"
@@ -852,19 +852,19 @@
},
"dependencies": {
"@authorizerdev/authorizer-js": {
"version": "0.13.0-beta.2",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.13.0-beta.2.tgz",
"integrity": "sha512-3K3Qew/npD375DQdJnYb43aWVOU7giG2+8sHOjy8aXhZ+GlQQ8cGu54lozFYUMDcM1HjQU2N53xLmneONPepSw==",
"version": "0.14.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.14.0.tgz",
"integrity": "sha512-cpeeFrmG623QPLn+nf+ACHayZYqW8xokIidGikeboBDJtuAAQB50a54/7HwLHriG2FB7WvPuHQ/9LFFX//N1lg==",
"requires": {
"node-fetch": "^2.6.1"
}
},
"@authorizerdev/authorizer-react": {
"version": "0.24.0-beta.1",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.24.0-beta.1.tgz",
"integrity": "sha512-S/Oqc24EfotbrABuv379i/3uCfQPYJQqXrOU9d8AytF++pzG/2dcoIoaMbWZQkATR3m6a5AnhpG6bIB+4NbrUQ==",
"version": "0.25.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.25.0.tgz",
"integrity": "sha512-Dt2rZf+cGCVb8dqcJ/9l8Trx+QeXnTdfhER6r/cq0iOnFC9MqWzQPB3RgrlUoMLHtZvKNDXIk1HvfD5hSX9lhw==",
"requires": {
"@authorizerdev/authorizer-js": "^0.13.0-beta.2",
"@authorizerdev/authorizer-js": "^0.14.0",
"final-form": "^4.20.2",
"react-final-form": "^6.5.3",
"styled-components": "^5.3.0"

View File

@@ -11,7 +11,7 @@
"author": "Lakhan Samani",
"license": "ISC",
"dependencies": {
"@authorizerdev/authorizer-react": "^0.24.0-beta.1",
"@authorizerdev/authorizer-react": "^0.25.0",
"@types/react": "^17.0.15",
"@types/react-dom": "^17.0.9",
"esbuild": "^0.12.17",

View File

@@ -2529,7 +2529,8 @@
"@chakra-ui/css-reset": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@chakra-ui/css-reset/-/css-reset-1.1.1.tgz",
"integrity": "sha512-+KNNHL4OWqeKia5SL858K3Qbd8WxMij9mWIilBzLD4j2KFrl/+aWFw8syMKth3NmgIibrjsljo+PU3fy2o50dg=="
"integrity": "sha512-+KNNHL4OWqeKia5SL858K3Qbd8WxMij9mWIilBzLD4j2KFrl/+aWFw8syMKth3NmgIibrjsljo+PU3fy2o50dg==",
"requires": {}
},
"@chakra-ui/descendant": {
"version": "2.1.1",
@@ -3133,7 +3134,8 @@
"@graphql-typed-document-node/core": {
"version": "3.1.1",
"resolved": "https://registry.npmjs.org/@graphql-typed-document-node/core/-/core-3.1.1.tgz",
"integrity": "sha512-NQ17ii0rK1b34VZonlmT2QMJFI70m0TRwbknO/ihlbatXyaktDhN/98vBiUU6kNBPljqGqyIrl2T4nY2RpFANg=="
"integrity": "sha512-NQ17ii0rK1b34VZonlmT2QMJFI70m0TRwbknO/ihlbatXyaktDhN/98vBiUU6kNBPljqGqyIrl2T4nY2RpFANg==",
"requires": {}
},
"@popperjs/core": {
"version": "2.11.0",
@@ -3843,7 +3845,8 @@
"react-icons": {
"version": "4.3.1",
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.3.1.tgz",
"integrity": "sha512-cB10MXLTs3gVuXimblAdI71jrJx8njrJZmNMEMC+sQu5B/BIOmlsAjskdqpn81y8UBVEGuHODd7/ci5DvoSzTQ=="
"integrity": "sha512-cB10MXLTs3gVuXimblAdI71jrJx8njrJZmNMEMC+sQu5B/BIOmlsAjskdqpn81y8UBVEGuHODd7/ci5DvoSzTQ==",
"requires": {}
},
"react-is": {
"version": "16.13.1",
@@ -4029,7 +4032,8 @@
"use-callback-ref": {
"version": "1.2.5",
"resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.2.5.tgz",
"integrity": "sha512-gN3vgMISAgacF7sqsLPByqoePooY3n2emTH59Ur5d/M8eg4WTWu1xp8i8DHjohftIyEx0S08RiYxbffr4j8Peg=="
"integrity": "sha512-gN3vgMISAgacF7sqsLPByqoePooY3n2emTH59Ur5d/M8eg4WTWu1xp8i8DHjohftIyEx0S08RiYxbffr4j8Peg==",
"requires": {}
},
"use-sidecar": {
"version": "1.0.5",

View File

@@ -71,6 +71,18 @@ const Features = ({ variables, setVariables }: any) => {
/>
</Flex>
</Flex>
<Flex>
<Flex w="100%" justifyContent="start" alignItems="center">
<Text fontSize="sm">Disable Strong Password:</Text>
</Flex>
<Flex justifyContent="start" mb={3}>
<InputField
variables={variables}
setVariables={setVariables}
inputType={SwitchInputType.DISABLE_STRONG_PASSWORD}
/>
</Flex>
</Flex>
</Stack>
</div>
);

View File

@@ -67,6 +67,7 @@ export const SwitchInputType = {
DISABLE_BASIC_AUTHENTICATION: 'DISABLE_BASIC_AUTHENTICATION',
DISABLE_SIGN_UP: 'DISABLE_SIGN_UP',
DISABLE_REDIS_FOR_ENV: 'DISABLE_REDIS_FOR_ENV',
DISABLE_STRONG_PASSWORD: 'DISABLE_STRONG_PASSWORD',
};
export const DateInputType = {
@@ -131,6 +132,7 @@ export interface envVarTypes {
DISABLE_EMAIL_VERIFICATION: boolean;
DISABLE_BASIC_AUTHENTICATION: boolean;
DISABLE_SIGN_UP: boolean;
DISABLE_STRONG_PASSWORD: boolean;
OLD_ADMIN_SECRET: string;
DATABASE_NAME: string;
DATABASE_TYPE: string;

View File

@@ -53,6 +53,7 @@ export const EnvVariablesQuery = `
DISABLE_EMAIL_VERIFICATION,
DISABLE_BASIC_AUTHENTICATION,
DISABLE_SIGN_UP,
DISABLE_STRONG_PASSWORD,
DISABLE_REDIS_FOR_ENV,
CUSTOM_ACCESS_TOKEN_SCRIPT,
DATABASE_NAME,

View File

@@ -74,6 +74,7 @@ const Environment = () => {
DISABLE_EMAIL_VERIFICATION: false,
DISABLE_BASIC_AUTHENTICATION: false,
DISABLE_SIGN_UP: false,
DISABLE_STRONG_PASSWORD: false,
OLD_ADMIN_SECRET: '',
DATABASE_NAME: '',
DATABASE_TYPE: '',

View File

@@ -0,0 +1,18 @@
package constants
const (
// AuthRecipeMethodBasicAuth is the basic_auth auth method
AuthRecipeMethodBasicAuth = "basic_auth"
// AuthRecipeMethodMagicLinkLogin is the magic_link_login auth method
AuthRecipeMethodMagicLinkLogin = "magic_link_login"
// AuthRecipeMethodGoogle is the google auth method
AuthRecipeMethodGoogle = "google"
// AuthRecipeMethodGithub is the github auth method
AuthRecipeMethodGithub = "github"
// AuthRecipeMethodFacebook is the facebook auth method
AuthRecipeMethodFacebook = "facebook"
// AuthRecipeMethodLinkedin is the linkedin auth method
AuthRecipeMethodLinkedIn = "linkedin"
// AuthRecipeMethodApple is the apple auth method
AuthRecipeMethodApple = "apple"
)

View File

@@ -23,4 +23,6 @@ const (
DbTypeScyllaDB = "scylladb"
// DbTypeCockroachDB is the cockroach database type
DbTypeCockroachDB = "cockroachdb"
// DbTypePlanetScaleDB is the planetscale database type
DbTypePlanetScaleDB = "planetscale"
)

View File

@@ -115,6 +115,8 @@ const (
EnvKeyDisableSignUp = "DISABLE_SIGN_UP"
// EnvKeyDisableRedisForEnv key for env variable DISABLE_REDIS_FOR_ENV
EnvKeyDisableRedisForEnv = "DISABLE_REDIS_FOR_ENV"
// EnvKeyDisableStrongPassword key for env variable DISABLE_STRONG_PASSWORD
EnvKeyDisableStrongPassword = "DISABLE_STRONG_PASSWORD"
// Slice variables
// EnvKeyRoles key for env variable ROLES

View File

@@ -1,18 +0,0 @@
package constants
const (
// SignupMethodBasicAuth is the basic_auth signup method
SignupMethodBasicAuth = "basic_auth"
// SignupMethodMagicLinkLogin is the magic_link_login signup method
SignupMethodMagicLinkLogin = "magic_link_login"
// SignupMethodGoogle is the google signup method
SignupMethodGoogle = "google"
// SignupMethodGithub is the github signup method
SignupMethodGithub = "github"
// SignupMethodFacebook is the facebook signup method
SignupMethodFacebook = "facebook"
// SignupMethodLinkedin is the linkedin signup method
SignupMethodLinkedIn = "linkedin"
// SignupMethodApple is the apple signup method
SignupMethodApple = "apple"
)

View File

@@ -0,0 +1,18 @@
package constants
const (
// UserLoginWebhookEvent name for login event
UserLoginWebhookEvent = `user.login`
// UserCreatedWebhookEvent name for user creation event
// This is triggered when user entry is created but still not verified
UserCreatedWebhookEvent = `user.created`
// UserSignUpWebhookEvent name for signup event
UserSignUpWebhookEvent = `user.signup`
// UserAccessRevokedWebhookEvent name for user access revoke event
UserAccessRevokedWebhookEvent = `user.access_revoked`
// UserAccessEnabledWebhookEvent name for user access enable event
UserAccessEnabledWebhookEvent = `user.access_enabled`
// UserDeletedWebhookEvent name for user deleted event
UserDeletedWebhookEvent = `user.deleted`
)

View File

@@ -6,6 +6,8 @@ type CollectionList struct {
VerificationRequest string
Session string
Env string
Webhook string
WebhookLog string
}
var (
@@ -17,5 +19,7 @@ var (
VerificationRequest: Prefix + "verification_requests",
Session: Prefix + "sessions",
Env: Prefix + "env",
Webhook: Prefix + "webhook",
WebhookLog: Prefix + "webhook_log",
}
)

View File

@@ -6,8 +6,7 @@ package models
type Session struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
UserID string `gorm:"type:char(36),index:" json:"user_id" bson:"user_id" cql:"user_id"`
User User `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"-" bson:"-" cql:"-"`
UserID string `gorm:"type:char(36)" json:"user_id" bson:"user_id" cql:"user_id"`
UserAgent string `json:"user_agent" bson:"user_agent" cql:"user_agent"`
IP string `json:"ip" bson:"ip" cql:"ip"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`

View File

@@ -38,8 +38,13 @@ func (user *User) AsAPIUser() *model.User {
email := user.Email
createdAt := user.CreatedAt
updatedAt := user.UpdatedAt
id := user.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.User{
ID: user.ID,
ID: id,
Email: user.Email,
EmailVerified: isEmailVerified,
SignupMethods: user.SignupMethods,

View File

@@ -1,6 +1,10 @@
package models
import "github.com/authorizerdev/authorizer/server/graph/model"
import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
@@ -27,8 +31,13 @@ func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequ
redirectURI := v.RedirectURI
expires := v.ExpiresAt
identifier := v.Identifier
id := v.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.VerificationRequest{
ID: v.ID,
ID: id,
Token: &token,
Identifier: &identifier,
Expires: &expires,

View File

@@ -0,0 +1,41 @@
package models
import (
"encoding/json"
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
// Webhook model for db
type Webhook struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
EventName string `gorm:"unique" json:"event_name" bson:"event_name" cql:"event_name"`
EndPoint string `gorm:"type:text" json:"endpoint" bson:"endpoint" cql:"endpoint"`
Headers string `gorm:"type:text" json:"headers" bson:"headers" cql:"headers"`
Enabled bool `json:"enabled" bson:"enabled" cql:"enabled"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"`
}
func (w *Webhook) AsAPIWebhook() *model.Webhook {
headersMap := make(map[string]interface{})
json.Unmarshal([]byte(w.Headers), &headersMap)
id := w.ID
if strings.Contains(id, Collections.Webhook+"/") {
id = strings.TrimPrefix(id, Collections.Webhook+"/")
}
return &model.Webhook{
ID: id,
EventName: &w.EventName,
Endpoint: &w.EndPoint,
Headers: headersMap,
Enabled: &w.Enabled,
CreatedAt: &w.CreatedAt,
UpdatedAt: &w.UpdatedAt,
}
}

View File

@@ -0,0 +1,37 @@
package models
import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
// WebhookLog model for db
type WebhookLog struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
HttpStatus int64 `json:"http_status" bson:"http_status" cql:"http_status"`
Response string `gorm:"type:text" json:"response" bson:"response" cql:"response"`
Request string `gorm:"type:text" json:"request" bson:"request" cql:"request"`
WebhookID string `gorm:"type:char(36)" json:"webhook_id" bson:"webhook_id" cql:"webhook_id"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"`
}
func (w *WebhookLog) AsAPIWebhookLog() *model.WebhookLog {
id := w.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.WebhookLog{
ID: id,
HTTPStatus: &w.HttpStatus,
Response: &w.Response,
Request: &w.Request,
WebhookID: &w.WebhookID,
CreatedAt: &w.CreatedAt,
UpdatedAt: &w.UpdatedAt,
}
}

View File

@@ -1,6 +1,7 @@
package arangodb
import (
"context"
"fmt"
"time"
@@ -11,15 +12,15 @@ import (
)
// AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) {
func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" {
env.ID = uuid.New().String()
}
env.CreatedAt = time.Now().Unix()
env.UpdatedAt = time.Now().Unix()
configCollection, _ := p.db.Collection(nil, models.Collections.Env)
meta, err := configCollection.CreateDocument(arangoDriver.WithOverwrite(nil), env)
configCollection, _ := p.db.Collection(ctx, models.Collections.Env)
meta, err := configCollection.CreateDocument(arangoDriver.WithOverwrite(ctx), env)
if err != nil {
return env, err
}
@@ -29,10 +30,10 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
}
// UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix()
collection, _ := p.db.Collection(nil, models.Collections.Env)
meta, err := collection.UpdateDocument(nil, env.Key, env)
collection, _ := p.db.Collection(ctx, models.Collections.Env)
meta, err := collection.UpdateDocument(ctx, env.Key, env)
if err != nil {
return env, err
}
@@ -43,11 +44,11 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
}
// GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) {
func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env
query := fmt.Sprintf("FOR d in %s RETURN d", models.Collections.Env)
cursor, err := p.db.Query(nil, query, nil)
cursor, err := p.db.Query(ctx, query, nil)
if err != nil {
return env, err
}
@@ -60,7 +61,7 @@ func (p *provider) GetEnv() (models.Env, error) {
}
break
}
_, err := cursor.ReadDocument(nil, &env)
_, err := cursor.ReadDocument(ctx, &env)
if err != nil {
return env, err
}

View File

@@ -107,6 +107,33 @@ func NewProvider() (*provider, error) {
}
}
webhookCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Webhook)
if !webhookCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.Webhook, nil)
if err != nil {
return nil, err
}
}
webhookCollection, _ := arangodb.Collection(nil, models.Collections.Webhook)
webhookCollection.EnsureHashIndex(ctx, []string{"event_name"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
})
webhookLogCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.WebhookLog)
if !webhookLogCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.WebhookLog, nil)
if err != nil {
return nil, err
}
}
webhookLogCollection, _ := arangodb.Collection(nil, models.Collections.WebhookLog)
webhookLogCollection.EnsureHashIndex(ctx, []string{"webhook_id"}, &arangoDriver.EnsureHashIndexOptions{
Sparse: true,
})
return &provider{
db: arangodb,
}, err

View File

@@ -1,7 +1,7 @@
package arangodb
import (
"fmt"
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -9,31 +9,17 @@ import (
)
// AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error {
func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" {
session.ID = uuid.New().String()
}
session.CreatedAt = time.Now().Unix()
session.UpdatedAt = time.Now().Unix()
sessionCollection, _ := p.db.Collection(nil, models.Collections.Session)
_, err := sessionCollection.CreateDocument(nil, session)
sessionCollection, _ := p.db.Collection(ctx, models.Collections.Session)
_, err := sessionCollection.CreateDocument(ctx, session)
if err != nil {
return err
}
return nil
}
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @userId REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session)
bindVars := map[string]interface{}{
"userId": userId,
}
cursor, err := p.db.Query(nil, query, bindVars)
if err != nil {
return err
}
defer cursor.Close()
return nil
}

View File

@@ -15,7 +15,7 @@ import (
)
// AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) {
func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" {
user.ID = uuid.New().String()
}
@@ -30,8 +30,8 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
user.CreatedAt = time.Now().Unix()
user.UpdatedAt = time.Now().Unix()
userCollection, _ := p.db.Collection(nil, models.Collections.User)
meta, err := userCollection.CreateDocument(arangoDriver.WithOverwrite(nil), user)
userCollection, _ := p.db.Collection(ctx, models.Collections.User)
meta, err := userCollection.CreateDocument(arangoDriver.WithOverwrite(ctx), user)
if err != nil {
return user, err
}
@@ -42,10 +42,10 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
}
// UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) {
func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
collection, _ := p.db.Collection(nil, models.Collections.User)
meta, err := collection.UpdateDocument(nil, user.Key, user)
collection, _ := p.db.Collection(ctx, models.Collections.User)
meta, err := collection.UpdateDocument(ctx, user.Key, user)
if err != nil {
return user, err
}
@@ -56,24 +56,34 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
}
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error {
collection, _ := p.db.Collection(nil, models.Collections.User)
_, err := collection.RemoveDocument(nil, user.Key)
func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
collection, _ := p.db.Collection(ctx, models.Collections.User)
_, err := collection.RemoveDocument(ctx, user.Key)
if err != nil {
return err
}
query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @user_id REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session)
bindVars := map[string]interface{}{
"user_id": user.ID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return err
}
defer cursor.Close()
return nil
}
// ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) {
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []*model.User
ctx := driver.WithQueryFullCount(context.Background())
sctx := driver.WithQueryFullCount(ctx)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.User, pagination.Offset, pagination.Limit)
cursor, err := p.db.Query(ctx, query, nil)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err
}
@@ -84,7 +94,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
for {
var user models.User
meta, err := cursor.ReadDocument(nil, &user)
meta, err := cursor.ReadDocument(ctx, &user)
if arangoDriver.IsNoMoreDocuments(err) {
break
@@ -104,7 +114,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
}
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) {
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User
query := fmt.Sprintf("FOR d in %s FILTER d.email == @email RETURN d", models.Collections.User)
@@ -112,7 +122,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
"email": email,
}
cursor, err := p.db.Query(nil, query, bindVars)
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return user, err
}
@@ -125,7 +135,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
}
break
}
_, err := cursor.ReadDocument(nil, &user)
_, err := cursor.ReadDocument(ctx, &user)
if err != nil {
return user, err
}
@@ -135,7 +145,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) {
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User
query := fmt.Sprintf("FOR d in %s FILTER d._id == @id LIMIT 1 RETURN d", models.Collections.User)
@@ -143,7 +153,7 @@ func (p *provider) GetUserByID(id string) (models.User, error) {
"id": id,
}
cursor, err := p.db.Query(nil, query, bindVars)
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return user, err
}
@@ -156,7 +166,7 @@ func (p *provider) GetUserByID(id string) (models.User, error) {
}
break
}
_, err := cursor.ReadDocument(nil, &user)
_, err := cursor.ReadDocument(ctx, &user)
if err != nil {
return user, err
}

View File

@@ -12,15 +12,15 @@ import (
)
// AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String()
}
verificationRequest.CreatedAt = time.Now().Unix()
verificationRequest.UpdatedAt = time.Now().Unix()
verificationRequestRequestCollection, _ := p.db.Collection(nil, models.Collections.VerificationRequest)
meta, err := verificationRequestRequestCollection.CreateDocument(nil, verificationRequest)
verificationRequestRequestCollection, _ := p.db.Collection(ctx, models.Collections.VerificationRequest)
meta, err := verificationRequestRequestCollection.CreateDocument(ctx, verificationRequest)
if err != nil {
return verificationRequest, err
}
@@ -31,14 +31,14 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
}
// GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
query := fmt.Sprintf("FOR d in %s FILTER d.token == @token LIMIT 1 RETURN d", models.Collections.VerificationRequest)
bindVars := map[string]interface{}{
"token": token,
}
cursor, err := p.db.Query(nil, query, bindVars)
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return verificationRequest, err
}
@@ -51,7 +51,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
}
break
}
_, err := cursor.ReadDocument(nil, &verificationRequest)
_, err := cursor.ReadDocument(ctx, &verificationRequest)
if err != nil {
return verificationRequest, err
}
@@ -61,7 +61,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
}
// GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
query := fmt.Sprintf("FOR d in %s FILTER d.email == @email FILTER d.identifier == @identifier LIMIT 1 RETURN d", models.Collections.VerificationRequest)
@@ -70,7 +70,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
"identifier": identifier,
}
cursor, err := p.db.Query(nil, query, bindVars)
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return verificationRequest, err
}
@@ -83,7 +83,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
}
break
}
_, err := cursor.ReadDocument(nil, &verificationRequest)
_, err := cursor.ReadDocument(ctx, &verificationRequest)
if err != nil {
return verificationRequest, err
}
@@ -93,12 +93,12 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
}
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) {
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest
ctx := driver.WithQueryFullCount(context.Background())
sctx := driver.WithQueryFullCount(ctx)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.VerificationRequest, pagination.Offset, pagination.Limit)
cursor, err := p.db.Query(ctx, query, nil)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err
}
@@ -109,7 +109,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
for {
var verificationRequest models.VerificationRequest
meta, err := cursor.ReadDocument(nil, &verificationRequest)
meta, err := cursor.ReadDocument(ctx, &verificationRequest)
if driver.IsNoMoreDocuments(err) {
break
@@ -130,7 +130,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
}
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error {
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
collection, _ := p.db.Collection(nil, models.Collections.VerificationRequest)
_, err := collection.RemoveDocument(nil, verificationRequest.Key)
if err != nil {

View File

@@ -0,0 +1,161 @@
package arangodb
import (
"context"
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
_, err := webhookCollection.CreateDocument(ctx, webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
meta, err := webhookCollection.UpdateDocument(ctx, webhook.Key, webhook)
if err != nil {
return nil, err
}
webhook.Key = meta.Key
webhook.ID = meta.ID.String()
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{}
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.Webhook, pagination.Offset, pagination.Limit)
sctx := driver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err
}
defer cursor.Close()
paginationClone := pagination
paginationClone.Total = cursor.Statistics().FullCount()
for {
var webhook models.Webhook
meta, err := cursor.ReadDocument(ctx, &webhook)
if arangoDriver.IsNoMoreDocuments(err) {
break
} else if err != nil {
return nil, err
}
if meta.Key != "" {
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: webhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf("FOR d in %s FILTER d._key == @webhook_id RETURN d", models.Collections.Webhook)
bindVars := map[string]interface{}{
"webhook_id": webhookID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return nil, err
}
defer cursor.Close()
for {
if !cursor.HasMore() {
if webhook.Key == "" {
return nil, fmt.Errorf("webhook not found")
}
break
}
_, err := cursor.ReadDocument(ctx, &webhook)
if err != nil {
return nil, err
}
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf("FOR d in %s FILTER d.event_name == @event_name RETURN d", models.Collections.Webhook)
bindVars := map[string]interface{}{
"event_name": eventName,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return nil, err
}
defer cursor.Close()
for {
if !cursor.HasMore() {
if webhook.Key == "" {
return nil, fmt.Errorf("webhook not found")
}
break
}
_, err := cursor.ReadDocument(ctx, &webhook)
if err != nil {
return nil, err
}
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
_, err := webhookCollection.RemoveDocument(ctx, webhook.ID)
if err != nil {
return err
}
query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog)
bindVars := map[string]interface{}{
"webhook_id": webhook.ID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return err
}
defer cursor.Close()
return nil
}

View File

@@ -0,0 +1,75 @@
package arangodb
import (
"context"
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
webhookLogCollection, _ := p.db.Collection(ctx, models.Collections.WebhookLog)
_, err := webhookLogCollection.CreateDocument(ctx, webhookLog)
if err != nil {
return nil, err
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
webhookLogs := []*model.WebhookLog{}
bindVariables := map[string]interface{}{}
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit)
if webhookID != "" {
query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit)
bindVariables = map[string]interface{}{
"webhook_id": webhookID,
}
}
sctx := driver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, bindVariables)
if err != nil {
return nil, err
}
defer cursor.Close()
paginationClone := pagination
paginationClone.Total = cursor.Statistics().FullCount()
for {
var webhookLog models.WebhookLog
meta, err := cursor.ReadDocument(ctx, &webhookLog)
if arangoDriver.IsNoMoreDocuments(err) {
break
} else if err != nil {
return nil, err
}
if meta.Key != "" {
webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog())
}
}
return &model.WebhookLogs{
Pagination: &paginationClone,
WebhookLogs: webhookLogs,
}, nil
}

View File

@@ -1,6 +1,7 @@
package cassandradb
import (
"context"
"fmt"
"time"
@@ -10,7 +11,7 @@ import (
)
// AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) {
func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" {
env.ID = uuid.New().String()
}
@@ -27,7 +28,7 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
}
// UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix()
updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = '%s', updated_at = %d WHERE id = '%s'", KeySpace+"."+models.Collections.Env, env.EnvData, env.UpdatedAt, env.ID)
@@ -39,7 +40,7 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
}
// GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) {
func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env
query := fmt.Sprintf("SELECT id, env, hash, created_at, updated_at FROM %s LIMIT 1", KeySpace+"."+models.Collections.Env)

View File

@@ -5,6 +5,7 @@ import (
"crypto/x509"
"fmt"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto"
@@ -96,6 +97,8 @@ func NewProvider() (*provider, error) {
NumRetries: 3,
}
cassandraClient.Consistency = gocql.LocalQuorum
cassandraClient.ConnectTimeout = 10 * time.Second
cassandraClient.ProtoVersion = 4
session, err := cassandraClient.CreateSession()
if err != nil {
@@ -140,6 +143,11 @@ func NewProvider() (*provider, error) {
if err != nil {
return nil, err
}
sessionIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_session_user_id ON %s.%s (user_id)", KeySpace, models.Collections.Session)
err = session.Query(sessionIndexQuery).Exec()
if err != nil {
return nil, err
}
userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User)
err = session.Query(userCollectionQuery).Exec()
@@ -174,6 +182,28 @@ func NewProvider() (*provider, error) {
return nil, err
}
webhookCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, endpoint text, enabled boolean, headers text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Webhook)
err = session.Query(webhookCollectionQuery).Exec()
if err != nil {
return nil, err
}
webhookIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_webhook_event_name ON %s.%s (event_name)", KeySpace, models.Collections.Webhook)
err = session.Query(webhookIndexQuery).Exec()
if err != nil {
return nil, err
}
webhookLogCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, http_status bigint, response text, request text, webhook_id text,updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.WebhookLog)
err = session.Query(webhookLogCollectionQuery).Exec()
if err != nil {
return nil, err
}
webhookLogIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_webhook_log_webhook_id ON %s.%s (webhook_id)", KeySpace, models.Collections.WebhookLog)
err = session.Query(webhookLogIndexQuery).Exec()
if err != nil {
return nil, err
}
return &provider{
db: session,
}, err

View File

@@ -1,6 +1,7 @@
package cassandradb
import (
"context"
"fmt"
"time"
@@ -9,7 +10,7 @@ import (
)
// AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error {
func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" {
session.ID = uuid.New().String()
}
@@ -24,13 +25,3 @@ func (p *provider) AddSession(session models.Session) error {
}
return nil
}
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE user_id = '%s'", KeySpace+"."+models.Collections.Session, userId)
err := p.db.Query(deleteSessionQuery).Exec()
if err != nil {
return err
}
return nil
}

View File

@@ -1,6 +1,7 @@
package cassandradb
import (
"context"
"encoding/json"
"fmt"
"reflect"
@@ -16,7 +17,7 @@ import (
)
// AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) {
func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" {
user.ID = uuid.New().String()
}
@@ -79,7 +80,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
}
// UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) {
func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
bytes, err := json.Marshal(user)
@@ -97,10 +98,11 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
updateFields := ""
for key, value := range userMap {
if value != nil && key != "_id" {
if key == "_id" {
continue
}
if key == "_id" {
if key == "_key" {
continue
}
@@ -130,14 +132,36 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
}
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error {
func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID)
err := p.db.Query(query).Exec()
return err
if err != nil {
return err
}
getSessionsQuery := fmt.Sprintf("SELECT id FROM %s WHERE user_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.Session, user.ID)
scanner := p.db.Query(getSessionsQuery).Iter().Scanner()
sessionIDs := ""
for scanner.Next() {
var wlID string
err = scanner.Scan(&wlID)
if err != nil {
return err
}
sessionIDs += fmt.Sprintf("'%s',", wlID)
}
sessionIDs = strings.TrimSuffix(sessionIDs, ",")
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.Session, sessionIDs)
err = p.db.Query(deleteSessionQuery).Exec()
if err != nil {
return err
}
return nil
}
// ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) {
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
responseUsers := []*model.User{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User)
@@ -171,9 +195,9 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
}
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) {
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, email)
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
if err != nil {
return user, err
@@ -182,7 +206,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) {
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)

View File

@@ -1,6 +1,7 @@
package cassandradb
import (
"context"
"fmt"
"time"
@@ -11,7 +12,7 @@ import (
)
// AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String()
}
@@ -28,7 +29,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
}
// GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token)
@@ -40,7 +41,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
}
// GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier)
@@ -53,7 +54,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
}
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) {
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest
paginationClone := pagination
@@ -89,7 +90,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
}
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error {
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.VerificationRequest, verificationRequest.ID)
err := p.db.Query(query).Exec()
if err != nil {

View File

@@ -0,0 +1,172 @@
package cassandradb
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/gocql/gocql"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
existingHook, _ := p.GetWebhookByEventName(ctx, webhook.EventName)
if existingHook != nil {
return nil, fmt.Errorf("Webhook with %s event_name already exists", webhook.EventName)
}
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt)
err := p.db.Query(insertQuery).Exec()
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
bytes, err := json.Marshal(webhook)
if err != nil {
return nil, err
}
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
webhookMap := map[string]interface{}{}
err = decoder.Decode(&webhookMap)
if err != nil {
return nil, err
}
updateFields := ""
for key, value := range webhookMap {
if key == "_id" {
continue
}
if key == "_key" {
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID)
err = p.db.Query(query).Exec()
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.Webhook)
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil {
return nil, err
}
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
query := fmt.Sprintf("SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.Webhook, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var webhook models.Webhook
err := scanner.Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
counter++
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: webhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID)
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName)
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, webhook.ID)
err := p.db.Query(query).Exec()
if err != nil {
return err
}
getWebhookLogQuery := fmt.Sprintf("SELECT id FROM %s WHERE webhook_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhook.ID)
scanner := p.db.Query(getWebhookLogQuery).Iter().Scanner()
webhookLogIDs := ""
for scanner.Next() {
var wlID string
err = scanner.Scan(&wlID)
if err != nil {
return err
}
webhookLogIDs += fmt.Sprintf("'%s',", wlID)
}
webhookLogIDs = strings.TrimSuffix(webhookLogIDs, ",")
query = fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.WebhookLog, webhookLogIDs)
err = p.db.Query(query).Exec()
return err
}

View File

@@ -0,0 +1,70 @@
package cassandradb
import (
"context"
"fmt"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/gocql/gocql"
"github.com/google/uuid"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
insertQuery := fmt.Sprintf("INSERT INTO %s (id, http_status, response, request, webhook_id, created_at, updated_at) VALUES ('%s', %d,'%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.WebhookLog, webhookLog.ID, webhookLog.HttpStatus, webhookLog.Response, webhookLog.Request, webhookLog.WebhookID, webhookLog.CreatedAt, webhookLog.UpdatedAt)
err := p.db.Query(insertQuery).Exec()
if err != nil {
return nil, err
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
webhookLogs := []*model.WebhookLog{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.WebhookLog)
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
query := fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, pagination.Limit+pagination.Offset)
if webhookID != "" {
totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.WebhookLog, webhookID)
query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset)
}
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil {
return nil, err
}
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var webhookLog models.WebhookLog
err := scanner.Scan(&webhookLog.ID, &webhookLog.HttpStatus, &webhookLog.Response, &webhookLog.Request, &webhookLog.WebhookID, &webhookLog.CreatedAt, &webhookLog.UpdatedAt)
if err != nil {
return nil, err
}
webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog())
}
counter++
}
return &model.WebhookLogs{
Pagination: &paginationClone,
WebhookLogs: webhookLogs,
}, nil
}

View File

@@ -1,6 +1,7 @@
package mongodb
import (
"context"
"fmt"
"time"
@@ -11,7 +12,7 @@ import (
)
// AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) {
func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" {
env.ID = uuid.New().String()
}
@@ -20,7 +21,7 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix()
env.Key = env.ID
configCollection := p.db.Collection(models.Collections.Env, options.Collection())
_, err := configCollection.InsertOne(nil, env)
_, err := configCollection.InsertOne(ctx, env)
if err != nil {
return env, err
}
@@ -28,10 +29,10 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
}
// UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix()
configCollection := p.db.Collection(models.Collections.Env, options.Collection())
_, err := configCollection.UpdateOne(nil, bson.M{"_id": bson.M{"$eq": env.ID}}, bson.M{"$set": env}, options.MergeUpdateOptions())
_, err := configCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": env.ID}}, bson.M{"$set": env}, options.MergeUpdateOptions())
if err != nil {
return env, err
}
@@ -39,14 +40,14 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
}
// GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) {
func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env
configCollection := p.db.Collection(models.Collections.Env, options.Collection())
cursor, err := configCollection.Find(nil, bson.M{}, options.Find())
cursor, err := configCollection.Find(ctx, bson.M{}, options.Find())
if err != nil {
return env, err
}
defer cursor.Close(nil)
defer cursor.Close(ctx)
for cursor.Next(nil) {
err := cursor.Decode(&env)

View File

@@ -83,6 +83,24 @@ func NewProvider() (*provider, error) {
mongodb.CreateCollection(ctx, models.Collections.Env, options.CreateCollection())
mongodb.CreateCollection(ctx, models.Collections.Webhook, options.CreateCollection())
webhookCollection := mongodb.Collection(models.Collections.Webhook, options.Collection())
webhookCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{
{
Keys: bson.M{"event_name": 1},
Options: options.Index().SetUnique(true).SetSparse(true),
},
}, options.CreateIndexes())
mongodb.CreateCollection(ctx, models.Collections.WebhookLog, options.CreateCollection())
webhookLogCollection := mongodb.Collection(models.Collections.WebhookLog, options.Collection())
webhookLogCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{
{
Keys: bson.M{"webhook_id": 1},
Options: options.Index().SetSparse(true),
},
}, options.CreateIndexes())
return &provider{
db: mongodb,
}, nil

View File

@@ -1,16 +1,16 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
)
// AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error {
func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" {
session.ID = uuid.New().String()
}
@@ -19,17 +19,7 @@ func (p *provider) AddSession(session models.Session) error {
session.CreatedAt = time.Now().Unix()
session.UpdatedAt = time.Now().Unix()
sessionCollection := p.db.Collection(models.Collections.Session, options.Collection())
_, err := sessionCollection.InsertOne(nil, session)
if err != nil {
return err
}
return nil
}
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
sessionCollection := p.db.Collection(models.Collections.Session, options.Collection())
_, err := sessionCollection.DeleteMany(nil, bson.M{"user_id": userId}, options.Delete())
_, err := sessionCollection.InsertOne(ctx, session)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/constants"
@@ -13,7 +14,7 @@ import (
)
// AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) {
func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" {
user.ID = uuid.New().String()
}
@@ -29,7 +30,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
user.Key = user.ID
userCollection := p.db.Collection(models.Collections.User, options.Collection())
_, err := userCollection.InsertOne(nil, user)
_, err := userCollection.InsertOne(ctx, user)
if err != nil {
return user, err
}
@@ -38,10 +39,10 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
}
// UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) {
func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
userCollection := p.db.Collection(models.Collections.User, options.Collection())
_, err := userCollection.UpdateOne(nil, bson.M{"_id": bson.M{"$eq": user.ID}}, bson.M{"$set": user}, options.MergeUpdateOptions())
_, err := userCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": user.ID}}, bson.M{"$set": user}, options.MergeUpdateOptions())
if err != nil {
return user, err
}
@@ -49,9 +50,15 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
}
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error {
func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
userCollection := p.db.Collection(models.Collections.User, options.Collection())
_, err := userCollection.DeleteOne(nil, bson.M{"_id": user.ID}, options.Delete())
_, err := userCollection.DeleteOne(ctx, bson.M{"_id": user.ID}, options.Delete())
if err != nil {
return err
}
sessionCollection := p.db.Collection(models.Collections.Session, options.Collection())
_, err = sessionCollection.DeleteMany(ctx, bson.M{"user_id": user.ID}, options.Delete())
if err != nil {
return err
}
@@ -60,7 +67,7 @@ func (p *provider) DeleteUser(user models.User) error {
}
// ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) {
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []*model.User
opts := options.Find()
opts.SetLimit(pagination.Limit)
@@ -70,20 +77,20 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
paginationClone := pagination
userCollection := p.db.Collection(models.Collections.User, options.Collection())
count, err := userCollection.CountDocuments(nil, bson.M{}, options.Count())
count, err := userCollection.CountDocuments(ctx, bson.M{}, options.Count())
if err != nil {
return nil, err
}
paginationClone.Total = count
cursor, err := userCollection.Find(nil, bson.M{}, opts)
cursor, err := userCollection.Find(ctx, bson.M{}, opts)
if err != nil {
return nil, err
}
defer cursor.Close(nil)
defer cursor.Close(ctx)
for cursor.Next(nil) {
for cursor.Next(ctx) {
var user models.User
err := cursor.Decode(&user)
if err != nil {
@@ -99,10 +106,10 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
}
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) {
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User
userCollection := p.db.Collection(models.Collections.User, options.Collection())
err := userCollection.FindOne(nil, bson.M{"email": email}).Decode(&user)
err := userCollection.FindOne(ctx, bson.M{"email": email}).Decode(&user)
if err != nil {
return user, err
}
@@ -111,11 +118,11 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) {
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User
userCollection := p.db.Collection(models.Collections.User, options.Collection())
err := userCollection.FindOne(nil, bson.M{"_id": id}).Decode(&user)
err := userCollection.FindOne(ctx, bson.M{"_id": id}).Decode(&user)
if err != nil {
return user, err
}

View File

@@ -1,6 +1,7 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -11,7 +12,7 @@ import (
)
// AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String()
@@ -19,7 +20,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
verificationRequest.UpdatedAt = time.Now().Unix()
verificationRequest.Key = verificationRequest.ID
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
_, err := verificationRequestCollection.InsertOne(nil, verificationRequest)
_, err := verificationRequestCollection.InsertOne(ctx, verificationRequest)
if err != nil {
return verificationRequest, err
}
@@ -29,11 +30,11 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
}
// GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
err := verificationRequestCollection.FindOne(nil, bson.M{"token": token}).Decode(&verificationRequest)
err := verificationRequestCollection.FindOne(ctx, bson.M{"token": token}).Decode(&verificationRequest)
if err != nil {
return verificationRequest, err
}
@@ -42,11 +43,11 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
}
// GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
err := verificationRequestCollection.FindOne(nil, bson.M{"email": email, "identifier": identifier}).Decode(&verificationRequest)
err := verificationRequestCollection.FindOne(ctx, bson.M{"email": email, "identifier": identifier}).Decode(&verificationRequest)
if err != nil {
return verificationRequest, err
}
@@ -55,7 +56,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
}
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) {
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest
opts := options.Find()
@@ -65,17 +66,17 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
verificationRequestCollectionCount, err := verificationRequestCollection.CountDocuments(nil, bson.M{})
verificationRequestCollectionCount, err := verificationRequestCollection.CountDocuments(ctx, bson.M{})
paginationClone := pagination
paginationClone.Total = verificationRequestCollectionCount
cursor, err := verificationRequestCollection.Find(nil, bson.M{}, opts)
cursor, err := verificationRequestCollection.Find(ctx, bson.M{}, opts)
if err != nil {
return nil, err
}
defer cursor.Close(nil)
defer cursor.Close(ctx)
for cursor.Next(nil) {
for cursor.Next(ctx) {
var verificationRequest models.VerificationRequest
err := cursor.Decode(&verificationRequest)
if err != nil {
@@ -91,9 +92,9 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
}
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error {
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
_, err := verificationRequestCollection.DeleteOne(nil, bson.M{"_id": verificationRequest.ID}, options.Delete())
_, err := verificationRequestCollection.DeleteOne(ctx, bson.M{"_id": verificationRequest.ID}, options.Delete())
if err != nil {
return err
}

View File

@@ -0,0 +1,120 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.InsertOne(ctx, webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": webhook.ID}}, bson.M{"$set": webhook}, options.MergeUpdateOptions())
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
var webhooks []*model.Webhook
opts := options.Find()
opts.SetLimit(pagination.Limit)
opts.SetSkip(pagination.Offset)
opts.SetSort(bson.M{"created_at": -1})
paginationClone := pagination
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
count, err := webhookCollection.CountDocuments(ctx, bson.M{}, options.Count())
if err != nil {
return nil, err
}
paginationClone.Total = count
cursor, err := webhookCollection.Find(ctx, bson.M{}, opts)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
var webhook models.Webhook
err := cursor.Decode(&webhook)
if err != nil {
return nil, err
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: webhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
err := webhookCollection.FindOne(ctx, bson.M{"_id": webhookID}).Decode(&webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
err := webhookCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.DeleteOne(nil, bson.M{"_id": webhook.ID}, options.Delete())
if err != nil {
return err
}
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
_, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete())
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,74 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
_, err := webhookLogCollection.InsertOne(ctx, webhookLog)
if err != nil {
return nil, err
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
webhookLogs := []*model.WebhookLog{}
opts := options.Find()
opts.SetLimit(pagination.Limit)
opts.SetSkip(pagination.Offset)
opts.SetSort(bson.M{"created_at": -1})
paginationClone := pagination
query := bson.M{}
if webhookID != "" {
query = bson.M{"webhook_id": webhookID}
}
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
count, err := webhookLogCollection.CountDocuments(ctx, query, options.Count())
if err != nil {
return nil, err
}
paginationClone.Total = count
cursor, err := webhookLogCollection.Find(ctx, query, opts)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
var webhookLog models.WebhookLog
err := cursor.Decode(&webhookLog)
if err != nil {
return nil, err
}
webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog())
}
return &model.WebhookLogs{
Pagination: &paginationClone,
WebhookLogs: webhookLogs,
}, nil
}

View File

@@ -1,6 +1,7 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -8,7 +9,7 @@ import (
)
// AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) {
func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" {
env.ID = uuid.New().String()
}
@@ -19,13 +20,13 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
}
// UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix()
return env, nil
}
// GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) {
func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env
return env, nil

View File

@@ -1,6 +1,7 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -8,7 +9,7 @@ import (
)
// AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error {
func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" {
session.ID = uuid.New().String()
}
@@ -19,6 +20,6 @@ func (p *provider) AddSession(session models.Session) error {
}
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
func (p *provider) DeleteSession(ctx context.Context, userId string) error {
return nil
}

View File

@@ -1,6 +1,7 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/constants"
@@ -11,7 +12,7 @@ import (
)
// AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) {
func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" {
user.ID = uuid.New().String()
}
@@ -31,30 +32,30 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
}
// UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) {
func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
return user, nil
}
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error {
func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
return nil
}
// ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) {
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
return nil, nil
}
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) {
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User
return user, nil
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) {
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User
return user, nil

View File

@@ -1,6 +1,7 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -9,7 +10,7 @@ import (
)
// AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String()
}
@@ -21,25 +22,25 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
}
// GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
return verificationRequest, nil
}
// GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
return verificationRequest, nil
}
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) {
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
return nil, nil
}
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error {
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
return nil
}

View File

@@ -0,0 +1,49 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
return nil, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
return nil, nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
return nil, nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
// Also delete webhook logs for given webhook id
return nil
}

View File

@@ -0,0 +1,27 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
return nil, nil
}

View File

@@ -1,44 +1,62 @@
package providers
import (
"context"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
)
type Provider interface {
// AddUser to save user information in database
AddUser(user models.User) (models.User, error)
AddUser(ctx context.Context, user models.User) (models.User, error)
// UpdateUser to update user information in database
UpdateUser(user models.User) (models.User, error)
UpdateUser(ctx context.Context, user models.User) (models.User, error)
// DeleteUser to delete user information from database
DeleteUser(user models.User) error
DeleteUser(ctx context.Context, user models.User) error
// ListUsers to get list of users from database
ListUsers(pagination model.Pagination) (*model.Users, error)
ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error)
// GetUserByEmail to get user information from database using email address
GetUserByEmail(email string) (models.User, error)
GetUserByEmail(ctx context.Context, email string) (models.User, error)
// GetUserByID to get user information from database using user ID
GetUserByID(id string) (models.User, error)
GetUserByID(ctx context.Context, id string) (models.User, error)
// AddVerification to save verification request in database
AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error)
AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error)
// GetVerificationRequestByToken to get verification request from database using token
GetVerificationRequestByToken(token string) (models.VerificationRequest, error)
GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error)
// GetVerificationRequestByEmail to get verification request by email from database
GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error)
GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error)
// ListVerificationRequests to get list of verification requests from database
ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error)
ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error)
// DeleteVerificationRequest to delete verification request from database
DeleteVerificationRequest(verificationRequest models.VerificationRequest) error
DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error
// AddSession to save session information in database
AddSession(session models.Session) error
// DeleteSession to delete session information from database
DeleteSession(userId string) error
AddSession(ctx context.Context, session models.Session) error
// AddEnv to save environment information in database
AddEnv(env models.Env) (models.Env, error)
AddEnv(ctx context.Context, env models.Env) (models.Env, error)
// UpdateEnv to update environment information in database
UpdateEnv(env models.Env) (models.Env, error)
UpdateEnv(ctx context.Context, env models.Env) (models.Env, error)
// GetEnv to get environment information from database
GetEnv() (models.Env, error)
GetEnv(ctx context.Context) (models.Env, error)
// AddWebhook to add webhook
AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error)
// UpdateWebhook to update webhook
UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error)
// ListWebhooks to list webhook
ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error)
// GetWebhookByID to get webhook by id
GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error)
// GetWebhookByEventName to get webhook by event_name
GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error)
// DeleteWebhook to delete webhook
DeleteWebhook(ctx context.Context, webhook *model.Webhook) error
// AddWebhookLog to add webhook log
AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error)
// ListWebhookLogs to list webhook logs
ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error)
}

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -8,7 +9,7 @@ import (
)
// AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) {
func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" {
env.ID = uuid.New().String()
}
@@ -25,7 +26,7 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
}
// UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix()
result := p.db.Save(&env)
@@ -36,7 +37,7 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
}
// GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) {
func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env
result := p.db.First(&env)

View File

@@ -50,7 +50,7 @@ func NewProvider() (*provider, error) {
sqlDB, err = gorm.Open(postgres.Open(dbURL), ormConfig)
case constants.DbTypeSqlite:
sqlDB, err = gorm.Open(sqlite.Open(dbURL), ormConfig)
case constants.DbTypeMysql, constants.DbTypeMariaDB:
case constants.DbTypeMysql, constants.DbTypeMariaDB, constants.DbTypePlanetScaleDB:
sqlDB, err = gorm.Open(mysql.Open(dbURL), ormConfig)
case constants.DbTypeSqlserver:
sqlDB, err = gorm.Open(sqlserver.Open(dbURL), ormConfig)
@@ -60,7 +60,7 @@ func NewProvider() (*provider, error) {
return nil, err
}
err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{})
err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}, &models.Webhook{}, models.WebhookLog{})
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -9,7 +10,7 @@ import (
)
// AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error {
func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" {
session.ID = uuid.New().String()
}
@@ -26,13 +27,3 @@ func (p *provider) AddSession(session models.Session) error {
}
return nil
}
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
result := p.db.Where("user_id = ?", userId).Delete(&models.Session{})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/constants"
@@ -12,7 +13,7 @@ import (
)
// AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) {
func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" {
user.ID = uuid.New().String()
}
@@ -42,7 +43,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
}
// UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) {
func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix()
result := p.db.Save(&user)
@@ -55,18 +56,23 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
}
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error {
func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
result := p.db.Delete(&user)
if result.Error != nil {
return result.Error
}
result = p.db.Where("user_id = ?", user.ID).Delete(&models.Session{})
if result.Error != nil {
return result.Error
}
return nil
}
// ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) {
func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []models.User
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&users)
if result.Error != nil {
@@ -94,7 +100,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
}
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) {
func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User
result := p.db.Where("email = ?", email).First(&user)
if result.Error != nil {
@@ -105,7 +111,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) {
func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User
result := p.db.Where("id = ?", id).First(&user)

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
@@ -10,7 +11,7 @@ import (
)
// AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String()
}
@@ -31,7 +32,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
}
// GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
result := p.db.Where("token = ?", token).First(&verificationRequest)
@@ -43,7 +44,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
}
// GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) {
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest
result := p.db.Where("email = ? AND identifier = ?", email, identifier).First(&verificationRequest)
@@ -56,7 +57,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
}
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) {
func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []models.VerificationRequest
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&verificationRequests)
@@ -85,7 +86,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
}
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error {
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
result := p.db.Delete(&verificationRequest)
if result.Error != nil {

View File

@@ -0,0 +1,104 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
res := p.db.Create(&webhook)
if res.Error != nil {
return nil, res.Error
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
result := p.db.Save(&webhook)
if result.Error != nil {
return nil, result.Error
}
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
var webhooks []models.Webhook
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks)
if result.Error != nil {
return nil, result.Error
}
var total int64
totalRes := p.db.Model(&models.Webhook{}).Count(&total)
if totalRes.Error != nil {
return nil, totalRes.Error
}
paginationClone := pagination
paginationClone.Total = total
responseWebhooks := []*model.Webhook{}
for _, w := range webhooks {
responseWebhooks = append(responseWebhooks, w.AsAPIWebhook())
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: responseWebhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
result := p.db.Where("id = ?", webhookID).First(&webhook)
if result.Error != nil {
return nil, result.Error
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
result := p.db.Where("event_name = ?", eventName).First(&webhook)
if result.Error != nil {
return nil, result.Error
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
result := p.db.Delete(&models.Webhook{
ID: webhook.ID,
})
if result.Error != nil {
return result.Error
}
result = p.db.Where("webhook_id = ?", webhook.ID).Delete(&models.WebhookLog{})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,68 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
res := p.db.Clauses(
clause.OnConflict{
DoNothing: true,
}).Create(&webhookLog)
if res.Error != nil {
return nil, res.Error
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
var webhookLogs []models.WebhookLog
var result *gorm.DB
var totalRes *gorm.DB
var total int64
if webhookID != "" {
result = p.db.Where("webhook_id = ?", webhookID).Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhookLogs)
totalRes = p.db.Where("webhook_id = ?", webhookID).Model(&models.WebhookLog{}).Count(&total)
} else {
result = p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhookLogs)
totalRes = p.db.Model(&models.WebhookLog{}).Count(&total)
}
if result.Error != nil {
return nil, result.Error
}
if totalRes.Error != nil {
return nil, totalRes.Error
}
paginationClone := pagination
paginationClone.Total = total
responseWebhookLogs := []*model.WebhookLog{}
for _, w := range webhookLogs {
responseWebhookLogs = append(responseWebhookLogs, w.AsAPIWebhookLog())
}
return &model.WebhookLogs{
WebhookLogs: responseWebhookLogs,
Pagination: &paginationClone,
}, nil
}

14
server/env/env.go vendored
View File

@@ -83,6 +83,7 @@ func InitAllEnv() error {
osDisableLoginPage := os.Getenv(constants.EnvKeyDisableLoginPage)
osDisableSignUp := os.Getenv(constants.EnvKeyDisableSignUp)
osDisableRedisForEnv := os.Getenv(constants.EnvKeyDisableRedisForEnv)
osDisableStrongPassword := os.Getenv(constants.EnvKeyDisableStrongPassword)
// os slice vars
osAllowedOrigins := os.Getenv(constants.EnvKeyAllowedOrigins)
@@ -476,6 +477,19 @@ func InitAllEnv() error {
}
}
if _, ok := envData[constants.EnvKeyDisableStrongPassword]; !ok {
envData[constants.EnvKeyDisableStrongPassword] = osDisableStrongPassword == "true"
}
if osDisableStrongPassword != "" {
boolValue, err := strconv.ParseBool(osDisableStrongPassword)
if err != nil {
return err
}
if boolValue != envData[constants.EnvKeyDisableStrongPassword].(bool) {
envData[constants.EnvKeyDisableStrongPassword] = boolValue
}
}
// no need to add nil check as its already done above
if envData[constants.EnvKeySmtpHost] == "" || envData[constants.EnvKeySmtpUsername] == "" || envData[constants.EnvKeySmtpPassword] == "" || envData[constants.EnvKeySenderEmail] == "" && envData[constants.EnvKeySmtpPort] == "" {
envData[constants.EnvKeyDisableEmailVerification] = true

View File

@@ -1,6 +1,7 @@
package env
import (
"context"
"encoding/json"
"os"
"reflect"
@@ -58,7 +59,8 @@ func fixBackwardCompatibility(data map[string]interface{}) (bool, map[string]int
// GetEnvData returns the env data from database
func GetEnvData() (map[string]interface{}, error) {
var result map[string]interface{}
env, err := db.Provider.GetEnv()
ctx := context.Background()
env, err := db.Provider.GetEnv(ctx)
// config not found in db
if err != nil {
log.Debug("Error while getting env data from db: ", err)
@@ -108,9 +110,10 @@ func GetEnvData() (map[string]interface{}, error) {
// PersistEnv persists the environment variables to the database
func PersistEnv() error {
env, err := db.Provider.GetEnv()
ctx := context.Background()
env, err := db.Provider.GetEnv(ctx)
// config not found in db
if err != nil {
if err != nil || env.EnvData == "" {
// AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid
hash := uuid.New().String()[:36-4]
err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash)
@@ -137,7 +140,7 @@ func PersistEnv() error {
EnvData: encryptedConfig,
}
env, err = db.Provider.AddEnv(env)
env, err = db.Provider.AddEnv(ctx, env)
if err != nil {
log.Debug("Error while persisting env data to db: ", err)
return err
@@ -171,7 +174,7 @@ func PersistEnv() error {
err = json.Unmarshal(decryptedConfigs, &storeData)
if err != nil {
log.Debug("Error while unmarshalling env data: ", err)
log.Debug("Error while un-marshalling env data: ", err)
return err
}
@@ -198,7 +201,7 @@ func PersistEnv() error {
envValue := strings.TrimSpace(os.Getenv(key))
if envValue != "" {
switch key {
case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv:
case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv, constants.EnvKeyDisableStrongPassword:
if envValueBool, err := strconv.ParseBool(envValue); err == nil {
if value.(bool) != envValueBool {
storeData[key] = envValueBool
@@ -251,7 +254,7 @@ func PersistEnv() error {
}
env.EnvData = encryptedConfig
_, err = db.Provider.UpdateEnv(env)
_, err = db.Provider.UpdateEnv(ctx, env)
if err != nil {
log.Debug("Failed to Update Config: ", err)
return err

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,13 @@
package model
type AddWebhookRequest struct {
EventName string `json:"event_name"`
Endpoint string `json:"endpoint"`
Enabled bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
}
type AdminLoginInput struct {
AdminSecret string `json:"admin_secret"`
}
@@ -55,6 +62,7 @@ type Env struct {
DisableLoginPage bool `json:"DISABLE_LOGIN_PAGE"`
DisableSignUp bool `json:"DISABLE_SIGN_UP"`
DisableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"`
DisableStrongPassword bool `json:"DISABLE_STRONG_PASSWORD"`
Roles []string `json:"ROLES"`
ProtectedRoles []string `json:"PROTECTED_ROLES"`
DefaultRoles []string `json:"DEFAULT_ROLES"`
@@ -99,6 +107,11 @@ type InviteMemberInput struct {
RedirectURI *string `json:"redirect_uri"`
}
type ListWebhookLogRequest struct {
Pagination *PaginationInput `json:"pagination"`
WebhookID *string `json:"webhook_id"`
}
type LoginInput struct {
Email string `json:"email"`
Password string `json:"password"`
@@ -126,6 +139,7 @@ type Meta struct {
IsBasicAuthenticationEnabled bool `json:"is_basic_authentication_enabled"`
IsMagicLinkLoginEnabled bool `json:"is_magic_link_login_enabled"`
IsSignUpEnabled bool `json:"is_sign_up_enabled"`
IsStrongPasswordEnabled bool `json:"is_strong_password_enabled"`
}
type OAuthRevokeInput struct {
@@ -185,6 +199,17 @@ type SignUpInput struct {
RedirectURI *string `json:"redirect_uri"`
}
type TestEndpointRequest struct {
Endpoint string `json:"endpoint"`
EventName string `json:"event_name"`
Headers map[string]interface{} `json:"headers"`
}
type TestEndpointResponse struct {
HTTPStatus *int64 `json:"http_status"`
Response map[string]interface{} `json:"response"`
}
type UpdateAccessInput struct {
UserID string `json:"user_id"`
}
@@ -212,6 +237,7 @@ type UpdateEnvInput struct {
DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"`
DisableSignUp *bool `json:"DISABLE_SIGN_UP"`
DisableRedisForEnv *bool `json:"DISABLE_REDIS_FOR_ENV"`
DisableStrongPassword *bool `json:"DISABLE_STRONG_PASSWORD"`
Roles []string `json:"ROLES"`
ProtectedRoles []string `json:"PROTECTED_ROLES"`
DefaultRoles []string `json:"DEFAULT_ROLES"`
@@ -260,6 +286,14 @@ type UpdateUserInput struct {
Roles []*string `json:"roles"`
}
type UpdateWebhookRequest struct {
ID string `json:"id"`
EventName *string `json:"event_name"`
Endpoint *string `json:"endpoint"`
Enabled *bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`
@@ -316,3 +350,37 @@ type VerificationRequests struct {
type VerifyEmailInput struct {
Token string `json:"token"`
}
type Webhook struct {
ID string `json:"id"`
EventName *string `json:"event_name"`
Endpoint *string `json:"endpoint"`
Enabled *bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
CreatedAt *int64 `json:"created_at"`
UpdatedAt *int64 `json:"updated_at"`
}
type WebhookLog struct {
ID string `json:"id"`
HTTPStatus *int64 `json:"http_status"`
Response *string `json:"response"`
Request *string `json:"request"`
WebhookID *string `json:"webhook_id"`
CreatedAt *int64 `json:"created_at"`
UpdatedAt *int64 `json:"updated_at"`
}
type WebhookLogs struct {
Pagination *Pagination `json:"pagination"`
WebhookLogs []*WebhookLog `json:"webhook_logs"`
}
type WebhookRequest struct {
ID string `json:"id"`
}
type Webhooks struct {
Pagination *Pagination `json:"pagination"`
Webhooks []*Webhook `json:"webhooks"`
}

View File

@@ -24,6 +24,7 @@ type Meta {
is_basic_authentication_enabled: Boolean!
is_magic_link_login_enabled: Boolean!
is_sign_up_enabled: Boolean!
is_strong_password_enabled: Boolean!
}
type User {
@@ -120,6 +121,7 @@ type Env {
DISABLE_LOGIN_PAGE: Boolean!
DISABLE_SIGN_UP: Boolean!
DISABLE_REDIS_FOR_ENV: Boolean!
DISABLE_STRONG_PASSWORD: Boolean!
ROLES: [String!]
PROTECTED_ROLES: [String!]
DEFAULT_ROLES: [String!]
@@ -171,6 +173,7 @@ input UpdateEnvInput {
DISABLE_LOGIN_PAGE: Boolean
DISABLE_SIGN_UP: Boolean
DISABLE_REDIS_FOR_ENV: Boolean
DISABLE_STRONG_PASSWORD: Boolean
ROLES: [String!]
PROTECTED_ROLES: [String!]
DEFAULT_ROLES: [String!]
@@ -321,6 +324,71 @@ input GenerateJWTKeysInput {
type: String!
}
type Webhook {
id: ID!
event_name: String
endpoint: String
enabled: Boolean
headers: Map
created_at: Int64
updated_at: Int64
}
type Webhooks {
pagination: Pagination!
webhooks: [Webhook!]!
}
type WebhookLog {
id: ID!
http_status: Int64
response: String
request: String
webhook_id: ID
created_at: Int64
updated_at: Int64
}
type TestEndpointResponse {
http_status: Int64
response: Map
}
input ListWebhookLogRequest {
pagination: PaginationInput
webhook_id: String
}
input AddWebhookRequest {
event_name: String!
endpoint: String!
enabled: Boolean!
headers: Map
}
input UpdateWebhookRequest {
id: ID!
event_name: String
endpoint: String
enabled: Boolean
headers: Map
}
input WebhookRequest {
id: ID!
}
input TestEndpointRequest {
endpoint: String!
event_name: String!
headers: Map
}
type WebhookLogs {
pagination: Pagination!
webhook_logs: [WebhookLog!]!
}
type Mutation {
signup(params: SignUpInput!): AuthResponse!
login(params: LoginInput!): AuthResponse!
@@ -343,6 +411,10 @@ type Mutation {
_revoke_access(param: UpdateAccessInput!): Response!
_enable_access(param: UpdateAccessInput!): Response!
_generate_jwt_keys(params: GenerateJWTKeysInput!): GenerateJWTKeysResponse!
_add_webhook(params: AddWebhookRequest!): Response!
_update_webhook(params: UpdateWebhookRequest!): Response!
_delete_webhook(params: WebhookRequest!): Response!
_test_endpoint(params: TestEndpointRequest!): TestEndpointResponse!
}
type Query {
@@ -355,4 +427,7 @@ type Query {
_verification_requests(params: PaginatedInput): VerificationRequests!
_admin_session: Response!
_env: Env!
_webhook(params: WebhookRequest!): Webhook!
_webhooks(params: PaginatedInput): Webhooks!
_webhook_logs(params: ListWebhookLogRequest): WebhookLogs!
}

View File

@@ -91,6 +91,22 @@ func (r *mutationResolver) GenerateJwtKeys(ctx context.Context, params model.Gen
return resolvers.GenerateJWTKeysResolver(ctx, params)
}
func (r *mutationResolver) AddWebhook(ctx context.Context, params model.AddWebhookRequest) (*model.Response, error) {
return resolvers.AddWebhookResolver(ctx, params)
}
func (r *mutationResolver) UpdateWebhook(ctx context.Context, params model.UpdateWebhookRequest) (*model.Response, error) {
return resolvers.UpdateWebhookResolver(ctx, params)
}
func (r *mutationResolver) DeleteWebhook(ctx context.Context, params model.WebhookRequest) (*model.Response, error) {
return resolvers.DeleteWebhookResolver(ctx, params)
}
func (r *mutationResolver) TestEndpoint(ctx context.Context, params model.TestEndpointRequest) (*model.TestEndpointResponse, error) {
return resolvers.TestEndpointResolver(ctx, params)
}
func (r *queryResolver) Meta(ctx context.Context) (*model.Meta, error) {
return resolvers.MetaResolver(ctx)
}
@@ -123,6 +139,18 @@ func (r *queryResolver) Env(ctx context.Context) (*model.Env, error) {
return resolvers.EnvResolver(ctx)
}
func (r *queryResolver) Webhook(ctx context.Context, params model.WebhookRequest) (*model.Webhook, error) {
return resolvers.WebhookResolver(ctx, params)
}
func (r *queryResolver) Webhooks(ctx context.Context, params *model.PaginatedInput) (*model.Webhooks, error) {
return resolvers.WebhooksResolver(ctx, params)
}
func (r *queryResolver) WebhookLogs(ctx context.Context, params *model.ListWebhookLogRequest) (*model.WebhookLogs, error) {
return resolvers.WebhookLogsResolver(ctx, params)
}
// Mutation returns generated.MutationResolver implementation.
func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} }

View File

@@ -199,7 +199,7 @@ func AuthorizeHandler() gin.HandlerFunc {
return
}
userID := claims.Subject
user, err := db.Provider.GetUserByID(userID)
user, err := db.Provider.GetUserByID(gc, userID)
if err != nil {
if isQuery {
gc.Redirect(http.StatusFound, loginURL)
@@ -218,13 +218,18 @@ func AuthorizeHandler() gin.HandlerFunc {
return
}
sessionKey := user.ID
if claims.LoginMethod != "" {
sessionKey = claims.LoginMethod + ":" + user.ID
}
// if user is logged in
// based on the response type, generate the response
// based on the response type code, generate the response
if isResponseTypeCode {
// rollover the session for security
go memorystore.Provider.DeleteUserSession(user.ID, claims.Nonce)
go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
nonce := uuid.New().String()
newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope)
newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod)
if err != nil {
if isQuery {
gc.Redirect(http.StatusFound, loginURL)
@@ -262,7 +267,7 @@ func AuthorizeHandler() gin.HandlerFunc {
if isResponseTypeToken {
// rollover the session for security
authToken, err := token.CreateAuthToken(gc, user, claims.Roles, scope)
authToken, err := token.CreateAuthToken(gc, user, claims.Roles, scope, claims.LoginMethod)
if err != nil {
if isQuery {
gc.Redirect(http.StatusFound, loginURL)
@@ -280,9 +285,10 @@ func AuthorizeHandler() gin.HandlerFunc {
}
return
}
go memorystore.Provider.DeleteUserSession(user.ID, claims.Nonce)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
cookie.SetSession(gc, authToken.FingerPrintHash)
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -305,7 +311,7 @@ func AuthorizeHandler() gin.HandlerFunc {
if authToken.RefreshToken != nil {
res["refresh_token"] = authToken.RefreshToken.Token
params += "&refresh_token=" + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
if isQuery {

View File

@@ -2,6 +2,7 @@ package handlers
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
@@ -17,7 +18,6 @@ import (
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/cookie"
"github.com/authorizerdev/authorizer/server/crypto"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/memorystore"
@@ -28,21 +28,21 @@ import (
// OAuthCallbackHandler handles the OAuth callback for various oauth providers
func OAuthCallbackHandler() gin.HandlerFunc {
return func(c *gin.Context) {
provider := c.Param("oauth_provider")
state := c.Request.FormValue("state")
return func(ctx *gin.Context) {
provider := ctx.Param("oauth_provider")
state := ctx.Request.FormValue("state")
sessionState, err := memorystore.Provider.GetState(state)
if sessionState == "" || err != nil {
log.Debug("Invalid oauth state: ", state)
c.JSON(400, gin.H{"error": "invalid oauth state"})
ctx.JSON(400, gin.H{"error": "invalid oauth state"})
}
// contains random token, redirect url, role
sessionSplit := strings.Split(state, "___")
if len(sessionSplit) < 3 {
log.Debug("Unable to get redirect url from state: ", state)
c.JSON(400, gin.H{"error": "invalid redirect url"})
ctx.JSON(400, gin.H{"error": "invalid redirect url"})
return
}
@@ -55,17 +55,17 @@ func OAuthCallbackHandler() gin.HandlerFunc {
scopes := strings.Split(sessionSplit[3], ",")
user := models.User{}
code := c.Request.FormValue("code")
code := ctx.Request.FormValue("code")
switch provider {
case constants.SignupMethodGoogle:
case constants.AuthRecipeMethodGoogle:
user, err = processGoogleUserInfo(code)
case constants.SignupMethodGithub:
case constants.AuthRecipeMethodGithub:
user, err = processGithubUserInfo(code)
case constants.SignupMethodFacebook:
case constants.AuthRecipeMethodFacebook:
user, err = processFacebookUserInfo(code)
case constants.SignupMethodLinkedIn:
case constants.AuthRecipeMethodLinkedIn:
user, err = processLinkedInUserInfo(code)
case constants.SignupMethodApple:
case constants.AuthRecipeMethodApple:
user, err = processAppleUserInfo(code)
default:
log.Info("Invalid oauth provider")
@@ -74,23 +74,24 @@ func OAuthCallbackHandler() gin.HandlerFunc {
if err != nil {
log.Debug("Failed to process user info: ", err)
c.JSON(400, gin.H{"error": err.Error()})
ctx.JSON(400, gin.H{"error": err.Error()})
return
}
existingUser, err := db.Provider.GetUserByEmail(user.Email)
existingUser, err := db.Provider.GetUserByEmail(ctx, user.Email)
log := log.WithField("user", user.Email)
isSignUp := false
if err != nil {
isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp)
if err != nil {
log.Debug("Failed to get signup disabled env variable: ", err)
c.JSON(400, gin.H{"error": err.Error()})
ctx.JSON(400, gin.H{"error": err.Error()})
return
}
if isSignupDisabled {
log.Debug("Failed to signup as disabled")
c.JSON(400, gin.H{"error": "signup is disabled for this instance"})
ctx.JSON(400, gin.H{"error": "signup is disabled for this instance"})
return
}
// user not registered, register user and generate session token
@@ -113,19 +114,20 @@ func OAuthCallbackHandler() gin.HandlerFunc {
if hasProtectedRole {
log.Debug("Signup is not allowed with protected roles:", inputRoles)
c.JSON(400, gin.H{"error": "invalid role"})
ctx.JSON(400, gin.H{"error": "invalid role"})
return
}
user.Roles = strings.Join(inputRoles, ",")
now := time.Now().Unix()
user.EmailVerifiedAt = &now
user, _ = db.Provider.AddUser(user)
user, _ = db.Provider.AddUser(ctx, user)
isSignUp = true
} else {
user = existingUser
if user.RevokedTimestamp != nil {
log.Debug("User access revoked at: ", user.RevokedTimestamp)
c.JSON(400, gin.H{"error": "user access has been revoked"})
ctx.JSON(400, gin.H{"error": "user access has been revoked"})
return
}
@@ -175,7 +177,7 @@ func OAuthCallbackHandler() gin.HandlerFunc {
if hasProtectedRole {
log.Debug("Invalid role. User is using protected unassigned role")
c.JSON(400, gin.H{"error": "invalid role"})
ctx.JSON(400, gin.H{"error": "invalid role"})
return
} else {
user.Roles = existingUser.Roles + "," + strings.Join(unasignedRoles, ",")
@@ -184,18 +186,18 @@ func OAuthCallbackHandler() gin.HandlerFunc {
user.Roles = existingUser.Roles
}
user, err = db.Provider.UpdateUser(user)
user, err = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
c.JSON(500, gin.H{"error": err.Error()})
ctx.JSON(500, gin.H{"error": err.Error()})
return
}
}
authToken, err := token.CreateAuthToken(c, user, inputRoles, scopes)
authToken, err := token.CreateAuthToken(ctx, user, inputRoles, scopes, provider)
if err != nil {
log.Debug("Failed to create auth token: ", err)
c.JSON(500, gin.H{"error": err.Error()})
ctx.JSON(500, gin.H{"error": err.Error()})
}
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -205,27 +207,35 @@ func OAuthCallbackHandler() gin.HandlerFunc {
params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + stateValue + "&id_token=" + authToken.IDToken.Token
cookie.SetSession(c, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
sessionKey := provider + ":" + user.ID
cookie.SetSession(ctx, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil {
params = params + `&refresh_token=` + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
go db.Provider.AddSession(models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(c.Request),
IP: utils.GetIP(c.Request),
})
go func() {
if isSignUp {
utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, provider, user)
} else {
utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, provider, user)
}
db.Provider.AddSession(ctx, models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(ctx.Request),
IP: utils.GetIP(ctx.Request),
})
}()
if strings.Contains(redirectURL, "?") {
redirectURL = redirectURL + "&" + params
} else {
redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&")
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
ctx.Redirect(http.StatusFound, redirectURL)
}
}
@@ -462,8 +472,6 @@ func processAppleUserInfo(code string) (models.User, error) {
return user, fmt.Errorf("invalid apple exchange code: %s", err.Error())
}
fmt.Println("=> token", oauth2Token.AccessToken)
// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
@@ -471,36 +479,39 @@ func processAppleUserInfo(code string) (models.User, error) {
return user, fmt.Errorf("unable to extract id_token")
}
fmt.Println("=> rawIDToken", rawIDToken)
tokenSplit := strings.Split(rawIDToken, ".")
claimsData := tokenSplit[1]
decodedClaimsData, err := crypto.DecryptB64(claimsData)
decodedClaimsData, err := base64.RawURLEncoding.DecodeString(claimsData)
if err != nil {
log.Debug("Failed to decrypt claims data: ", err)
log.Debugf("Failed to decrypt claims %s: %s", claimsData, err.Error())
return user, fmt.Errorf("failed to decrypt claims data: %s", err.Error())
}
fmt.Println("=> decoded claims data", decodedClaimsData)
claims := make(map[string]interface{})
err = json.Unmarshal([]byte(decodedClaimsData), &claims)
err = json.Unmarshal(decodedClaimsData, &claims)
if err != nil {
log.Debug("Failed to unmarshal claims data: ", err)
return user, fmt.Errorf("failed to unmarshal claims data: %s", err.Error())
}
fmt.Println("=> claims", claims)
if val, ok := claims["email"]; !ok {
log.Debug("Failed to extract email from claims")
return user, fmt.Errorf("unable to extract email")
log.Debug("Failed to extract email from claims.")
return user, fmt.Errorf("unable to extract email, please check the scopes enabled for your app. It needs `email`, `name` scopes")
} else {
user.Email = val.(string)
}
if val, ok := claims["name"]; ok {
givenName := val.(string)
user.GivenName = &givenName
nameData := val.(map[string]interface{})
if nameVal, ok := nameData["firstName"]; ok {
givenName := nameVal.(string)
user.GivenName = &givenName
}
if nameVal, ok := nameData["lastName"]; ok {
familyName := nameVal.(string)
user.FamilyName = &familyName
}
}
return user, err

View File

@@ -100,13 +100,13 @@ func OAuthLoginHandler() gin.HandlerFunc {
provider := c.Param("oauth_provider")
isProviderConfigured := true
switch provider {
case constants.SignupMethodGoogle:
case constants.AuthRecipeMethodGoogle:
if oauth.OAuthProviders.GoogleConfig == nil {
log.Debug("Google OAuth provider is not configured")
isProviderConfigured = false
break
}
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGoogle)
err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodGoogle)
if err != nil {
log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{
@@ -115,16 +115,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
return
}
// during the init of OAuthProvider authorizer url might be empty
oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodGoogle
oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodGoogle
url := oauth.OAuthProviders.GoogleConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodGithub:
case constants.AuthRecipeMethodGithub:
if oauth.OAuthProviders.GithubConfig == nil {
log.Debug("Github OAuth provider is not configured")
isProviderConfigured = false
break
}
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGithub)
err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodGithub)
if err != nil {
log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{
@@ -132,16 +132,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
})
return
}
oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodGithub
oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodGithub
url := oauth.OAuthProviders.GithubConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodFacebook:
case constants.AuthRecipeMethodFacebook:
if oauth.OAuthProviders.FacebookConfig == nil {
log.Debug("Facebook OAuth provider is not configured")
isProviderConfigured = false
break
}
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodFacebook)
err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodFacebook)
if err != nil {
log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{
@@ -149,16 +149,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
})
return
}
oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodFacebook
oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodFacebook
url := oauth.OAuthProviders.FacebookConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodLinkedIn:
case constants.AuthRecipeMethodLinkedIn:
if oauth.OAuthProviders.LinkedInConfig == nil {
log.Debug("Linkedin OAuth provider is not configured")
isProviderConfigured = false
break
}
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodLinkedIn)
err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodLinkedIn)
if err != nil {
log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{
@@ -166,16 +166,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
})
return
}
oauth.OAuthProviders.LinkedInConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodLinkedIn
oauth.OAuthProviders.LinkedInConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodLinkedIn
url := oauth.OAuthProviders.LinkedInConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodApple:
case constants.AuthRecipeMethodApple:
if oauth.OAuthProviders.AppleConfig == nil {
log.Debug("Apple OAuth provider is not configured")
isProviderConfigured = false
break
}
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodApple)
err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodApple)
if err != nil {
log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{
@@ -183,8 +183,10 @@ func OAuthLoginHandler() gin.HandlerFunc {
})
return
}
oauth.OAuthProviders.AppleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodApple
url := oauth.OAuthProviders.AppleConfig.AuthCodeURL(oauthStateString, oauth2.SetAuthURLParam("response_mode", "form_post"))
oauth.OAuthProviders.AppleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodApple
// there is scope encoding issue with oauth2 and how apple expects, hence added scope manually
// check: https://github.com/golang/oauth2/issues/449
url := oauth.OAuthProviders.AppleConfig.AuthCodeURL(oauthStateString, oauth2.SetAuthURLParam("response_mode", "form_post")) + "&scope=name email"
c.Redirect(http.StatusTemporaryRedirect, url)
default:
log.Debug("Invalid oauth provider: ", provider)

View File

@@ -12,8 +12,8 @@ import (
"github.com/authorizerdev/authorizer/server/token"
)
// Revoke handler to revoke refresh token
func RevokeHandler() gin.HandlerFunc {
// RevokeRefreshTokenHandler handler to revoke refresh token
func RevokeRefreshTokenHandler() gin.HandlerFunc {
return func(gc *gin.Context) {
var reqBody map[string]string
if err := gc.BindJSON(&reqBody); err != nil {
@@ -56,7 +56,14 @@ func RevokeHandler() gin.HandlerFunc {
return
}
memorystore.Provider.DeleteUserSession(claims["sub"].(string), claims["nonce"].(string))
userID := claims["sub"].(string)
loginMethod := claims["login_method"]
sessionToken := userID
if loginMethod != nil && loginMethod != "" {
sessionToken = loginMethod.(string) + ":" + userID
}
memorystore.Provider.DeleteUserSession(sessionToken, claims["nonce"].(string))
gc.JSON(http.StatusOK, gin.H{
"message": "Token revoked successfully",

View File

@@ -72,6 +72,9 @@ func TokenHandler() gin.HandlerFunc {
var userID string
var roles, scope []string
loginMethod := ""
sessionKey := ""
if isAuthorizationCodeGrant {
if codeVerifier == "" {
@@ -134,8 +137,13 @@ func TokenHandler() gin.HandlerFunc {
userID = claims.Subject
roles = claims.Roles
scope = claims.Scope
loginMethod = claims.LoginMethod
// rollover the session for security
go memorystore.Provider.DeleteUserSession(userID, claims.Nonce)
sessionKey = userID
if loginMethod != "" {
sessionKey = loginMethod + ":" + userID
}
go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
} else {
// validate refresh token
if refreshToken == "" {
@@ -155,6 +163,7 @@ func TokenHandler() gin.HandlerFunc {
})
}
userID = claims["sub"].(string)
loginMethod := claims["login_method"]
rolesInterface := claims["roles"].([]interface{})
scopeInterface := claims["scope"].([]interface{})
for _, v := range rolesInterface {
@@ -163,11 +172,25 @@ func TokenHandler() gin.HandlerFunc {
for _, v := range scopeInterface {
scope = append(scope, v.(string))
}
sessionKey = userID
if loginMethod != nil && loginMethod != "" {
sessionKey = loginMethod.(string) + ":" + sessionKey
}
// remove older refresh token and rotate it for security
go memorystore.Provider.DeleteUserSession(userID, claims["nonce"].(string))
go memorystore.Provider.DeleteUserSession(sessionKey, claims["nonce"].(string))
}
user, err := db.Provider.GetUserByID(userID)
if sessionKey == "" {
log.Debug("Error getting sessionKey: ", sessionKey, loginMethod)
gc.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"error_description": "User not found",
})
return
}
user, err := db.Provider.GetUserByID(gc, userID)
if err != nil {
log.Debug("Error getting user: ", err)
gc.JSON(http.StatusUnauthorized, gin.H{
@@ -177,7 +200,7 @@ func TokenHandler() gin.HandlerFunc {
return
}
authToken, err := token.CreateAuthToken(gc, user, roles, scope)
authToken, err := token.CreateAuthToken(gc, user, roles, scope, loginMethod)
if err != nil {
log.Debug("Error creating auth token: ", err)
gc.JSON(http.StatusUnauthorized, gin.H{
@@ -186,8 +209,8 @@ func TokenHandler() gin.HandlerFunc {
})
return
}
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
cookie.SetSession(gc, authToken.FingerPrintHash)
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -205,7 +228,7 @@ func TokenHandler() gin.HandlerFunc {
if authToken.RefreshToken != nil {
res["refresh_token"] = authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
gc.JSON(http.StatusOK, res)

View File

@@ -31,7 +31,7 @@ func UserInfoHandler() gin.HandlerFunc {
}
userID := claims["sub"].(string)
user, err := db.Provider.GetUserByID(userID)
user, err := db.Provider.GetUserByID(gc, userID)
if err != nil {
log.Debug("Error getting user: ", err)
gc.JSON(http.StatusUnauthorized, gin.H{

View File

@@ -33,7 +33,7 @@ func VerifyEmailHandler() gin.HandlerFunc {
return
}
verificationRequest, err := db.Provider.GetVerificationRequestByToken(tokenInQuery)
verificationRequest, err := db.Provider.GetVerificationRequestByToken(c, tokenInQuery)
if err != nil {
log.Debug("Error getting verification request: ", err)
errorRes["error_description"] = err.Error()
@@ -58,7 +58,7 @@ func VerifyEmailHandler() gin.HandlerFunc {
return
}
user, err := db.Provider.GetUserByEmail(verificationRequest.Email)
user, err := db.Provider.GetUserByEmail(c, verificationRequest.Email)
if err != nil {
log.Debug("Error getting user: ", err)
errorRes["error_description"] = err.Error()
@@ -66,14 +66,16 @@ func VerifyEmailHandler() gin.HandlerFunc {
return
}
isSignUp := false
// update email_verified_at in users table
if user.EmailVerifiedAt == nil {
now := time.Now().Unix()
user.EmailVerifiedAt = &now
db.Provider.UpdateUser(user)
isSignUp = true
db.Provider.UpdateUser(c, user)
}
// delete from verification table
db.Provider.DeleteVerificationRequest(verificationRequest)
db.Provider.DeleteVerificationRequest(c, verificationRequest)
state := strings.TrimSpace(c.Query("state"))
redirectURL := strings.TrimSpace(c.Query("redirect_uri"))
@@ -92,7 +94,11 @@ func VerifyEmailHandler() gin.HandlerFunc {
} else {
scope = strings.Split(scopeString, " ")
}
authToken, err := token.CreateAuthToken(c, user, roles, scope)
loginMethod := constants.AuthRecipeMethodBasicAuth
if verificationRequest.Identifier == constants.VerificationTypeMagicLinkLogin {
loginMethod = constants.AuthRecipeMethodMagicLinkLogin
}
authToken, err := token.CreateAuthToken(c, user, roles, scope, loginMethod)
if err != nil {
log.Debug("Error creating auth token: ", err)
errorRes["error_description"] = err.Error()
@@ -107,13 +113,14 @@ func VerifyEmailHandler() gin.HandlerFunc {
params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + state + "&id_token=" + authToken.IDToken.Token
sessionKey := loginMethod + ":" + user.ID
cookie.SetSession(c, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil {
params = params + `&refresh_token=` + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
if redirectURL == "" {
@@ -126,11 +133,19 @@ func VerifyEmailHandler() gin.HandlerFunc {
redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&")
}
go db.Provider.AddSession(models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(c.Request),
IP: utils.GetIP(c.Request),
})
go func() {
if isSignUp {
utils.RegisterEvent(c, constants.UserSignUpWebhookEvent, loginMethod, user)
} else {
utils.RegisterEvent(c, constants.UserLoginWebhookEvent, loginMethod, user)
}
db.Provider.AddSession(c, models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(c.Request),
IP: utils.GetIP(c.Request),
})
}()
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}

View File

@@ -30,6 +30,7 @@ func InitMemStore() error {
constants.EnvKeyDisableEmailVerification: false,
constants.EnvKeyDisableLoginPage: false,
constants.EnvKeyDisableSignUp: false,
constants.EnvKeyDisableStrongPassword: false,
}
requiredEnvs := RequiredEnvStoreObj.GetRequiredEnv()

View File

@@ -26,24 +26,34 @@ func (c *provider) GetUserSession(userId, sessionToken string) (string, error) {
// DeleteAllUserSessions deletes all the user sessions from in-memory store.
func (c *provider) DeleteAllUserSessions(userId string) error {
if os.Getenv("ENV") != constants.TestEnv {
c.mutex.Lock()
defer c.mutex.Unlock()
namespaces := []string{
constants.AuthRecipeMethodBasicAuth,
constants.AuthRecipeMethodMagicLinkLogin,
constants.AuthRecipeMethodApple,
constants.AuthRecipeMethodFacebook,
constants.AuthRecipeMethodGithub,
constants.AuthRecipeMethodGoogle,
constants.AuthRecipeMethodLinkedIn,
}
for _, namespace := range namespaces {
c.sessionStore.RemoveAll(namespace + ":" + userId)
}
c.sessionStore.RemoveAll(userId)
return nil
}
// DeleteUserSession deletes the user session from the in-memory store.
func (c *provider) DeleteUserSession(userId, sessionToken string) error {
if os.Getenv("ENV") != constants.TestEnv {
c.mutex.Lock()
defer c.mutex.Unlock()
}
c.sessionStore.Remove(userId, sessionToken)
return nil
}
// DeleteSessionForNamespace to delete session for a given namespace example google,github
func (c *provider) DeleteSessionForNamespace(namespace string) error {
c.sessionStore.RemoveByNamespace(namespace)
return nil
}
// SetState sets the state in the in-memory store.
func (c *provider) SetState(key, state string) error {
if os.Getenv("ENV") != constants.TestEnv {

View File

@@ -1,10 +1,7 @@
package stores
import (
"os"
"sync"
"github.com/authorizerdev/authorizer/server/constants"
)
// EnvStore struct to store the env variables
@@ -23,12 +20,10 @@ func NewEnvStore() *EnvStore {
// UpdateEnvStore to update the whole env store object
func (e *EnvStore) UpdateStore(store map[string]interface{}) {
if os.Getenv("ENV") != constants.TestEnv {
e.mutex.Lock()
defer e.mutex.Unlock()
}
// just override the keys + new keys
e.mutex.Lock()
defer e.mutex.Unlock()
// just override the keys + new keys
for key, value := range store {
e.store[key] = value
}
@@ -46,9 +41,8 @@ func (e *EnvStore) Get(key string) interface{} {
// Set sets the value of the key in env store
func (e *EnvStore) Set(key string, value interface{}) {
if os.Getenv("ENV") != constants.TestEnv {
e.mutex.Lock()
defer e.mutex.Unlock()
}
e.mutex.Lock()
defer e.mutex.Unlock()
e.store[key] = value
}

View File

@@ -1,10 +1,8 @@
package stores
import (
"os"
"strings"
"sync"
"github.com/authorizerdev/authorizer/server/constants"
)
// SessionStore struct to store the env variables
@@ -28,10 +26,9 @@ func (s *SessionStore) Get(key, subKey string) string {
// Set sets the value of the key in state store
func (s *SessionStore) Set(key string, subKey, value string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock()
defer s.mutex.Unlock()
}
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.store[key]; !ok {
s.store[key] = make(map[string]string)
}
@@ -40,19 +37,15 @@ func (s *SessionStore) Set(key string, subKey, value string) {
// RemoveAll all values for given key
func (s *SessionStore) RemoveAll(key string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock()
defer s.mutex.Unlock()
}
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.store, key)
}
// Remove value for given key and subkey
func (s *SessionStore) Remove(key, subKey string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock()
defer s.mutex.Unlock()
}
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.store[key]; ok {
delete(s.store[key], subKey)
}
@@ -65,3 +58,15 @@ func (s *SessionStore) GetAll(key string) map[string]string {
}
return s.store[key]
}
// RemoveByNamespace to delete session for a given namespace example google,github
func (s *SessionStore) RemoveByNamespace(namespace string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
for key := range s.store {
if strings.Contains(key, namespace+":") {
delete(s.store, key)
}
}
return nil
}

View File

@@ -1,10 +1,7 @@
package stores
import (
"os"
"sync"
"github.com/authorizerdev/authorizer/server/constants"
)
// StateStore struct to store the env variables
@@ -28,19 +25,16 @@ func (s *StateStore) Get(key string) string {
// Set sets the value of the key in state store
func (s *StateStore) Set(key string, value string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock()
defer s.mutex.Unlock()
}
s.mutex.Lock()
defer s.mutex.Unlock()
s.store[key] = value
}
// Remove removes the key from state store
func (s *StateStore) Remove(key string) {
if os.Getenv("ENV") != constants.TestEnv {
s.mutex.Lock()
defer s.mutex.Unlock()
}
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.store, key)
}

View File

@@ -12,6 +12,8 @@ type Provider interface {
DeleteUserSession(userId, key string) error
// DeleteAllSessions deletes all the sessions from the session store
DeleteAllUserSessions(userId string) error
// DeleteSessionForNamespace deletes the session for a given namespace
DeleteSessionForNamespace(namespace string) error
// SetState sets the login state (key, value form) in the session store
SetState(key, state string) error

View File

@@ -63,11 +63,48 @@ func (c *provider) DeleteUserSession(userId, key string) error {
// DeleteAllUserSessions deletes all the user session from redis
func (c *provider) DeleteAllUserSessions(userID string) error {
err := c.store.Del(c.ctx, userID).Err()
if err != nil {
log.Debug("Error deleting all user sessions from redis: ", err)
return err
namespaces := []string{
constants.AuthRecipeMethodBasicAuth,
constants.AuthRecipeMethodMagicLinkLogin,
constants.AuthRecipeMethodApple,
constants.AuthRecipeMethodFacebook,
constants.AuthRecipeMethodGithub,
constants.AuthRecipeMethodGoogle,
constants.AuthRecipeMethodLinkedIn,
}
for _, namespace := range namespaces {
err := c.store.Del(c.ctx, namespace+":"+userID).Err()
if err != nil {
log.Debug("Error deleting all user sessions from redis: ", err)
return err
}
}
return nil
}
// DeleteSessionForNamespace to delete session for a given namespace example google,github
func (c *provider) DeleteSessionForNamespace(namespace string) error {
var cursor uint64
for {
keys := []string{}
keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result()
if err != nil {
log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error())
return err
}
for _, key := range keys {
err := c.store.Del(c.ctx, key).Err()
if err != nil {
log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error())
return err
}
}
if cursor == 0 { // no more keys
break
}
}
return nil
}
@@ -123,7 +160,7 @@ func (c *provider) GetEnvStore() (map[string]interface{}, error) {
return nil, err
}
for key, value := range data {
if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp {
if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp || key == constants.EnvKeyDisableStrongPassword {
boolValue, err := strconv.ParseBool(value)
if err != nil {
return res, err

View File

@@ -130,7 +130,6 @@ func InitOAuth() error {
AuthURL: "https://appleid.apple.com/auth/authorize",
TokenURL: "https://appleid.apple.com/auth/token",
},
Scopes: []string{"name", "email"},
}
}

View File

@@ -0,0 +1,54 @@
package resolvers
import (
"context"
"encoding/json"
"fmt"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/authorizerdev/authorizer/server/validators"
log "github.com/sirupsen/logrus"
)
// AddWebhookResolver resolver for add webhook mutation
func AddWebhookResolver(ctx context.Context, params model.AddWebhookRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if !validators.IsValidWebhookEventName(params.EventName) {
log.Debug("Invalid Event Name: ", params.EventName)
return nil, fmt.Errorf("invalid event name %s", params.EventName)
}
headerBytes, err := json.Marshal(params.Headers)
if err != nil {
return nil, err
}
_, err = db.Provider.AddWebhook(ctx, models.Webhook{
EventName: params.EventName,
EndPoint: params.Endpoint,
Enabled: params.Enabled,
Headers: string(headerBytes),
})
if err != nil {
log.Debug("Failed to add webhook: ", err)
return nil, err
}
return &model.Response{
Message: `Webhook added successfully`,
}, nil
}

View File

@@ -58,7 +58,7 @@ func AdminSignupResolver(ctx context.Context, params model.AdminSignupInput) (*m
return res, err
}
env, err := db.Provider.GetEnv()
env, err := db.Provider.GetEnv(ctx)
if err != nil {
log.Debug("Failed to get env: ", err)
return res, err
@@ -71,7 +71,7 @@ func AdminSignupResolver(ctx context.Context, params model.AdminSignupInput) (*m
}
env.EnvData = envData
if _, err := db.Provider.UpdateEnv(env); err != nil {
if _, err := db.Provider.UpdateEnv(ctx, env); err != nil {
log.Debug("Failed to update env: ", err)
return res, err
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore"
@@ -32,15 +33,13 @@ func DeleteUserResolver(ctx context.Context, params model.DeleteUserInput) (*mod
"email": params.Email,
})
user, err := db.Provider.GetUserByEmail(params.Email)
user, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil {
log.Debug("Failed to get user from DB: ", err)
return res, err
}
go memorystore.Provider.DeleteAllUserSessions(user.ID)
err = db.Provider.DeleteUser(user)
err = db.Provider.DeleteUser(ctx, user)
if err != nil {
log.Debug("Failed to delete user: ", err)
return res, err
@@ -50,5 +49,10 @@ func DeleteUserResolver(ctx context.Context, params model.DeleteUserInput) (*mod
Message: `user deleted successfully`,
}
go func() {
memorystore.Provider.DeleteAllUserSessions(user.ID)
utils.RegisterEvent(ctx, constants.UserDeletedWebhookEvent, "", user)
}()
return res, nil
}

View File

@@ -0,0 +1,49 @@
package resolvers
import (
"context"
"fmt"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
log "github.com/sirupsen/logrus"
)
// DeleteWebhookResolver resolver to delete webhook and its relevant logs
func DeleteWebhookResolver(ctx context.Context, params model.WebhookRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if params.ID == "" {
log.Debug("webhookID is required")
return nil, fmt.Errorf("webhook ID required")
}
log := log.WithField("webhook_id", params.ID)
webhook, err := db.Provider.GetWebhookByID(ctx, params.ID)
if err != nil {
log.Debug("failed to get webhook: ", err)
return nil, err
}
err = db.Provider.DeleteWebhook(ctx, webhook)
if err != nil {
log.Debug("failed to delete webhook: ", err)
return nil, err
}
return &model.Response{
Message: "Webhook deleted successfully",
}, nil
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
@@ -31,7 +32,7 @@ func EnableAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
"user_id": params.UserID,
})
user, err := db.Provider.GetUserByID(params.UserID)
user, err := db.Provider.GetUserByID(ctx, params.UserID)
if err != nil {
log.Debug("Failed to get user from DB: ", err)
return res, err
@@ -39,7 +40,7 @@ func EnableAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
user.RevokedTimestamp = nil
user, err = db.Provider.UpdateUser(user)
user, err = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
return res, err
@@ -49,5 +50,7 @@ func EnableAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
Message: `user access enabled successfully`,
}
go utils.RegisterEvent(ctx, constants.UserAccessEnabledWebhookEvent, "", user)
return res, nil
}

View File

@@ -168,6 +168,7 @@ func EnvResolver(ctx context.Context) (*model.Env, error) {
res.DisableMagicLinkLogin = store[constants.EnvKeyDisableMagicLinkLogin].(bool)
res.DisableLoginPage = store[constants.EnvKeyDisableLoginPage].(bool)
res.DisableSignUp = store[constants.EnvKeyDisableSignUp].(bool)
res.DisableStrongPassword = store[constants.EnvKeyDisableStrongPassword].(bool)
return res, nil
}

View File

@@ -49,7 +49,7 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu
log := log.WithFields(log.Fields{
"email": params.Email,
})
_, err = db.Provider.GetUserByEmail(params.Email)
_, err = db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil {
log.Debug("User not found: ", err)
return res, fmt.Errorf(`user with this email not found`)
@@ -71,7 +71,7 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu
log.Debug("Failed to create verification token", err)
return res, err
}
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{
_, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken,
Identifier: constants.VerificationTypeForgotPassword,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),

View File

@@ -70,7 +70,7 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput)
// for each emails check if emails exists in db
newEmails := []string{}
for _, email := range emails {
_, err := db.Provider.GetUserByEmail(email)
_, err := db.Provider.GetUserByEmail(ctx, email)
if err != nil {
log.Debugf("User with %s email not found, so inviting user", email)
newEmails = append(newEmails, email)
@@ -129,24 +129,24 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput)
// use magic link login if that option is on
if !isMagicLinkLoginDisabled {
user.SignupMethods = constants.SignupMethodMagicLinkLogin
user.SignupMethods = constants.AuthRecipeMethodMagicLinkLogin
verificationRequest.Identifier = constants.VerificationTypeMagicLinkLogin
} else {
// use basic authentication if that option is on
user.SignupMethods = constants.SignupMethodBasicAuth
user.SignupMethods = constants.AuthRecipeMethodBasicAuth
verificationRequest.Identifier = constants.VerificationTypeForgotPassword
verifyEmailURL = appURL + "/setup-password"
}
user, err = db.Provider.AddUser(user)
user, err = db.Provider.AddUser(ctx, user)
if err != nil {
log.Debugf("Error adding user: %s, err: %v", email, err)
return nil, err
}
_, err = db.Provider.AddVerificationRequest(verificationRequest)
_, err = db.Provider.AddVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debugf("Error adding verification request: %s, err: %v", email, err)
return nil, err

View File

@@ -45,7 +45,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
"email": params.Email,
})
params.Email = strings.ToLower(params.Email)
user, err := db.Provider.GetUserByEmail(params.Email)
user, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil {
log.Debug("Failed to get user by email: ", err)
return res, fmt.Errorf(`user with this email not found`)
@@ -56,7 +56,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
return res, fmt.Errorf(`user access has been revoked`)
}
if !strings.Contains(user.SignupMethods, constants.SignupMethodBasicAuth) {
if !strings.Contains(user.SignupMethods, constants.AuthRecipeMethodBasicAuth) {
log.Debug("User signup method is not basic auth")
return res, fmt.Errorf(`user has not signed up email & password`)
}
@@ -97,7 +97,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
scope = params.Scope
}
authToken, err := token.CreateAuthToken(gc, user, roles, scope)
authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth)
if err != nil {
log.Debug("Failed to create auth token", err)
return res, err
@@ -117,19 +117,23 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
}
cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
go db.Provider.AddSession(models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(gc.Request),
IP: utils.GetIP(gc.Request),
})
go func() {
utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodBasicAuth, user)
db.Provider.AddSession(ctx, models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(gc.Request),
IP: utils.GetIP(gc.Request),
})
}()
return res, nil
}

View File

@@ -41,7 +41,12 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) {
return nil, err
}
memorystore.Provider.DeleteUserSession(sessionData.Subject, sessionData.Nonce)
sessionKey := sessionData.Subject
if sessionData.LoginMethod != "" {
sessionKey = sessionData.LoginMethod + ":" + sessionData.Subject
}
memorystore.Provider.DeleteUserSession(sessionKey, sessionData.Nonce)
cookie.DeleteSession(gc)
res := &model.Response{

View File

@@ -59,7 +59,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
}
// find user with email
existingUser, err := db.Provider.GetUserByEmail(params.Email)
existingUser, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil {
isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp)
if err != nil {
@@ -70,7 +70,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
return res, fmt.Errorf(`signup is disabled for this instance`)
}
user.SignupMethods = constants.SignupMethodMagicLinkLogin
user.SignupMethods = constants.AuthRecipeMethodMagicLinkLogin
// define roles for new user
if len(params.Roles) > 0 {
// check if roles exists
@@ -99,7 +99,8 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
}
user.Roles = strings.Join(inputRoles, ",")
user, _ = db.Provider.AddUser(user)
user, _ = db.Provider.AddUser(ctx, user)
go utils.RegisterEvent(ctx, constants.UserCreatedWebhookEvent, constants.AuthRecipeMethodMagicLinkLogin, user)
} else {
user = existingUser
// There multiple scenarios with roles here in magic link login
@@ -158,12 +159,12 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
}
signupMethod := existingUser.SignupMethods
if !strings.Contains(signupMethod, constants.SignupMethodMagicLinkLogin) {
signupMethod = signupMethod + "," + constants.SignupMethodMagicLinkLogin
if !strings.Contains(signupMethod, constants.AuthRecipeMethodMagicLinkLogin) {
signupMethod = signupMethod + "," + constants.AuthRecipeMethodMagicLinkLogin
}
user.SignupMethods = signupMethod
user, _ = db.Provider.UpdateUser(user)
user, _ = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
}
@@ -205,7 +206,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
if err != nil {
log.Debug("Failed to create verification token: ", err)
}
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{
_, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken,
Identifier: verificationType,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),

View File

@@ -101,6 +101,12 @@ func MetaResolver(ctx context.Context) (*model.Meta, error) {
isSignUpDisabled = true
}
isStrongPasswordDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableStrongPassword)
if err != nil {
log.Debug("Failed to get Disable Signup from environment variable", err)
isSignUpDisabled = true
}
metaInfo := model.Meta{
Version: constants.VERSION,
ClientID: clientID,
@@ -113,6 +119,7 @@ func MetaResolver(ctx context.Context) (*model.Meta, error) {
IsEmailVerificationEnabled: !isEmailVerificationDisabled,
IsMagicLinkLoginEnabled: !isMagicLinkLoginDisabled,
IsSignUpEnabled: !isSignUpDisabled,
IsStrongPasswordEnabled: !isStrongPasswordDisabled,
}
return &metaInfo, nil
}

View File

@@ -38,7 +38,7 @@ func ProfileResolver(ctx context.Context) (*model.User, error) {
log := log.WithFields(log.Fields{
"user_id": userID,
})
user, err := db.Provider.GetUserByID(userID)
user, err := db.Provider.GetUserByID(ctx, userID)
if err != nil {
log.Debug("Failed to get user: ", err)
return res, err

View File

@@ -39,14 +39,14 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma
return res, fmt.Errorf("invalid identifier")
}
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(params.Email, params.Identifier)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, params.Email, params.Identifier)
if err != nil {
log.Debug("Failed to get verification request: ", err)
return res, fmt.Errorf(`verification request not found`)
}
// delete current verification and create new one
err = db.Provider.DeleteVerificationRequest(verificationRequest)
err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debug("Failed to delete verification request: ", err)
}
@@ -62,7 +62,7 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma
if err != nil {
log.Debug("Failed to create verification token: ", err)
}
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{
_, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken,
Identifier: params.Identifier,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),

View File

@@ -39,7 +39,7 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput)
return res, fmt.Errorf(`basic authentication is disabled for this instance`)
}
verificationRequest, err := db.Provider.GetVerificationRequestByToken(params.Token)
verificationRequest, err := db.Provider.GetVerificationRequestByToken(ctx, params.Token)
if err != nil {
log.Debug("Failed to get verification request: ", err)
return res, fmt.Errorf(`invalid token`)
@@ -50,9 +50,9 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput)
return res, fmt.Errorf(`passwords don't match`)
}
if !validators.IsValidPassword(params.Password) {
if err := validators.IsValidPassword(params.Password); err != nil {
log.Debug("Invalid password")
return res, fmt.Errorf(`password is not valid. It needs to be at least 6 characters long and contain at least one number, one uppercase letter, one lowercase letter and one special character`)
return res, err
}
// verify if token exists in db
@@ -72,7 +72,7 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput)
log := log.WithFields(log.Fields{
"email": email,
})
user, err := db.Provider.GetUserByEmail(email)
user, err := db.Provider.GetUserByEmail(ctx, email)
if err != nil {
log.Debug("Failed to get user: ", err)
return res, err
@@ -82,8 +82,8 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput)
user.Password = &password
signupMethod := user.SignupMethods
if !strings.Contains(signupMethod, constants.SignupMethodBasicAuth) {
signupMethod = signupMethod + "," + constants.SignupMethodBasicAuth
if !strings.Contains(signupMethod, constants.AuthRecipeMethodBasicAuth) {
signupMethod = signupMethod + "," + constants.AuthRecipeMethodBasicAuth
}
user.SignupMethods = signupMethod
@@ -94,13 +94,13 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput)
}
// delete from verification table
err = db.Provider.DeleteVerificationRequest(verificationRequest)
err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debug("Failed to delete verification request: ", err)
return res, err
}
_, err = db.Provider.UpdateUser(user)
_, err = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
return res, err

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore"
@@ -32,7 +33,7 @@ func RevokeAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
log := log.WithFields(log.Fields{
"user_id": params.UserID,
})
user, err := db.Provider.GetUserByID(params.UserID)
user, err := db.Provider.GetUserByID(ctx, params.UserID)
if err != nil {
log.Debug("Failed to get user by ID: ", err)
return res, err
@@ -41,13 +42,16 @@ func RevokeAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
now := time.Now().Unix()
user.RevokedTimestamp = &now
user, err = db.Provider.UpdateUser(user)
user, err = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
return res, err
}
go memorystore.Provider.DeleteAllUserSessions(user.ID)
go func() {
memorystore.Provider.DeleteAllUserSessions(user.ID)
utils.RegisterEvent(ctx, constants.UserAccessRevokedWebhookEvent, "", user)
}()
res = &model.Response{
Message: `user access revoked successfully`,

View File

@@ -46,7 +46,7 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod
"user_id": userID,
})
user, err := db.Provider.GetUserByID(userID)
user, err := db.Provider.GetUserByID(ctx, userID)
if err != nil {
return res, err
}
@@ -70,14 +70,18 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod
scope = params.Scope
}
authToken, err := token.CreateAuthToken(gc, user, claimRoles, scope)
authToken, err := token.CreateAuthToken(gc, user, claimRoles, scope, claims.LoginMethod)
if err != nil {
log.Debug("Failed to create auth token: ", err)
return res, err
}
// rollover the session for security
go memorystore.Provider.DeleteUserSession(userID, claims.Nonce)
sessionKey := userID
if claims.LoginMethod != "" {
sessionKey = claims.LoginMethod + ":" + userID
}
go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
if expiresIn <= 0 {
@@ -93,12 +97,12 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod
}
cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
return res, nil
}

View File

@@ -58,9 +58,9 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
return res, fmt.Errorf(`password and confirm password does not match`)
}
if !validators.IsValidPassword(params.Password) {
if err := validators.IsValidPassword(params.Password); err != nil {
log.Debug("Invalid password")
return res, fmt.Errorf(`password is not valid. It needs to be at least 6 characters long and contain at least one number, one uppercase letter, one lowercase letter and one special character`)
return res, err
}
params.Email = strings.ToLower(params.Email)
@@ -74,7 +74,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
"email": params.Email,
})
// find user with email
existingUser, err := db.Provider.GetUserByEmail(params.Email)
existingUser, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil {
log.Debug("Failed to get user by email: ", err)
}
@@ -157,7 +157,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
user.Picture = params.Picture
}
user.SignupMethods = constants.SignupMethodBasicAuth
user.SignupMethods = constants.AuthRecipeMethodBasicAuth
isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification)
if err != nil {
log.Debug("Error getting email verification disabled: ", err)
@@ -167,7 +167,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
now := time.Now().Unix()
user.EmailVerifiedAt = &now
}
user, err = db.Provider.AddUser(user)
user, err = db.Provider.AddUser(ctx, user)
if err != nil {
log.Debug("Failed to add user: ", err)
return res, err
@@ -193,7 +193,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
log.Debug("Failed to create verification token: ", err)
return res, err
}
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{
_, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken,
Identifier: verificationType,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),
@@ -207,7 +207,10 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
}
// exec it as go routin so that we can reduce the api latency
go email.SendVerificationMail(params.Email, verificationToken, hostname)
go func() {
email.SendVerificationMail(params.Email, verificationToken, hostname)
utils.RegisterEvent(ctx, constants.UserCreatedWebhookEvent, constants.AuthRecipeMethodBasicAuth, user)
}()
res = &model.AuthResponse{
Message: `Verification email has been sent. Please check your inbox`,
@@ -219,18 +222,12 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
scope = params.Scope
}
authToken, err := token.CreateAuthToken(gc, user, roles, scope)
authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth)
if err != nil {
log.Debug("Failed to create auth token: ", err)
return res, err
}
go db.Provider.AddSession(models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(gc.Request),
IP: utils.GetIP(gc.Request),
})
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
if expiresIn <= 0 {
expiresIn = 1
@@ -243,14 +240,24 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
User: userToReturn,
}
sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID
cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
}
go func() {
utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, constants.AuthRecipeMethodBasicAuth, user)
db.Provider.AddSession(ctx, models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(gc.Request),
IP: utils.GetIP(gc.Request),
})
}()
}
return res, nil

View File

@@ -0,0 +1,109 @@
package resolvers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"time"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/authorizerdev/authorizer/server/validators"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
)
// TestEndpointResolver resolver to test webhook endpoints
func TestEndpointResolver(ctx context.Context, params model.TestEndpointRequest) (*model.TestEndpointResponse, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if !validators.IsValidWebhookEventName(params.EventName) {
log.Debug("Invalid event name: ", params.EventName)
return nil, fmt.Errorf("invalid event_name %s", params.EventName)
}
user := model.User{
ID: uuid.NewString(),
Email: "test_endpoint@foo.com",
EmailVerified: true,
SignupMethods: constants.AuthRecipeMethodMagicLinkLogin,
GivenName: utils.NewStringRef("Foo"),
FamilyName: utils.NewStringRef("Bar"),
}
userBytes, err := json.Marshal(user)
if err != nil {
log.Debug("error marshalling user obj: ", err)
return nil, err
}
userMap := map[string]interface{}{}
err = json.Unmarshal(userBytes, &userMap)
if err != nil {
log.Debug("error un-marshalling user obj: ", err)
return nil, err
}
reqBody := map[string]interface{}{
"event_name": constants.UserLoginWebhookEvent,
"user": userMap,
}
if params.EventName == constants.UserLoginWebhookEvent {
reqBody["auth_recipe"] = constants.AuthRecipeMethodMagicLinkLogin
}
requestBody, err := json.Marshal(reqBody)
if err != nil {
log.Debug("error marshalling requestBody obj: ", err)
return nil, err
}
req, err := http.NewRequest("POST", params.Endpoint, bytes.NewBuffer(requestBody))
if err != nil {
log.Debug("error creating post request: ", err)
return nil, err
}
req.Header.Set("Content-Type", "application/json")
for key, val := range params.Headers {
req.Header.Set(key, val.(string))
}
client := &http.Client{Timeout: time.Second * 30}
resp, err := client.Do(req)
if err != nil {
log.Debug("error making request: ", err)
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Debug("error reading response: ", err)
return nil, err
}
response := map[string]interface{}{}
if err := json.Unmarshal(body, &response); err != nil {
log.Debug("error un-marshalling response: ", err)
return nil, err
}
statusCode := int64(resp.StatusCode)
return &model.TestEndpointResponse{
HTTPStatus: &statusCode,
Response: response,
}, nil
}

View File

@@ -21,6 +21,54 @@ import (
"github.com/authorizerdev/authorizer/server/utils"
)
// check if login methods have been disabled
// remove the session tokens for those methods
func clearSessionIfRequired(currentData, updatedData map[string]interface{}) {
isCurrentBasicAuthEnabled := !currentData[constants.EnvKeyDisableBasicAuthentication].(bool)
isCurrentMagicLinkLoginEnabled := !currentData[constants.EnvKeyDisableMagicLinkLogin].(bool)
isCurrentAppleLoginEnabled := currentData[constants.EnvKeyAppleClientID] != nil && currentData[constants.EnvKeyAppleClientSecret] != nil && currentData[constants.EnvKeyAppleClientID].(string) != "" && currentData[constants.EnvKeyAppleClientSecret].(string) != ""
isCurrentFacebookLoginEnabled := currentData[constants.EnvKeyFacebookClientID] != nil && currentData[constants.EnvKeyFacebookClientSecret] != nil && currentData[constants.EnvKeyFacebookClientID].(string) != "" && currentData[constants.EnvKeyFacebookClientSecret].(string) != ""
isCurrentGoogleLoginEnabled := currentData[constants.EnvKeyGoogleClientID] != nil && currentData[constants.EnvKeyGoogleClientSecret] != nil && currentData[constants.EnvKeyGoogleClientID].(string) != "" && currentData[constants.EnvKeyGoogleClientSecret].(string) != ""
isCurrentGithubLoginEnabled := currentData[constants.EnvKeyGithubClientID] != nil && currentData[constants.EnvKeyGithubClientSecret] != nil && currentData[constants.EnvKeyGithubClientID].(string) != "" && currentData[constants.EnvKeyGithubClientSecret].(string) != ""
isCurrentLinkedInLoginEnabled := currentData[constants.EnvKeyLinkedInClientID] != nil && currentData[constants.EnvKeyLinkedInClientSecret] != nil && currentData[constants.EnvKeyLinkedInClientID].(string) != "" && currentData[constants.EnvKeyLinkedInClientSecret].(string) != ""
isUpdatedBasicAuthEnabled := !updatedData[constants.EnvKeyDisableBasicAuthentication].(bool)
isUpdatedMagicLinkLoginEnabled := !updatedData[constants.EnvKeyDisableMagicLinkLogin].(bool)
isUpdatedAppleLoginEnabled := updatedData[constants.EnvKeyAppleClientID] != nil && updatedData[constants.EnvKeyAppleClientSecret] != nil && updatedData[constants.EnvKeyAppleClientID].(string) != "" && updatedData[constants.EnvKeyAppleClientSecret].(string) != ""
isUpdatedFacebookLoginEnabled := updatedData[constants.EnvKeyFacebookClientID] != nil && updatedData[constants.EnvKeyFacebookClientSecret] != nil && updatedData[constants.EnvKeyFacebookClientID].(string) != "" && updatedData[constants.EnvKeyFacebookClientSecret].(string) != ""
isUpdatedGoogleLoginEnabled := updatedData[constants.EnvKeyGoogleClientID] != nil && updatedData[constants.EnvKeyGoogleClientSecret] != nil && updatedData[constants.EnvKeyGoogleClientID].(string) != "" && updatedData[constants.EnvKeyGoogleClientSecret].(string) != ""
isUpdatedGithubLoginEnabled := updatedData[constants.EnvKeyGithubClientID] != nil && updatedData[constants.EnvKeyGithubClientSecret] != nil && updatedData[constants.EnvKeyGithubClientID].(string) != "" && updatedData[constants.EnvKeyGithubClientSecret].(string) != ""
isUpdatedLinkedInLoginEnabled := updatedData[constants.EnvKeyLinkedInClientID] != nil && updatedData[constants.EnvKeyLinkedInClientSecret] != nil && updatedData[constants.EnvKeyLinkedInClientID].(string) != "" && updatedData[constants.EnvKeyLinkedInClientSecret].(string) != ""
if isCurrentBasicAuthEnabled && !isUpdatedBasicAuthEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodBasicAuth)
}
if isCurrentMagicLinkLoginEnabled && !isUpdatedMagicLinkLoginEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodMagicLinkLogin)
}
if isCurrentAppleLoginEnabled && !isUpdatedAppleLoginEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodApple)
}
if isCurrentFacebookLoginEnabled && !isUpdatedFacebookLoginEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodFacebook)
}
if isCurrentGoogleLoginEnabled && !isUpdatedGoogleLoginEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodGoogle)
}
if isCurrentGithubLoginEnabled && !isUpdatedGithubLoginEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodGithub)
}
if isCurrentLinkedInLoginEnabled && !isUpdatedLinkedInLoginEnabled {
memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodLinkedIn)
}
}
// UpdateEnvResolver is a resolver for update config mutation
// This is admin only mutation
func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model.Response, error) {
@@ -37,12 +85,19 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model
return res, fmt.Errorf("unauthorized")
}
updatedData, err := memorystore.Provider.GetEnvStore()
currentData, err := memorystore.Provider.GetEnvStore()
if err != nil {
log.Debug("Failed to get env store: ", err)
return res, err
}
// clone currentData in new var
// that will be updated based on the req
updatedData := make(map[string]interface{})
for key, val := range currentData {
updatedData[key] = val
}
isJWTUpdated := false
algo := updatedData[constants.EnvKeyJwtType].(string)
if params.JwtType != nil {
@@ -210,6 +265,8 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model
}
}
go clearSessionIfRequired(currentData, updatedData)
// Update local store
memorystore.Provider.UpdateEnvStore(updatedData)
jwk, err := crypto.GenerateJWKBasedOnEnv()
@@ -224,19 +281,13 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model
return res, err
}
// TODO check how to update session store based on env change.
// err = sessionstore.InitSession()
// if err != nil {
// log.Debug("Failed to init session store: ", err)
// return res, err
// }
err = oauth.InitOAuth()
if err != nil {
return res, err
}
// Fetch the current db store and update it
env, err := db.Provider.GetEnv()
env, err := db.Provider.GetEnv(ctx)
if err != nil {
log.Debug("Failed to get env: ", err)
return res, err
@@ -263,7 +314,7 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model
}
env.EnvData = encryptedConfig
_, err = db.Provider.UpdateEnv(env)
_, err = db.Provider.UpdateEnv(ctx, env)
if err != nil {
log.Debug("Failed to update env: ", err)
return res, err

View File

@@ -55,7 +55,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput)
"user_id": userID,
})
user, err := db.Provider.GetUserByID(userID)
user, err := db.Provider.GetUserByID(ctx, userID)
if err != nil {
log.Debug("Failed to get user by id: ", err)
return res, err
@@ -135,7 +135,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput)
return res, fmt.Errorf("invalid new email address")
}
// check if user with new email exists
_, err := db.Provider.GetUserByEmail(newEmail)
_, err := db.Provider.GetUserByEmail(ctx, newEmail)
// err = nil means user exists
if err == nil {
log.Debug("Failed to get user by email: ", newEmail)
@@ -168,7 +168,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput)
log.Debug("Failed to create verification token: ", err)
return res, err
}
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{
_, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken,
Identifier: verificationType,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),
@@ -186,7 +186,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput)
}
}
_, err = db.Provider.UpdateUser(user)
_, err = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
return res, err

View File

@@ -50,7 +50,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod
return res, fmt.Errorf("please enter atleast one param to update")
}
user, err := db.Provider.GetUserByID(params.ID)
user, err := db.Provider.GetUserByID(ctx, params.ID)
if err != nil {
log.Debug("Failed to get user by id: ", err)
return res, fmt.Errorf(`User not found`)
@@ -105,7 +105,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod
}
newEmail := strings.ToLower(*params.Email)
// check if user with new email exists
_, err = db.Provider.GetUserByEmail(newEmail)
_, err = db.Provider.GetUserByEmail(ctx, newEmail)
// err = nil means user exists
if err == nil {
log.Debug("User with email already exists: ", newEmail)
@@ -130,7 +130,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod
if err != nil {
log.Debug("Failed to create verification token: ", err)
}
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{
_, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken,
Identifier: verificationType,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),
@@ -189,7 +189,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod
user.Roles = rolesToSave
}
user, err = db.Provider.UpdateUser(user)
user, err = db.Provider.UpdateUser(ctx, user)
if err != nil {
log.Debug("Failed to update user: ", err)
return res, err

View File

@@ -0,0 +1,93 @@
package resolvers
import (
"context"
"encoding/json"
"fmt"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/authorizerdev/authorizer/server/validators"
log "github.com/sirupsen/logrus"
)
// UpdateWebhookResolver resolver for update webhook mutation
func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
webhook, err := db.Provider.GetWebhookByID(ctx, params.ID)
if err != nil {
log.Debug("failed to get webhook: ", err)
return nil, err
}
headersString := ""
if webhook.Headers != nil {
headerBytes, err := json.Marshal(webhook.Headers)
if err != nil {
log.Debug("failed to marshall source headers: ", err)
}
headersString = string(headerBytes)
}
webhookDetails := models.Webhook{
ID: webhook.ID,
Key: webhook.ID,
EventName: utils.StringValue(webhook.EventName),
EndPoint: utils.StringValue(webhook.Endpoint),
Enabled: utils.BoolValue(webhook.Enabled),
Headers: headersString,
CreatedAt: *webhook.CreatedAt,
}
if params.EventName != nil && webhookDetails.EventName != utils.StringValue(params.EventName) {
if isValid := validators.IsValidWebhookEventName(utils.StringValue(params.EventName)); !isValid {
log.Debug("invalid event name: ", utils.StringValue(params.EventName))
return nil, fmt.Errorf("invalid event name %s", utils.StringValue(params.EventName))
}
webhookDetails.EventName = utils.StringValue(params.EventName)
}
if params.Endpoint != nil && webhookDetails.EndPoint != utils.StringValue(params.Endpoint) {
webhookDetails.EventName = utils.StringValue(params.EventName)
}
if params.Enabled != nil && webhookDetails.Enabled != utils.BoolValue(params.Enabled) {
webhookDetails.Enabled = utils.BoolValue(params.Enabled)
}
if params.Headers != nil {
for key, val := range params.Headers {
webhook.Headers[key] = val
}
headerBytes, err := json.Marshal(webhook.Headers)
if err != nil {
log.Debug("failed to marshall headers: ", err)
return nil, err
}
webhookDetails.Headers = string(headerBytes)
}
_, err = db.Provider.UpdateWebhook(ctx, webhookDetails)
if err != nil {
return nil, err
}
return &model.Response{
Message: `Webhook updated successfully.`,
}, nil
}

View File

@@ -28,7 +28,7 @@ func UsersResolver(ctx context.Context, params *model.PaginatedInput) (*model.Us
pagination := utils.GetPagination(params)
res, err := db.Provider.ListUsers(pagination)
res, err := db.Provider.ListUsers(ctx, pagination)
if err != nil {
log.Debug("Failed to get users: ", err)
return nil, err

View File

@@ -50,7 +50,12 @@ func ValidateJwtTokenResolver(ctx context.Context, params model.ValidateJWTToken
// access_token and refresh_token should be validated from session store as well
if tokenType == constants.TokenTypeAccessToken || tokenType == constants.TokenTypeRefreshToken {
nonce = claims["nonce"].(string)
token, err := memorystore.Provider.GetUserSession(userID, tokenType+"_"+claims["nonce"].(string))
loginMethod := claims["login_method"]
sessionKey := userID
if loginMethod != nil && loginMethod != "" {
sessionKey = loginMethod.(string) + ":" + userID
}
token, err := memorystore.Provider.GetUserSession(sessionKey, tokenType+"_"+claims["nonce"].(string))
if err != nil || token == "" {
log.Debug("Failed to get user session: ", err)
return nil, errors.New("invalid token")

View File

@@ -28,7 +28,7 @@ func VerificationRequestsResolver(ctx context.Context, params *model.PaginatedIn
pagination := utils.GetPagination(params)
res, err := db.Provider.ListVerificationRequests(pagination)
res, err := db.Provider.ListVerificationRequests(ctx, pagination)
if err != nil {
log.Debug("Failed to get verification requests: ", err)
return nil, err

Some files were not shown because too many files have changed in this diff Show More