Compare commits

..

56 Commits

Author SHA1 Message Date
Lakhan Samani
a102924fd7 fix: remove debug logs 2022-07-17 13:39:23 +05:30
Lakhan Samani
a48b809a89 feat: add tests for email template resolvers 2022-07-17 13:37:34 +05:30
Lakhan Samani
cd46da60a0 feat: implement resolvers for email template 2022-07-17 12:32:01 +05:30
Lakhan Samani
50f52a99b4 fix: github user emails 2022-07-17 12:07:17 +05:30
Lakhan Samani
1f7eee43e2 feat: add email template implementation for arangodb provider 2022-07-17 11:37:04 +05:30
Lakhan Samani
7c441fff14 feat: add email template implementation for cassandra provider 2022-07-17 11:21:51 +05:30
Lakhan Samani
647cc1d9bf feat: add email template implementation for mongodb provider 2022-07-17 11:01:47 +05:30
Lakhan Samani
97b1d8d66f Merge branch 'main' into feat/add-email-template-apis 2022-07-17 10:39:02 +05:30
Lakhan Samani
2cce1c4e93 fix: webhook update headers 2022-07-17 10:36:16 +05:30
Lakhan Samani
610896b6f5 fix: refs for email templatE 2022-07-15 22:13:00 +05:30
Lakhan Samani
aa12757155 Merge branch 'main' of https://github.com/authorizerdev/authorizer into feat/add-email-template-apis 2022-07-15 22:11:58 +05:30
Lakhan Samani
847c364ad1 fix: refs 2022-07-15 22:11:08 +05:30
Lakhan Samani
e2294c24d0 feat: add email template implementation for sql provider 2022-07-15 12:35:35 +05:30
Lakhan Samani
283e570ebb feat: init email template schema for all providers 2022-07-15 10:23:45 +05:30
Lakhan Samani
14c74f6566 feat: add email template schema 2022-07-15 10:12:24 +05:30
Lakhan Samani
fed092bb65 fix: invite email template 2022-07-13 21:16:31 +05:30
Lakhan Samani
6d28290605 Merge pull request #199 from authorizerdev/fix/password-changing
fix(update_profile): changing password if not signed up via basic auth
2022-07-13 20:46:56 +05:30
Lakhan Samani
2de0ea57d0 fix(update_profile): changing password if not signed up via basic
Resolves #198
2022-07-13 20:45:21 +05:30
Lakhan Samani
f2886e6da8 fix: disable other db test for quick test 2022-07-12 11:57:46 +05:30
Lakhan Samani
6b57bce6d9 fix: cassandra + mongo + arangodb issues with webhook 2022-07-12 11:48:42 +05:30
Lakhan Samani
bfbeb6add2 fix: couple session deletion with user deletion 2022-07-12 08:42:32 +05:30
Lakhan Samani
1fe0d65874 feat: add support for planetscale
Resolves #195
2022-07-11 22:37:07 +05:30
Lakhan Samani
bfaa0f9d89 fix: make list webhooks params optional 2022-07-11 22:05:44 +05:30
Lakhan Samani
4f5a6c77f8 Merge pull request #194 from authorizerdev/feat/webhook
feat: add webhook apis + integrate in events
2022-07-11 19:56:48 +05:30
Lakhan Samani
018a13ab3c feat: add tests for webhook resolvers 2022-07-11 19:40:54 +05:30
Lakhan Samani
334041d0e4 fix: delete user event flow 2022-07-11 11:13:32 +05:30
Lakhan Samani
6a8309a231 feat: register event for revoke/enable access + delete user 2022-07-11 11:12:30 +05:30
Lakhan Samani
6347b60753 fix: rename revoke refresh token handler for better reading 2022-07-11 11:10:30 +05:30
Lakhan Samani
bbb064b939 feat: add register event 2022-07-11 10:42:42 +05:30
Lakhan Samani
e91a819067 feat: implement resolvers 2022-07-10 21:49:33 +05:30
Lakhan Samani
09c3eafe6b feat: add mongodb database methods for webhook 2022-07-09 12:23:48 +05:30
Lakhan Samani
bb51775d34 feat: add cassandradb database methods for webhook 2022-07-09 12:16:54 +05:30
Lakhan Samani
6d586b16e4 feat: add arangodb database methods for webhook 2022-07-09 11:44:14 +05:30
Lakhan Samani
e8eb62769e feat: add sql database methods for webhook 2022-07-09 11:21:32 +05:30
Lakhan Samani
0ffb3f67f1 fix: index for arangodb 2022-07-08 19:10:37 +05:30
Lakhan Samani
ec62686fbc feat: add database methods for webhookLog 2022-07-08 19:09:23 +05:30
Lakhan Samani
a8064e79a1 feat: add template for webhook db methods 2022-07-06 10:38:21 +05:30
Lakhan Samani
265331801f feat: add database models 2022-07-04 22:37:13 +05:30
Lakhan Samani
6a74a50493 feat: add support for scylladb
Resolves #177
2022-07-04 21:57:14 +05:30
Lakhan Samani
8c27f20534 Merge pull request #193 from authorizerdev/fix/invalidate-session-token
fix: add provider to token creation
2022-07-01 22:14:50 +05:30
Lakhan Samani
29c6003ea3 fix: remove test log 2022-07-01 22:04:58 +05:30
Lakhan Samani
ae34fc7c2b fix: update_env resolver 2022-07-01 22:02:34 +05:30
Lakhan Samani
2a5d5d43b0 fix: add namespace to session token keys 2022-06-29 22:24:00 +05:30
Lakhan Samani
e6a4670ba9 fix: add provider to token creation 2022-06-29 09:54:12 +05:30
Lakhan Samani
64d64b4099 feat: add ability to disable strong password 2022-06-18 15:31:57 +05:30
Lakhan Samani
88f9a10f21 Merge pull request #192 from authorizerdev/feat/apple-login
feat: add apple login
2022-06-16 09:48:02 +05:30
Lakhan Samani
4e08d4f8fd fix: update app 2022-06-16 09:47:16 +05:30
Lakhan Samani
1c4dda9299 fix: remove debug logs 2022-06-16 07:34:10 +05:30
Lakhan Samani
ab18fa5832 fix: use raw base64 url decoding 2022-06-15 21:55:41 +05:30
Lakhan Samani
484d0c0882 chore: update app 2022-06-14 16:39:06 +05:30
Lakhan Samani
be59c3615f fix: add comment for scope 2022-06-14 15:47:08 +05:30
Lakhan Samani
db351f7771 fix: remove debug logs 2022-06-14 15:45:06 +05:30
Lakhan Samani
91c29c4092 fix: redirect 2022-06-14 15:43:23 +05:30
Lakhan Samani
415b97535e fix: update scope param 2022-06-14 15:05:56 +05:30
Lakhan Samani
7d1272d815 fix: update scope for apple login 2022-06-14 14:41:31 +05:30
Lakhan Samani
c9ba0b13f8 fix: update scope for apple login 2022-06-14 13:37:05 +05:30
154 changed files with 7568 additions and 718 deletions

View File

@@ -10,7 +10,16 @@ build-dashboard:
clean: clean:
rm -rf build rm -rf build
test: test:
rm -rf server/test/test.db && rm -rf test.db && cd server && go clean --testcache && go test -p 1 -v ./test rm -rf server/test/test.db && rm -rf test.db && cd server && go clean --testcache && TEST_DBS="sqlite" go test -p 1 -v ./test
test-all-db:
rm -rf server/test/test.db && rm -rf test.db
docker run -d --name authorizer_scylla_db -p 9042:9042 scylladb/scylla
docker run -d --name authorizer_mongodb_db -p 27017:27017 mongo:4.4.15
docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.8.4
cd server && go clean --testcache && TEST_DBS="sqlite,mongodb,arangodb,scylladb" go test -p 1 -v ./test
docker rm -vf authorizer_mongodb_db
docker rm -vf authorizer_scylla_db
docker rm -vf authorizer_arangodb
generate: generate:
cd server && go get github.com/99designs/gqlgen/cmd@v0.14.0 && go run github.com/99designs/gqlgen generate cd server && go get github.com/99designs/gqlgen/cmd@v0.14.0 && go run github.com/99designs/gqlgen generate

30
app/package-lock.json generated
View File

@@ -9,7 +9,7 @@
"version": "1.0.0", "version": "1.0.0",
"license": "ISC", "license": "ISC",
"dependencies": { "dependencies": {
"@authorizerdev/authorizer-react": "^0.24.0-beta.1", "@authorizerdev/authorizer-react": "^0.25.0",
"@types/react": "^17.0.15", "@types/react": "^17.0.15",
"@types/react-dom": "^17.0.9", "@types/react-dom": "^17.0.9",
"esbuild": "^0.12.17", "esbuild": "^0.12.17",
@@ -26,9 +26,9 @@
} }
}, },
"node_modules/@authorizerdev/authorizer-js": { "node_modules/@authorizerdev/authorizer-js": {
"version": "0.13.0-beta.2", "version": "0.14.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.13.0-beta.2.tgz", "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.14.0.tgz",
"integrity": "sha512-3K3Qew/npD375DQdJnYb43aWVOU7giG2+8sHOjy8aXhZ+GlQQ8cGu54lozFYUMDcM1HjQU2N53xLmneONPepSw==", "integrity": "sha512-cpeeFrmG623QPLn+nf+ACHayZYqW8xokIidGikeboBDJtuAAQB50a54/7HwLHriG2FB7WvPuHQ/9LFFX//N1lg==",
"dependencies": { "dependencies": {
"node-fetch": "^2.6.1" "node-fetch": "^2.6.1"
}, },
@@ -37,11 +37,11 @@
} }
}, },
"node_modules/@authorizerdev/authorizer-react": { "node_modules/@authorizerdev/authorizer-react": {
"version": "0.24.0-beta.1", "version": "0.25.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.24.0-beta.1.tgz", "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.25.0.tgz",
"integrity": "sha512-S/Oqc24EfotbrABuv379i/3uCfQPYJQqXrOU9d8AytF++pzG/2dcoIoaMbWZQkATR3m6a5AnhpG6bIB+4NbrUQ==", "integrity": "sha512-Dt2rZf+cGCVb8dqcJ/9l8Trx+QeXnTdfhER6r/cq0iOnFC9MqWzQPB3RgrlUoMLHtZvKNDXIk1HvfD5hSX9lhw==",
"dependencies": { "dependencies": {
"@authorizerdev/authorizer-js": "^0.13.0-beta.2", "@authorizerdev/authorizer-js": "^0.14.0",
"final-form": "^4.20.2", "final-form": "^4.20.2",
"react-final-form": "^6.5.3", "react-final-form": "^6.5.3",
"styled-components": "^5.3.0" "styled-components": "^5.3.0"
@@ -852,19 +852,19 @@
}, },
"dependencies": { "dependencies": {
"@authorizerdev/authorizer-js": { "@authorizerdev/authorizer-js": {
"version": "0.13.0-beta.2", "version": "0.14.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.13.0-beta.2.tgz", "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-0.14.0.tgz",
"integrity": "sha512-3K3Qew/npD375DQdJnYb43aWVOU7giG2+8sHOjy8aXhZ+GlQQ8cGu54lozFYUMDcM1HjQU2N53xLmneONPepSw==", "integrity": "sha512-cpeeFrmG623QPLn+nf+ACHayZYqW8xokIidGikeboBDJtuAAQB50a54/7HwLHriG2FB7WvPuHQ/9LFFX//N1lg==",
"requires": { "requires": {
"node-fetch": "^2.6.1" "node-fetch": "^2.6.1"
} }
}, },
"@authorizerdev/authorizer-react": { "@authorizerdev/authorizer-react": {
"version": "0.24.0-beta.1", "version": "0.25.0",
"resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.24.0-beta.1.tgz", "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-0.25.0.tgz",
"integrity": "sha512-S/Oqc24EfotbrABuv379i/3uCfQPYJQqXrOU9d8AytF++pzG/2dcoIoaMbWZQkATR3m6a5AnhpG6bIB+4NbrUQ==", "integrity": "sha512-Dt2rZf+cGCVb8dqcJ/9l8Trx+QeXnTdfhER6r/cq0iOnFC9MqWzQPB3RgrlUoMLHtZvKNDXIk1HvfD5hSX9lhw==",
"requires": { "requires": {
"@authorizerdev/authorizer-js": "^0.13.0-beta.2", "@authorizerdev/authorizer-js": "^0.14.0",
"final-form": "^4.20.2", "final-form": "^4.20.2",
"react-final-form": "^6.5.3", "react-final-form": "^6.5.3",
"styled-components": "^5.3.0" "styled-components": "^5.3.0"

View File

@@ -11,7 +11,7 @@
"author": "Lakhan Samani", "author": "Lakhan Samani",
"license": "ISC", "license": "ISC",
"dependencies": { "dependencies": {
"@authorizerdev/authorizer-react": "^0.24.0-beta.1", "@authorizerdev/authorizer-react": "^0.25.0",
"@types/react": "^17.0.15", "@types/react": "^17.0.15",
"@types/react-dom": "^17.0.9", "@types/react-dom": "^17.0.9",
"esbuild": "^0.12.17", "esbuild": "^0.12.17",

View File

@@ -2529,7 +2529,8 @@
"@chakra-ui/css-reset": { "@chakra-ui/css-reset": {
"version": "1.1.1", "version": "1.1.1",
"resolved": "https://registry.npmjs.org/@chakra-ui/css-reset/-/css-reset-1.1.1.tgz", "resolved": "https://registry.npmjs.org/@chakra-ui/css-reset/-/css-reset-1.1.1.tgz",
"integrity": "sha512-+KNNHL4OWqeKia5SL858K3Qbd8WxMij9mWIilBzLD4j2KFrl/+aWFw8syMKth3NmgIibrjsljo+PU3fy2o50dg==" "integrity": "sha512-+KNNHL4OWqeKia5SL858K3Qbd8WxMij9mWIilBzLD4j2KFrl/+aWFw8syMKth3NmgIibrjsljo+PU3fy2o50dg==",
"requires": {}
}, },
"@chakra-ui/descendant": { "@chakra-ui/descendant": {
"version": "2.1.1", "version": "2.1.1",
@@ -3133,7 +3134,8 @@
"@graphql-typed-document-node/core": { "@graphql-typed-document-node/core": {
"version": "3.1.1", "version": "3.1.1",
"resolved": "https://registry.npmjs.org/@graphql-typed-document-node/core/-/core-3.1.1.tgz", "resolved": "https://registry.npmjs.org/@graphql-typed-document-node/core/-/core-3.1.1.tgz",
"integrity": "sha512-NQ17ii0rK1b34VZonlmT2QMJFI70m0TRwbknO/ihlbatXyaktDhN/98vBiUU6kNBPljqGqyIrl2T4nY2RpFANg==" "integrity": "sha512-NQ17ii0rK1b34VZonlmT2QMJFI70m0TRwbknO/ihlbatXyaktDhN/98vBiUU6kNBPljqGqyIrl2T4nY2RpFANg==",
"requires": {}
}, },
"@popperjs/core": { "@popperjs/core": {
"version": "2.11.0", "version": "2.11.0",
@@ -3843,7 +3845,8 @@
"react-icons": { "react-icons": {
"version": "4.3.1", "version": "4.3.1",
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.3.1.tgz", "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.3.1.tgz",
"integrity": "sha512-cB10MXLTs3gVuXimblAdI71jrJx8njrJZmNMEMC+sQu5B/BIOmlsAjskdqpn81y8UBVEGuHODd7/ci5DvoSzTQ==" "integrity": "sha512-cB10MXLTs3gVuXimblAdI71jrJx8njrJZmNMEMC+sQu5B/BIOmlsAjskdqpn81y8UBVEGuHODd7/ci5DvoSzTQ==",
"requires": {}
}, },
"react-is": { "react-is": {
"version": "16.13.1", "version": "16.13.1",
@@ -4029,7 +4032,8 @@
"use-callback-ref": { "use-callback-ref": {
"version": "1.2.5", "version": "1.2.5",
"resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.2.5.tgz", "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.2.5.tgz",
"integrity": "sha512-gN3vgMISAgacF7sqsLPByqoePooY3n2emTH59Ur5d/M8eg4WTWu1xp8i8DHjohftIyEx0S08RiYxbffr4j8Peg==" "integrity": "sha512-gN3vgMISAgacF7sqsLPByqoePooY3n2emTH59Ur5d/M8eg4WTWu1xp8i8DHjohftIyEx0S08RiYxbffr4j8Peg==",
"requires": {}
}, },
"use-sidecar": { "use-sidecar": {
"version": "1.0.5", "version": "1.0.5",

View File

@@ -71,6 +71,18 @@ const Features = ({ variables, setVariables }: any) => {
/> />
</Flex> </Flex>
</Flex> </Flex>
<Flex>
<Flex w="100%" justifyContent="start" alignItems="center">
<Text fontSize="sm">Disable Strong Password:</Text>
</Flex>
<Flex justifyContent="start" mb={3}>
<InputField
variables={variables}
setVariables={setVariables}
inputType={SwitchInputType.DISABLE_STRONG_PASSWORD}
/>
</Flex>
</Flex>
</Stack> </Stack>
</div> </div>
); );

View File

@@ -67,6 +67,7 @@ export const SwitchInputType = {
DISABLE_BASIC_AUTHENTICATION: 'DISABLE_BASIC_AUTHENTICATION', DISABLE_BASIC_AUTHENTICATION: 'DISABLE_BASIC_AUTHENTICATION',
DISABLE_SIGN_UP: 'DISABLE_SIGN_UP', DISABLE_SIGN_UP: 'DISABLE_SIGN_UP',
DISABLE_REDIS_FOR_ENV: 'DISABLE_REDIS_FOR_ENV', DISABLE_REDIS_FOR_ENV: 'DISABLE_REDIS_FOR_ENV',
DISABLE_STRONG_PASSWORD: 'DISABLE_STRONG_PASSWORD',
}; };
export const DateInputType = { export const DateInputType = {
@@ -131,6 +132,7 @@ export interface envVarTypes {
DISABLE_EMAIL_VERIFICATION: boolean; DISABLE_EMAIL_VERIFICATION: boolean;
DISABLE_BASIC_AUTHENTICATION: boolean; DISABLE_BASIC_AUTHENTICATION: boolean;
DISABLE_SIGN_UP: boolean; DISABLE_SIGN_UP: boolean;
DISABLE_STRONG_PASSWORD: boolean;
OLD_ADMIN_SECRET: string; OLD_ADMIN_SECRET: string;
DATABASE_NAME: string; DATABASE_NAME: string;
DATABASE_TYPE: string; DATABASE_TYPE: string;

View File

@@ -53,6 +53,7 @@ export const EnvVariablesQuery = `
DISABLE_EMAIL_VERIFICATION, DISABLE_EMAIL_VERIFICATION,
DISABLE_BASIC_AUTHENTICATION, DISABLE_BASIC_AUTHENTICATION,
DISABLE_SIGN_UP, DISABLE_SIGN_UP,
DISABLE_STRONG_PASSWORD,
DISABLE_REDIS_FOR_ENV, DISABLE_REDIS_FOR_ENV,
CUSTOM_ACCESS_TOKEN_SCRIPT, CUSTOM_ACCESS_TOKEN_SCRIPT,
DATABASE_NAME, DATABASE_NAME,

View File

@@ -74,6 +74,7 @@ const Environment = () => {
DISABLE_EMAIL_VERIFICATION: false, DISABLE_EMAIL_VERIFICATION: false,
DISABLE_BASIC_AUTHENTICATION: false, DISABLE_BASIC_AUTHENTICATION: false,
DISABLE_SIGN_UP: false, DISABLE_SIGN_UP: false,
DISABLE_STRONG_PASSWORD: false,
OLD_ADMIN_SECRET: '', OLD_ADMIN_SECRET: '',
DATABASE_NAME: '', DATABASE_NAME: '',
DATABASE_TYPE: '', DATABASE_TYPE: '',

View File

@@ -0,0 +1,18 @@
package constants
const (
// AuthRecipeMethodBasicAuth is the basic_auth auth method
AuthRecipeMethodBasicAuth = "basic_auth"
// AuthRecipeMethodMagicLinkLogin is the magic_link_login auth method
AuthRecipeMethodMagicLinkLogin = "magic_link_login"
// AuthRecipeMethodGoogle is the google auth method
AuthRecipeMethodGoogle = "google"
// AuthRecipeMethodGithub is the github auth method
AuthRecipeMethodGithub = "github"
// AuthRecipeMethodFacebook is the facebook auth method
AuthRecipeMethodFacebook = "facebook"
// AuthRecipeMethodLinkedin is the linkedin auth method
AuthRecipeMethodLinkedIn = "linkedin"
// AuthRecipeMethodApple is the apple auth method
AuthRecipeMethodApple = "apple"
)

View File

@@ -23,4 +23,6 @@ const (
DbTypeScyllaDB = "scylladb" DbTypeScyllaDB = "scylladb"
// DbTypeCockroachDB is the cockroach database type // DbTypeCockroachDB is the cockroach database type
DbTypeCockroachDB = "cockroachdb" DbTypeCockroachDB = "cockroachdb"
// DbTypePlanetScaleDB is the planetscale database type
DbTypePlanetScaleDB = "planetscale"
) )

View File

@@ -115,6 +115,8 @@ const (
EnvKeyDisableSignUp = "DISABLE_SIGN_UP" EnvKeyDisableSignUp = "DISABLE_SIGN_UP"
// EnvKeyDisableRedisForEnv key for env variable DISABLE_REDIS_FOR_ENV // EnvKeyDisableRedisForEnv key for env variable DISABLE_REDIS_FOR_ENV
EnvKeyDisableRedisForEnv = "DISABLE_REDIS_FOR_ENV" EnvKeyDisableRedisForEnv = "DISABLE_REDIS_FOR_ENV"
// EnvKeyDisableStrongPassword key for env variable DISABLE_STRONG_PASSWORD
EnvKeyDisableStrongPassword = "DISABLE_STRONG_PASSWORD"
// Slice variables // Slice variables
// EnvKeyRoles key for env variable ROLES // EnvKeyRoles key for env variable ROLES

View File

@@ -8,6 +8,9 @@ const (
FacebookUserInfoURL = "https://graph.facebook.com/me?fields=id,first_name,last_name,name,email,picture&access_token=" FacebookUserInfoURL = "https://graph.facebook.com/me?fields=id,first_name,last_name,name,email,picture&access_token="
// Ref: https://docs.github.com/en/developers/apps/building-github-apps/identifying-and-authorizing-users-for-github-apps#3-your-github-app-accesses-the-api-with-the-users-access-token // Ref: https://docs.github.com/en/developers/apps/building-github-apps/identifying-and-authorizing-users-for-github-apps#3-your-github-app-accesses-the-api-with-the-users-access-token
GithubUserInfoURL = "https://api.github.com/user" GithubUserInfoURL = "https://api.github.com/user"
// Get github user emails when user info email is empty Ref: https://stackoverflow.com/a/35387123
GithubUserEmails = "https://api/github.com/user/emails"
// Ref: https://docs.microsoft.com/en-us/linkedin/shared/integrations/people/profile-api // Ref: https://docs.microsoft.com/en-us/linkedin/shared/integrations/people/profile-api
LinkedInUserInfoURL = "https://api.linkedin.com/v2/me?projection=(id,localizedFirstName,localizedLastName,emailAddress,profilePicture(displayImage~:playableStreams))" LinkedInUserInfoURL = "https://api.linkedin.com/v2/me?projection=(id,localizedFirstName,localizedLastName,emailAddress,profilePicture(displayImage~:playableStreams))"
LinkedInEmailURL = "https://api.linkedin.com/v2/emailAddress?q=members&projection=(elements*(handle~))" LinkedInEmailURL = "https://api.linkedin.com/v2/emailAddress?q=members&projection=(elements*(handle~))"

View File

@@ -1,18 +0,0 @@
package constants
const (
// SignupMethodBasicAuth is the basic_auth signup method
SignupMethodBasicAuth = "basic_auth"
// SignupMethodMagicLinkLogin is the magic_link_login signup method
SignupMethodMagicLinkLogin = "magic_link_login"
// SignupMethodGoogle is the google signup method
SignupMethodGoogle = "google"
// SignupMethodGithub is the github signup method
SignupMethodGithub = "github"
// SignupMethodFacebook is the facebook signup method
SignupMethodFacebook = "facebook"
// SignupMethodLinkedin is the linkedin signup method
SignupMethodLinkedIn = "linkedin"
// SignupMethodApple is the apple signup method
SignupMethodApple = "apple"
)

View File

@@ -0,0 +1,18 @@
package constants
const (
// UserLoginWebhookEvent name for login event
UserLoginWebhookEvent = `user.login`
// UserCreatedWebhookEvent name for user creation event
// This is triggered when user entry is created but still not verified
UserCreatedWebhookEvent = `user.created`
// UserSignUpWebhookEvent name for signup event
UserSignUpWebhookEvent = `user.signup`
// UserAccessRevokedWebhookEvent name for user access revoke event
UserAccessRevokedWebhookEvent = `user.access_revoked`
// UserAccessEnabledWebhookEvent name for user access enable event
UserAccessEnabledWebhookEvent = `user.access_enabled`
// UserDeletedWebhookEvent name for user deleted event
UserDeletedWebhookEvent = `user.deleted`
)

View File

@@ -0,0 +1,33 @@
package models
import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/refs"
)
// EmailTemplate model for database
type EmailTemplate struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
EventName string `gorm:"unique" json:"event_name" bson:"event_name" cql:"event_name"`
Template string `gorm:"type:text" json:"template" bson:"template" cql:"template"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"`
}
// AsAPIEmailTemplate to return email template as graphql response object
func (e *EmailTemplate) AsAPIEmailTemplate() *model.EmailTemplate {
id := e.ID
if strings.Contains(id, Collections.EmailTemplate+"/") {
id = strings.TrimPrefix(id, Collections.EmailTemplate+"/")
}
return &model.EmailTemplate{
ID: id,
EventName: e.EventName,
Template: e.Template,
CreatedAt: refs.NewInt64Ref(e.CreatedAt),
UpdatedAt: refs.NewInt64Ref(e.UpdatedAt),
}
}

View File

@@ -6,6 +6,9 @@ type CollectionList struct {
VerificationRequest string VerificationRequest string
Session string Session string
Env string Env string
Webhook string
WebhookLog string
EmailTemplate string
} }
var ( var (
@@ -17,5 +20,8 @@ var (
VerificationRequest: Prefix + "verification_requests", VerificationRequest: Prefix + "verification_requests",
Session: Prefix + "sessions", Session: Prefix + "sessions",
Env: Prefix + "env", Env: Prefix + "env",
Webhook: Prefix + "webhook",
WebhookLog: Prefix + "webhook_log",
EmailTemplate: Prefix + "email_template",
} }
) )

View File

@@ -6,8 +6,7 @@ package models
type Session struct { type Session struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"` ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
UserID string `gorm:"type:char(36),index:" json:"user_id" bson:"user_id" cql:"user_id"` UserID string `gorm:"type:char(36)" json:"user_id" bson:"user_id" cql:"user_id"`
User User `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"-" bson:"-" cql:"-"`
UserAgent string `json:"user_agent" bson:"user_agent" cql:"user_agent"` UserAgent string `json:"user_agent" bson:"user_agent" cql:"user_agent"`
IP string `json:"ip" bson:"ip" cql:"ip"` IP string `json:"ip" bson:"ip" cql:"ip"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"` CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`

View File

@@ -4,6 +4,7 @@ import (
"strings" "strings"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/refs"
) )
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
@@ -35,11 +36,13 @@ type User struct {
func (user *User) AsAPIUser() *model.User { func (user *User) AsAPIUser() *model.User {
isEmailVerified := user.EmailVerifiedAt != nil isEmailVerified := user.EmailVerifiedAt != nil
isPhoneVerified := user.PhoneNumberVerifiedAt != nil isPhoneVerified := user.PhoneNumberVerifiedAt != nil
email := user.Email
createdAt := user.CreatedAt id := user.ID
updatedAt := user.UpdatedAt if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.User{ return &model.User{
ID: user.ID, ID: id,
Email: user.Email, Email: user.Email,
EmailVerified: isEmailVerified, EmailVerified: isEmailVerified,
SignupMethods: user.SignupMethods, SignupMethods: user.SignupMethods,
@@ -47,7 +50,7 @@ func (user *User) AsAPIUser() *model.User {
FamilyName: user.FamilyName, FamilyName: user.FamilyName,
MiddleName: user.MiddleName, MiddleName: user.MiddleName,
Nickname: user.Nickname, Nickname: user.Nickname,
PreferredUsername: &email, PreferredUsername: refs.NewStringRef(user.Email),
Gender: user.Gender, Gender: user.Gender,
Birthdate: user.Birthdate, Birthdate: user.Birthdate,
PhoneNumber: user.PhoneNumber, PhoneNumber: user.PhoneNumber,
@@ -55,7 +58,7 @@ func (user *User) AsAPIUser() *model.User {
Picture: user.Picture, Picture: user.Picture,
Roles: strings.Split(user.Roles, ","), Roles: strings.Split(user.Roles, ","),
RevokedTimestamp: user.RevokedTimestamp, RevokedTimestamp: user.RevokedTimestamp,
CreatedAt: &createdAt, CreatedAt: refs.NewInt64Ref(user.CreatedAt),
UpdatedAt: &updatedAt, UpdatedAt: refs.NewInt64Ref(user.UpdatedAt),
} }
} }

View File

@@ -1,6 +1,11 @@
package models package models
import "github.com/authorizerdev/authorizer/server/graph/model" import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/refs"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
@@ -19,23 +24,20 @@ type VerificationRequest struct {
} }
func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequest { func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequest {
token := v.Token id := v.ID
createdAt := v.CreatedAt if strings.Contains(id, Collections.WebhookLog+"/") {
updatedAt := v.UpdatedAt id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
email := v.Email }
nonce := v.Nonce
redirectURI := v.RedirectURI
expires := v.ExpiresAt
identifier := v.Identifier
return &model.VerificationRequest{ return &model.VerificationRequest{
ID: v.ID, ID: id,
Token: &token, Token: refs.NewStringRef(v.Token),
Identifier: &identifier, Identifier: refs.NewStringRef(v.Identifier),
Expires: &expires, Expires: refs.NewInt64Ref(v.ExpiresAt),
Email: &email, Email: refs.NewStringRef(v.Email),
Nonce: &nonce, Nonce: refs.NewStringRef(v.Nonce),
RedirectURI: &redirectURI, RedirectURI: refs.NewStringRef(v.RedirectURI),
CreatedAt: &createdAt, CreatedAt: refs.NewInt64Ref(v.CreatedAt),
UpdatedAt: &updatedAt, UpdatedAt: refs.NewInt64Ref(v.UpdatedAt),
} }
} }

View File

@@ -0,0 +1,44 @@
package models
import (
"encoding/json"
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/refs"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
// Webhook model for db
type Webhook struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
EventName string `gorm:"unique" json:"event_name" bson:"event_name" cql:"event_name"`
EndPoint string `gorm:"type:text" json:"endpoint" bson:"endpoint" cql:"endpoint"`
Headers string `gorm:"type:text" json:"headers" bson:"headers" cql:"headers"`
Enabled bool `json:"enabled" bson:"enabled" cql:"enabled"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"`
}
// AsAPIWebhook to return webhook as graphql response object
func (w *Webhook) AsAPIWebhook() *model.Webhook {
headersMap := make(map[string]interface{})
json.Unmarshal([]byte(w.Headers), &headersMap)
id := w.ID
if strings.Contains(id, Collections.Webhook+"/") {
id = strings.TrimPrefix(id, Collections.Webhook+"/")
}
return &model.Webhook{
ID: id,
EventName: refs.NewStringRef(w.EventName),
Endpoint: refs.NewStringRef(w.EndPoint),
Headers: headersMap,
Enabled: refs.NewBoolRef(w.Enabled),
CreatedAt: refs.NewInt64Ref(w.CreatedAt),
UpdatedAt: refs.NewInt64Ref(w.UpdatedAt),
}
}

View File

@@ -0,0 +1,39 @@
package models
import (
"strings"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/refs"
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
// WebhookLog model for db
type WebhookLog struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"`
HttpStatus int64 `json:"http_status" bson:"http_status" cql:"http_status"`
Response string `gorm:"type:text" json:"response" bson:"response" cql:"response"`
Request string `gorm:"type:text" json:"request" bson:"request" cql:"request"`
WebhookID string `gorm:"type:char(36)" json:"webhook_id" bson:"webhook_id" cql:"webhook_id"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"`
}
// AsAPIWebhookLog to return webhook log as graphql response object
func (w *WebhookLog) AsAPIWebhookLog() *model.WebhookLog {
id := w.ID
if strings.Contains(id, Collections.WebhookLog+"/") {
id = strings.TrimPrefix(id, Collections.WebhookLog+"/")
}
return &model.WebhookLog{
ID: id,
HTTPStatus: refs.NewInt64Ref(w.HttpStatus),
Response: refs.NewStringRef(w.Response),
Request: refs.NewStringRef(w.Request),
WebhookID: refs.NewStringRef(w.WebhookID),
CreatedAt: refs.NewInt64Ref(w.CreatedAt),
UpdatedAt: refs.NewInt64Ref(w.UpdatedAt),
}
}

View File

@@ -0,0 +1,151 @@
package arangodb
import (
"context"
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddEmailTemplate to add EmailTemplate
func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
if emailTemplate.ID == "" {
emailTemplate.ID = uuid.New().String()
}
emailTemplate.Key = emailTemplate.ID
emailTemplate.CreatedAt = time.Now().Unix()
emailTemplate.UpdatedAt = time.Now().Unix()
emailTemplateCollection, _ := p.db.Collection(ctx, models.Collections.EmailTemplate)
_, err := emailTemplateCollection.CreateDocument(ctx, emailTemplate)
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
emailTemplate.UpdatedAt = time.Now().Unix()
emailTemplateCollection, _ := p.db.Collection(ctx, models.Collections.EmailTemplate)
meta, err := emailTemplateCollection.UpdateDocument(ctx, emailTemplate.Key, emailTemplate)
if err != nil {
return nil, err
}
emailTemplate.Key = meta.Key
emailTemplate.ID = meta.ID.String()
return emailTemplate.AsAPIEmailTemplate(), nil
}
// ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
emailTemplates := []*model.EmailTemplate{}
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.EmailTemplate, pagination.Offset, pagination.Limit)
sctx := driver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err
}
defer cursor.Close()
paginationClone := pagination
paginationClone.Total = cursor.Statistics().FullCount()
for {
var emailTemplate models.EmailTemplate
meta, err := cursor.ReadDocument(ctx, &emailTemplate)
if arangoDriver.IsNoMoreDocuments(err) {
break
} else if err != nil {
return nil, err
}
if meta.Key != "" {
emailTemplates = append(emailTemplates, emailTemplate.AsAPIEmailTemplate())
}
}
return &model.EmailTemplates{
Pagination: &paginationClone,
EmailTemplates: emailTemplates,
}, nil
}
// GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
query := fmt.Sprintf("FOR d in %s FILTER d._key == @email_template_id RETURN d", models.Collections.EmailTemplate)
bindVars := map[string]interface{}{
"email_template_id": emailTemplateID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return nil, err
}
defer cursor.Close()
for {
if !cursor.HasMore() {
if emailTemplate.Key == "" {
return nil, fmt.Errorf("email template not found")
}
break
}
_, err := cursor.ReadDocument(ctx, &emailTemplate)
if err != nil {
return nil, err
}
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
query := fmt.Sprintf("FOR d in %s FILTER d.event_name == @event_name RETURN d", models.Collections.EmailTemplate)
bindVars := map[string]interface{}{
"event_name": eventName,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return nil, err
}
defer cursor.Close()
for {
if !cursor.HasMore() {
if emailTemplate.Key == "" {
return nil, fmt.Errorf("email template not found")
}
break
}
_, err := cursor.ReadDocument(ctx, &emailTemplate)
if err != nil {
return nil, err
}
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// DeleteEmailTemplate to delete EmailTemplate
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error {
eventTemplateCollection, _ := p.db.Collection(ctx, models.Collections.EmailTemplate)
_, err := eventTemplateCollection.RemoveDocument(ctx, emailTemplate.ID)
if err != nil {
return err
}
return nil
}

View File

@@ -1,6 +1,7 @@
package arangodb package arangodb
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -11,15 +12,15 @@ import (
) )
// AddEnv to save environment information in database // AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) { func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" { if env.ID == "" {
env.ID = uuid.New().String() env.ID = uuid.New().String()
} }
env.CreatedAt = time.Now().Unix() env.CreatedAt = time.Now().Unix()
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
configCollection, _ := p.db.Collection(nil, models.Collections.Env) configCollection, _ := p.db.Collection(ctx, models.Collections.Env)
meta, err := configCollection.CreateDocument(arangoDriver.WithOverwrite(nil), env) meta, err := configCollection.CreateDocument(arangoDriver.WithOverwrite(ctx), env)
if err != nil { if err != nil {
return env, err return env, err
} }
@@ -29,10 +30,10 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
} }
// UpdateEnv to update environment information in database // UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
collection, _ := p.db.Collection(nil, models.Collections.Env) collection, _ := p.db.Collection(ctx, models.Collections.Env)
meta, err := collection.UpdateDocument(nil, env.Key, env) meta, err := collection.UpdateDocument(ctx, env.Key, env)
if err != nil { if err != nil {
return env, err return env, err
} }
@@ -43,11 +44,11 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
} }
// GetEnv to get environment information from database // GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) { func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env var env models.Env
query := fmt.Sprintf("FOR d in %s RETURN d", models.Collections.Env) query := fmt.Sprintf("FOR d in %s RETURN d", models.Collections.Env)
cursor, err := p.db.Query(nil, query, nil) cursor, err := p.db.Query(ctx, query, nil)
if err != nil { if err != nil {
return env, err return env, err
} }
@@ -60,7 +61,7 @@ func (p *provider) GetEnv() (models.Env, error) {
} }
break break
} }
_, err := cursor.ReadDocument(nil, &env) _, err := cursor.ReadDocument(ctx, &env)
if err != nil { if err != nil {
return env, err return env, err
} }

View File

@@ -107,6 +107,47 @@ func NewProvider() (*provider, error) {
} }
} }
webhookCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Webhook)
if !webhookCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.Webhook, nil)
if err != nil {
return nil, err
}
}
webhookCollection, _ := arangodb.Collection(nil, models.Collections.Webhook)
webhookCollection.EnsureHashIndex(ctx, []string{"event_name"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
})
webhookLogCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.WebhookLog)
if !webhookLogCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.WebhookLog, nil)
if err != nil {
return nil, err
}
}
webhookLogCollection, _ := arangodb.Collection(nil, models.Collections.WebhookLog)
webhookLogCollection.EnsureHashIndex(ctx, []string{"webhook_id"}, &arangoDriver.EnsureHashIndexOptions{
Sparse: true,
})
emailTemplateCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.EmailTemplate)
if !emailTemplateCollectionExists {
_, err = arangodb.CreateCollection(ctx, models.Collections.EmailTemplate, nil)
if err != nil {
return nil, err
}
}
emailTemplateCollection, _ := arangodb.Collection(nil, models.Collections.EmailTemplate)
emailTemplateCollection.EnsureHashIndex(ctx, []string{"event_name"}, &arangoDriver.EnsureHashIndexOptions{
Unique: true,
Sparse: true,
})
return &provider{ return &provider{
db: arangodb, db: arangodb,
}, err }, err

View File

@@ -1,7 +1,7 @@
package arangodb package arangodb
import ( import (
"fmt" "context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -9,31 +9,17 @@ import (
) )
// AddSession to save session information in database // AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error { func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" { if session.ID == "" {
session.ID = uuid.New().String() session.ID = uuid.New().String()
} }
session.CreatedAt = time.Now().Unix() session.CreatedAt = time.Now().Unix()
session.UpdatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix()
sessionCollection, _ := p.db.Collection(nil, models.Collections.Session) sessionCollection, _ := p.db.Collection(ctx, models.Collections.Session)
_, err := sessionCollection.CreateDocument(nil, session) _, err := sessionCollection.CreateDocument(ctx, session)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @userId REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session)
bindVars := map[string]interface{}{
"userId": userId,
}
cursor, err := p.db.Query(nil, query, bindVars)
if err != nil {
return err
}
defer cursor.Close()
return nil
}

View File

@@ -15,7 +15,7 @@ import (
) )
// AddUser to save user information in database // AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) { func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" { if user.ID == "" {
user.ID = uuid.New().String() user.ID = uuid.New().String()
} }
@@ -30,8 +30,8 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
user.CreatedAt = time.Now().Unix() user.CreatedAt = time.Now().Unix()
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
userCollection, _ := p.db.Collection(nil, models.Collections.User) userCollection, _ := p.db.Collection(ctx, models.Collections.User)
meta, err := userCollection.CreateDocument(arangoDriver.WithOverwrite(nil), user) meta, err := userCollection.CreateDocument(arangoDriver.WithOverwrite(ctx), user)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -42,10 +42,10 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
} }
// UpdateUser to update user information in database // UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) { func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
collection, _ := p.db.Collection(nil, models.Collections.User) collection, _ := p.db.Collection(ctx, models.Collections.User)
meta, err := collection.UpdateDocument(nil, user.Key, user) meta, err := collection.UpdateDocument(ctx, user.Key, user)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -56,24 +56,34 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
} }
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error { func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
collection, _ := p.db.Collection(nil, models.Collections.User) collection, _ := p.db.Collection(ctx, models.Collections.User)
_, err := collection.RemoveDocument(nil, user.Key) _, err := collection.RemoveDocument(ctx, user.Key)
if err != nil { if err != nil {
return err return err
} }
query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @user_id REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session)
bindVars := map[string]interface{}{
"user_id": user.ID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return err
}
defer cursor.Close()
return nil return nil
} }
// ListUsers to get list of users from database // ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []*model.User var users []*model.User
ctx := driver.WithQueryFullCount(context.Background()) sctx := driver.WithQueryFullCount(ctx)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.User, pagination.Offset, pagination.Limit) query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.User, pagination.Offset, pagination.Limit)
cursor, err := p.db.Query(ctx, query, nil) cursor, err := p.db.Query(sctx, query, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -84,7 +94,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
for { for {
var user models.User var user models.User
meta, err := cursor.ReadDocument(nil, &user) meta, err := cursor.ReadDocument(ctx, &user)
if arangoDriver.IsNoMoreDocuments(err) { if arangoDriver.IsNoMoreDocuments(err) {
break break
@@ -104,7 +114,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
} }
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User var user models.User
query := fmt.Sprintf("FOR d in %s FILTER d.email == @email RETURN d", models.Collections.User) query := fmt.Sprintf("FOR d in %s FILTER d.email == @email RETURN d", models.Collections.User)
@@ -112,7 +122,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
"email": email, "email": email,
} }
cursor, err := p.db.Query(nil, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -125,7 +135,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
} }
break break
} }
_, err := cursor.ReadDocument(nil, &user) _, err := cursor.ReadDocument(ctx, &user)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -135,7 +145,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
} }
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) { func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User var user models.User
query := fmt.Sprintf("FOR d in %s FILTER d._id == @id LIMIT 1 RETURN d", models.Collections.User) query := fmt.Sprintf("FOR d in %s FILTER d._id == @id LIMIT 1 RETURN d", models.Collections.User)
@@ -143,7 +153,7 @@ func (p *provider) GetUserByID(id string) (models.User, error) {
"id": id, "id": id,
} }
cursor, err := p.db.Query(nil, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -156,7 +166,7 @@ func (p *provider) GetUserByID(id string) (models.User, error) {
} }
break break
} }
_, err := cursor.ReadDocument(nil, &user) _, err := cursor.ReadDocument(ctx, &user)
if err != nil { if err != nil {
return user, err return user, err
} }

View File

@@ -12,15 +12,15 @@ import (
) )
// AddVerification to save verification request in database // AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" { if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String() verificationRequest.ID = uuid.New().String()
} }
verificationRequest.CreatedAt = time.Now().Unix() verificationRequest.CreatedAt = time.Now().Unix()
verificationRequest.UpdatedAt = time.Now().Unix() verificationRequest.UpdatedAt = time.Now().Unix()
verificationRequestRequestCollection, _ := p.db.Collection(nil, models.Collections.VerificationRequest) verificationRequestRequestCollection, _ := p.db.Collection(ctx, models.Collections.VerificationRequest)
meta, err := verificationRequestRequestCollection.CreateDocument(nil, verificationRequest) meta, err := verificationRequestRequestCollection.CreateDocument(ctx, verificationRequest)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -31,14 +31,14 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
} }
// GetVerificationRequestByToken to get verification request from database using token // GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
query := fmt.Sprintf("FOR d in %s FILTER d.token == @token LIMIT 1 RETURN d", models.Collections.VerificationRequest) query := fmt.Sprintf("FOR d in %s FILTER d.token == @token LIMIT 1 RETURN d", models.Collections.VerificationRequest)
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"token": token, "token": token,
} }
cursor, err := p.db.Query(nil, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -51,7 +51,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
} }
break break
} }
_, err := cursor.ReadDocument(nil, &verificationRequest) _, err := cursor.ReadDocument(ctx, &verificationRequest)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -61,7 +61,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
} }
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
query := fmt.Sprintf("FOR d in %s FILTER d.email == @email FILTER d.identifier == @identifier LIMIT 1 RETURN d", models.Collections.VerificationRequest) query := fmt.Sprintf("FOR d in %s FILTER d.email == @email FILTER d.identifier == @identifier LIMIT 1 RETURN d", models.Collections.VerificationRequest)
@@ -70,7 +70,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
"identifier": identifier, "identifier": identifier,
} }
cursor, err := p.db.Query(nil, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -83,7 +83,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
} }
break break
} }
_, err := cursor.ReadDocument(nil, &verificationRequest) _, err := cursor.ReadDocument(ctx, &verificationRequest)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -93,12 +93,12 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
} }
// ListVerificationRequests to get list of verification requests from database // ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest var verificationRequests []*model.VerificationRequest
ctx := driver.WithQueryFullCount(context.Background()) sctx := driver.WithQueryFullCount(ctx)
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.VerificationRequest, pagination.Offset, pagination.Limit) query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.VerificationRequest, pagination.Offset, pagination.Limit)
cursor, err := p.db.Query(ctx, query, nil) cursor, err := p.db.Query(sctx, query, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -109,7 +109,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
for { for {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
meta, err := cursor.ReadDocument(nil, &verificationRequest) meta, err := cursor.ReadDocument(ctx, &verificationRequest)
if driver.IsNoMoreDocuments(err) { if driver.IsNoMoreDocuments(err) {
break break
@@ -130,7 +130,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
} }
// DeleteVerificationRequest to delete verification request from database // DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
collection, _ := p.db.Collection(nil, models.Collections.VerificationRequest) collection, _ := p.db.Collection(nil, models.Collections.VerificationRequest)
_, err := collection.RemoveDocument(nil, verificationRequest.Key) _, err := collection.RemoveDocument(nil, verificationRequest.Key)
if err != nil { if err != nil {

View File

@@ -0,0 +1,161 @@
package arangodb
import (
"context"
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
_, err := webhookCollection.CreateDocument(ctx, webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
meta, err := webhookCollection.UpdateDocument(ctx, webhook.Key, webhook)
if err != nil {
return nil, err
}
webhook.Key = meta.Key
webhook.ID = meta.ID.String()
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{}
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.Webhook, pagination.Offset, pagination.Limit)
sctx := driver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, nil)
if err != nil {
return nil, err
}
defer cursor.Close()
paginationClone := pagination
paginationClone.Total = cursor.Statistics().FullCount()
for {
var webhook models.Webhook
meta, err := cursor.ReadDocument(ctx, &webhook)
if arangoDriver.IsNoMoreDocuments(err) {
break
} else if err != nil {
return nil, err
}
if meta.Key != "" {
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: webhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf("FOR d in %s FILTER d._key == @webhook_id RETURN d", models.Collections.Webhook)
bindVars := map[string]interface{}{
"webhook_id": webhookID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return nil, err
}
defer cursor.Close()
for {
if !cursor.HasMore() {
if webhook.Key == "" {
return nil, fmt.Errorf("webhook not found")
}
break
}
_, err := cursor.ReadDocument(ctx, &webhook)
if err != nil {
return nil, err
}
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf("FOR d in %s FILTER d.event_name == @event_name RETURN d", models.Collections.Webhook)
bindVars := map[string]interface{}{
"event_name": eventName,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return nil, err
}
defer cursor.Close()
for {
if !cursor.HasMore() {
if webhook.Key == "" {
return nil, fmt.Errorf("webhook not found")
}
break
}
_, err := cursor.ReadDocument(ctx, &webhook)
if err != nil {
return nil, err
}
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
_, err := webhookCollection.RemoveDocument(ctx, webhook.ID)
if err != nil {
return err
}
query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog)
bindVars := map[string]interface{}{
"webhook_id": webhook.ID,
}
cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil {
return err
}
defer cursor.Close()
return nil
}

View File

@@ -0,0 +1,75 @@
package arangodb
import (
"context"
"fmt"
"time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
webhookLogCollection, _ := p.db.Collection(ctx, models.Collections.WebhookLog)
_, err := webhookLogCollection.CreateDocument(ctx, webhookLog)
if err != nil {
return nil, err
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
webhookLogs := []*model.WebhookLog{}
bindVariables := map[string]interface{}{}
query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit)
if webhookID != "" {
query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit)
bindVariables = map[string]interface{}{
"webhook_id": webhookID,
}
}
sctx := driver.WithQueryFullCount(ctx)
cursor, err := p.db.Query(sctx, query, bindVariables)
if err != nil {
return nil, err
}
defer cursor.Close()
paginationClone := pagination
paginationClone.Total = cursor.Statistics().FullCount()
for {
var webhookLog models.WebhookLog
meta, err := cursor.ReadDocument(ctx, &webhookLog)
if arangoDriver.IsNoMoreDocuments(err) {
break
} else if err != nil {
return nil, err
}
if meta.Key != "" {
webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog())
}
}
return &model.WebhookLogs{
Pagination: &paginationClone,
WebhookLogs: webhookLogs,
}, nil
}

View File

@@ -0,0 +1,159 @@
package cassandradb
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/gocql/gocql"
"github.com/google/uuid"
)
// AddEmailTemplate to add EmailTemplate
func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
if emailTemplate.ID == "" {
emailTemplate.ID = uuid.New().String()
}
emailTemplate.Key = emailTemplate.ID
emailTemplate.CreatedAt = time.Now().Unix()
emailTemplate.UpdatedAt = time.Now().Unix()
existingEmailTemplate, _ := p.GetEmailTemplateByEventName(ctx, emailTemplate.EventName)
if existingEmailTemplate != nil {
return nil, fmt.Errorf("Email template with %s event_name already exists", emailTemplate.EventName)
}
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, template, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.EmailTemplate, emailTemplate.ID, emailTemplate.EventName, emailTemplate.Template, emailTemplate.CreatedAt, emailTemplate.UpdatedAt)
err := p.db.Query(insertQuery).Exec()
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
emailTemplate.UpdatedAt = time.Now().Unix()
bytes, err := json.Marshal(emailTemplate)
if err != nil {
return nil, err
}
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
emailTemplateMap := map[string]interface{}{}
err = decoder.Decode(&emailTemplateMap)
if err != nil {
return nil, err
}
updateFields := ""
for key, value := range emailTemplateMap {
if key == "_id" {
continue
}
if key == "_key" {
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.EmailTemplate, updateFields, emailTemplate.ID)
err = p.db.Query(query).Exec()
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
emailTemplates := []*model.EmailTemplate{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.EmailTemplate)
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil {
return nil, err
}
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
query := fmt.Sprintf("SELECT id, event_name, template, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.EmailTemplate, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var emailTemplate models.EmailTemplate
err := scanner.Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
if err != nil {
return nil, err
}
emailTemplates = append(emailTemplates, emailTemplate.AsAPIEmailTemplate())
}
counter++
}
return &model.EmailTemplates{
Pagination: &paginationClone,
EmailTemplates: emailTemplates,
}, nil
}
// GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
query := fmt.Sprintf(`SELECT id, event_name, template, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.EmailTemplate, emailTemplateID)
err := p.db.Query(query).Consistency(gocql.One).Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
query := fmt.Sprintf(`SELECT id, event_name, template, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.EmailTemplate, eventName)
err := p.db.Query(query).Consistency(gocql.One).Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// DeleteEmailTemplate to delete EmailTemplate
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.EmailTemplate, emailTemplate.ID)
err := p.db.Query(query).Exec()
if err != nil {
return err
}
return nil
}

View File

@@ -1,6 +1,7 @@
package cassandradb package cassandradb
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -10,7 +11,7 @@ import (
) )
// AddEnv to save environment information in database // AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) { func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" { if env.ID == "" {
env.ID = uuid.New().String() env.ID = uuid.New().String()
} }
@@ -27,7 +28,7 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
} }
// UpdateEnv to update environment information in database // UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = '%s', updated_at = %d WHERE id = '%s'", KeySpace+"."+models.Collections.Env, env.EnvData, env.UpdatedAt, env.ID) updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = '%s', updated_at = %d WHERE id = '%s'", KeySpace+"."+models.Collections.Env, env.EnvData, env.UpdatedAt, env.ID)
@@ -39,7 +40,7 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
} }
// GetEnv to get environment information from database // GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) { func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env var env models.Env
query := fmt.Sprintf("SELECT id, env, hash, created_at, updated_at FROM %s LIMIT 1", KeySpace+"."+models.Collections.Env) query := fmt.Sprintf("SELECT id, env, hash, created_at, updated_at FROM %s LIMIT 1", KeySpace+"."+models.Collections.Env)

View File

@@ -5,6 +5,7 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/crypto"
@@ -96,6 +97,8 @@ func NewProvider() (*provider, error) {
NumRetries: 3, NumRetries: 3,
} }
cassandraClient.Consistency = gocql.LocalQuorum cassandraClient.Consistency = gocql.LocalQuorum
cassandraClient.ConnectTimeout = 10 * time.Second
cassandraClient.ProtoVersion = 4
session, err := cassandraClient.CreateSession() session, err := cassandraClient.CreateSession()
if err != nil { if err != nil {
@@ -140,6 +143,11 @@ func NewProvider() (*provider, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sessionIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_session_user_id ON %s.%s (user_id)", KeySpace, models.Collections.Session)
err = session.Query(sessionIndexQuery).Exec()
if err != nil {
return nil, err
}
userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User) userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User)
err = session.Query(userCollectionQuery).Exec() err = session.Query(userCollectionQuery).Exec()
@@ -174,6 +182,39 @@ func NewProvider() (*provider, error) {
return nil, err return nil, err
} }
webhookCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, endpoint text, enabled boolean, headers text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Webhook)
err = session.Query(webhookCollectionQuery).Exec()
if err != nil {
return nil, err
}
webhookIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_webhook_event_name ON %s.%s (event_name)", KeySpace, models.Collections.Webhook)
err = session.Query(webhookIndexQuery).Exec()
if err != nil {
return nil, err
}
webhookLogCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, http_status bigint, response text, request text, webhook_id text,updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.WebhookLog)
err = session.Query(webhookLogCollectionQuery).Exec()
if err != nil {
return nil, err
}
webhookLogIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_webhook_log_webhook_id ON %s.%s (webhook_id)", KeySpace, models.Collections.WebhookLog)
err = session.Query(webhookLogIndexQuery).Exec()
if err != nil {
return nil, err
}
emailTemplateCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, template text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.EmailTemplate)
err = session.Query(emailTemplateCollectionQuery).Exec()
if err != nil {
return nil, err
}
emailTemplateIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_email_template_event_name ON %s.%s (event_name)", KeySpace, models.Collections.EmailTemplate)
err = session.Query(emailTemplateIndexQuery).Exec()
if err != nil {
return nil, err
}
return &provider{ return &provider{
db: session, db: session,
}, err }, err

View File

@@ -1,6 +1,7 @@
package cassandradb package cassandradb
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -9,7 +10,7 @@ import (
) )
// AddSession to save session information in database // AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error { func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" { if session.ID == "" {
session.ID = uuid.New().String() session.ID = uuid.New().String()
} }
@@ -24,13 +25,3 @@ func (p *provider) AddSession(session models.Session) error {
} }
return nil return nil
} }
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE user_id = '%s'", KeySpace+"."+models.Collections.Session, userId)
err := p.db.Query(deleteSessionQuery).Exec()
if err != nil {
return err
}
return nil
}

View File

@@ -1,6 +1,7 @@
package cassandradb package cassandradb
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@@ -16,7 +17,7 @@ import (
) )
// AddUser to save user information in database // AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) { func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" { if user.ID == "" {
user.ID = uuid.New().String() user.ID = uuid.New().String()
} }
@@ -79,7 +80,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
} }
// UpdateUser to update user information in database // UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) { func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
bytes, err := json.Marshal(user) bytes, err := json.Marshal(user)
@@ -97,10 +98,11 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
updateFields := "" updateFields := ""
for key, value := range userMap { for key, value := range userMap {
if value != nil && key != "_id" { if key == "_id" {
continue
} }
if key == "_id" { if key == "_key" {
continue continue
} }
@@ -130,14 +132,36 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
} }
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error { func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID) query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID)
err := p.db.Query(query).Exec() err := p.db.Query(query).Exec()
return err if err != nil {
return err
}
getSessionsQuery := fmt.Sprintf("SELECT id FROM %s WHERE user_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.Session, user.ID)
scanner := p.db.Query(getSessionsQuery).Iter().Scanner()
sessionIDs := ""
for scanner.Next() {
var wlID string
err = scanner.Scan(&wlID)
if err != nil {
return err
}
sessionIDs += fmt.Sprintf("'%s',", wlID)
}
sessionIDs = strings.TrimSuffix(sessionIDs, ",")
deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.Session, sessionIDs)
err = p.db.Query(deleteSessionQuery).Exec()
if err != nil {
return err
}
return nil
} }
// ListUsers to get list of users from database // ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
responseUsers := []*model.User{} responseUsers := []*model.User{}
paginationClone := pagination paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User) totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User)
@@ -171,9 +195,9 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
} }
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, email) query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)
if err != nil { if err != nil {
return user, err return user, err
@@ -182,7 +206,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
} }
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) { func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User var user models.User
query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id) query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id)
err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt)

View File

@@ -1,6 +1,7 @@
package cassandradb package cassandradb
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -11,7 +12,7 @@ import (
) )
// AddVerification to save verification request in database // AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" { if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String() verificationRequest.ID = uuid.New().String()
} }
@@ -28,7 +29,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
} }
// GetVerificationRequestByToken to get verification request from database using token // GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token) query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token)
@@ -40,7 +41,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
} }
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier) query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier)
@@ -53,7 +54,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
} }
// ListVerificationRequests to get list of verification requests from database // ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest var verificationRequests []*model.VerificationRequest
paginationClone := pagination paginationClone := pagination
@@ -89,7 +90,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
} }
// DeleteVerificationRequest to delete verification request from database // DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.VerificationRequest, verificationRequest.ID) query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.VerificationRequest, verificationRequest.ID)
err := p.db.Query(query).Exec() err := p.db.Query(query).Exec()
if err != nil { if err != nil {

View File

@@ -0,0 +1,172 @@
package cassandradb
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/gocql/gocql"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
existingHook, _ := p.GetWebhookByEventName(ctx, webhook.EventName)
if existingHook != nil {
return nil, fmt.Errorf("Webhook with %s event_name already exists", webhook.EventName)
}
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt)
err := p.db.Query(insertQuery).Exec()
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
bytes, err := json.Marshal(webhook)
if err != nil {
return nil, err
}
// use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling
decoder := json.NewDecoder(strings.NewReader(string(bytes)))
decoder.UseNumber()
webhookMap := map[string]interface{}{}
err = decoder.Decode(&webhookMap)
if err != nil {
return nil, err
}
updateFields := ""
for key, value := range webhookMap {
if key == "_id" {
continue
}
if key == "_key" {
continue
}
if value == nil {
updateFields += fmt.Sprintf("%s = null,", key)
continue
}
valueType := reflect.TypeOf(value)
if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
} else {
updateFields += fmt.Sprintf("%s = %v, ", key, value)
}
}
updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID)
err = p.db.Query(query).Exec()
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.Webhook)
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil {
return nil, err
}
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
query := fmt.Sprintf("SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.Webhook, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var webhook models.Webhook
err := scanner.Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
counter++
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: webhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID)
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName)
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, webhook.ID)
err := p.db.Query(query).Exec()
if err != nil {
return err
}
getWebhookLogQuery := fmt.Sprintf("SELECT id FROM %s WHERE webhook_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhook.ID)
scanner := p.db.Query(getWebhookLogQuery).Iter().Scanner()
webhookLogIDs := ""
for scanner.Next() {
var wlID string
err = scanner.Scan(&wlID)
if err != nil {
return err
}
webhookLogIDs += fmt.Sprintf("'%s',", wlID)
}
webhookLogIDs = strings.TrimSuffix(webhookLogIDs, ",")
query = fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.WebhookLog, webhookLogIDs)
err = p.db.Query(query).Exec()
return err
}

View File

@@ -0,0 +1,70 @@
package cassandradb
import (
"context"
"fmt"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/gocql/gocql"
"github.com/google/uuid"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
insertQuery := fmt.Sprintf("INSERT INTO %s (id, http_status, response, request, webhook_id, created_at, updated_at) VALUES ('%s', %d,'%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.WebhookLog, webhookLog.ID, webhookLog.HttpStatus, webhookLog.Response, webhookLog.Request, webhookLog.WebhookID, webhookLog.CreatedAt, webhookLog.UpdatedAt)
err := p.db.Query(insertQuery).Exec()
if err != nil {
return nil, err
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
webhookLogs := []*model.WebhookLog{}
paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.WebhookLog)
// there is no offset in cassandra
// so we fetch till limit + offset
// and return the results from offset to limit
query := fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, pagination.Limit+pagination.Offset)
if webhookID != "" {
totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.WebhookLog, webhookID)
query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset)
}
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil {
return nil, err
}
scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0)
for scanner.Next() {
if counter >= pagination.Offset {
var webhookLog models.WebhookLog
err := scanner.Scan(&webhookLog.ID, &webhookLog.HttpStatus, &webhookLog.Response, &webhookLog.Request, &webhookLog.WebhookID, &webhookLog.CreatedAt, &webhookLog.UpdatedAt)
if err != nil {
return nil, err
}
webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog())
}
counter++
}
return &model.WebhookLogs{
Pagination: &paginationClone,
WebhookLogs: webhookLogs,
}, nil
}

View File

@@ -0,0 +1,115 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
)
// AddEmailTemplate to add EmailTemplate
func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
if emailTemplate.ID == "" {
emailTemplate.ID = uuid.New().String()
}
emailTemplate.Key = emailTemplate.ID
emailTemplate.CreatedAt = time.Now().Unix()
emailTemplate.UpdatedAt = time.Now().Unix()
emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection())
_, err := emailTemplateCollection.InsertOne(ctx, emailTemplate)
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
emailTemplate.UpdatedAt = time.Now().Unix()
emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection())
_, err := emailTemplateCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": emailTemplate.ID}}, bson.M{"$set": emailTemplate}, options.MergeUpdateOptions())
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
var emailTemplates []*model.EmailTemplate
opts := options.Find()
opts.SetLimit(pagination.Limit)
opts.SetSkip(pagination.Offset)
opts.SetSort(bson.M{"created_at": -1})
paginationClone := pagination
emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection())
count, err := emailTemplateCollection.CountDocuments(ctx, bson.M{}, options.Count())
if err != nil {
return nil, err
}
paginationClone.Total = count
cursor, err := emailTemplateCollection.Find(ctx, bson.M{}, opts)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
var emailTemplate models.EmailTemplate
err := cursor.Decode(&emailTemplate)
if err != nil {
return nil, err
}
emailTemplates = append(emailTemplates, emailTemplate.AsAPIEmailTemplate())
}
return &model.EmailTemplates{
Pagination: &paginationClone,
EmailTemplates: emailTemplates,
}, nil
}
// GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection())
err := emailTemplateCollection.FindOne(ctx, bson.M{"_id": emailTemplateID}).Decode(&emailTemplate)
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection())
err := emailTemplateCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&emailTemplate)
if err != nil {
return nil, err
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// DeleteEmailTemplate to delete EmailTemplate
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error {
emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection())
_, err := emailTemplateCollection.DeleteOne(nil, bson.M{"_id": emailTemplate.ID}, options.Delete())
if err != nil {
return err
}
return nil
}

View File

@@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -11,7 +12,7 @@ import (
) )
// AddEnv to save environment information in database // AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) { func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" { if env.ID == "" {
env.ID = uuid.New().String() env.ID = uuid.New().String()
} }
@@ -20,7 +21,7 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
env.Key = env.ID env.Key = env.ID
configCollection := p.db.Collection(models.Collections.Env, options.Collection()) configCollection := p.db.Collection(models.Collections.Env, options.Collection())
_, err := configCollection.InsertOne(nil, env) _, err := configCollection.InsertOne(ctx, env)
if err != nil { if err != nil {
return env, err return env, err
} }
@@ -28,10 +29,10 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
} }
// UpdateEnv to update environment information in database // UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
configCollection := p.db.Collection(models.Collections.Env, options.Collection()) configCollection := p.db.Collection(models.Collections.Env, options.Collection())
_, err := configCollection.UpdateOne(nil, bson.M{"_id": bson.M{"$eq": env.ID}}, bson.M{"$set": env}, options.MergeUpdateOptions()) _, err := configCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": env.ID}}, bson.M{"$set": env}, options.MergeUpdateOptions())
if err != nil { if err != nil {
return env, err return env, err
} }
@@ -39,14 +40,14 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
} }
// GetEnv to get environment information from database // GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) { func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env var env models.Env
configCollection := p.db.Collection(models.Collections.Env, options.Collection()) configCollection := p.db.Collection(models.Collections.Env, options.Collection())
cursor, err := configCollection.Find(nil, bson.M{}, options.Find()) cursor, err := configCollection.Find(ctx, bson.M{}, options.Find())
if err != nil { if err != nil {
return env, err return env, err
} }
defer cursor.Close(nil) defer cursor.Close(ctx)
for cursor.Next(nil) { for cursor.Next(nil) {
err := cursor.Decode(&env) err := cursor.Decode(&env)

View File

@@ -83,6 +83,33 @@ func NewProvider() (*provider, error) {
mongodb.CreateCollection(ctx, models.Collections.Env, options.CreateCollection()) mongodb.CreateCollection(ctx, models.Collections.Env, options.CreateCollection())
mongodb.CreateCollection(ctx, models.Collections.Webhook, options.CreateCollection())
webhookCollection := mongodb.Collection(models.Collections.Webhook, options.Collection())
webhookCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{
{
Keys: bson.M{"event_name": 1},
Options: options.Index().SetUnique(true).SetSparse(true),
},
}, options.CreateIndexes())
mongodb.CreateCollection(ctx, models.Collections.WebhookLog, options.CreateCollection())
webhookLogCollection := mongodb.Collection(models.Collections.WebhookLog, options.Collection())
webhookLogCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{
{
Keys: bson.M{"webhook_id": 1},
Options: options.Index().SetSparse(true),
},
}, options.CreateIndexes())
mongodb.CreateCollection(ctx, models.Collections.EmailTemplate, options.CreateCollection())
emailTemplateCollection := mongodb.Collection(models.Collections.EmailTemplate, options.Collection())
emailTemplateCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{
{
Keys: bson.M{"event_name": 1},
Options: options.Index().SetUnique(true).SetSparse(true),
},
}, options.CreateIndexes())
return &provider{ return &provider{
db: mongodb, db: mongodb,
}, nil }, nil

View File

@@ -1,16 +1,16 @@
package mongodb package mongodb
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
"github.com/google/uuid" "github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
) )
// AddSession to save session information in database // AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error { func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" { if session.ID == "" {
session.ID = uuid.New().String() session.ID = uuid.New().String()
} }
@@ -19,17 +19,7 @@ func (p *provider) AddSession(session models.Session) error {
session.CreatedAt = time.Now().Unix() session.CreatedAt = time.Now().Unix()
session.UpdatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix()
sessionCollection := p.db.Collection(models.Collections.Session, options.Collection()) sessionCollection := p.db.Collection(models.Collections.Session, options.Collection())
_, err := sessionCollection.InsertOne(nil, session) _, err := sessionCollection.InsertOne(ctx, session)
if err != nil {
return err
}
return nil
}
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
sessionCollection := p.db.Collection(models.Collections.Session, options.Collection())
_, err := sessionCollection.DeleteMany(nil, bson.M{"user_id": userId}, options.Delete())
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@@ -13,7 +14,7 @@ import (
) )
// AddUser to save user information in database // AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) { func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" { if user.ID == "" {
user.ID = uuid.New().String() user.ID = uuid.New().String()
} }
@@ -29,7 +30,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
user.Key = user.ID user.Key = user.ID
userCollection := p.db.Collection(models.Collections.User, options.Collection()) userCollection := p.db.Collection(models.Collections.User, options.Collection())
_, err := userCollection.InsertOne(nil, user) _, err := userCollection.InsertOne(ctx, user)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -38,10 +39,10 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
} }
// UpdateUser to update user information in database // UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) { func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
userCollection := p.db.Collection(models.Collections.User, options.Collection()) userCollection := p.db.Collection(models.Collections.User, options.Collection())
_, err := userCollection.UpdateOne(nil, bson.M{"_id": bson.M{"$eq": user.ID}}, bson.M{"$set": user}, options.MergeUpdateOptions()) _, err := userCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": user.ID}}, bson.M{"$set": user}, options.MergeUpdateOptions())
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -49,9 +50,15 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
} }
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error { func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
userCollection := p.db.Collection(models.Collections.User, options.Collection()) userCollection := p.db.Collection(models.Collections.User, options.Collection())
_, err := userCollection.DeleteOne(nil, bson.M{"_id": user.ID}, options.Delete()) _, err := userCollection.DeleteOne(ctx, bson.M{"_id": user.ID}, options.Delete())
if err != nil {
return err
}
sessionCollection := p.db.Collection(models.Collections.Session, options.Collection())
_, err = sessionCollection.DeleteMany(ctx, bson.M{"user_id": user.ID}, options.Delete())
if err != nil { if err != nil {
return err return err
} }
@@ -60,7 +67,7 @@ func (p *provider) DeleteUser(user models.User) error {
} }
// ListUsers to get list of users from database // ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []*model.User var users []*model.User
opts := options.Find() opts := options.Find()
opts.SetLimit(pagination.Limit) opts.SetLimit(pagination.Limit)
@@ -70,20 +77,20 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
paginationClone := pagination paginationClone := pagination
userCollection := p.db.Collection(models.Collections.User, options.Collection()) userCollection := p.db.Collection(models.Collections.User, options.Collection())
count, err := userCollection.CountDocuments(nil, bson.M{}, options.Count()) count, err := userCollection.CountDocuments(ctx, bson.M{}, options.Count())
if err != nil { if err != nil {
return nil, err return nil, err
} }
paginationClone.Total = count paginationClone.Total = count
cursor, err := userCollection.Find(nil, bson.M{}, opts) cursor, err := userCollection.Find(ctx, bson.M{}, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cursor.Close(nil) defer cursor.Close(ctx)
for cursor.Next(nil) { for cursor.Next(ctx) {
var user models.User var user models.User
err := cursor.Decode(&user) err := cursor.Decode(&user)
if err != nil { if err != nil {
@@ -99,10 +106,10 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
} }
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User var user models.User
userCollection := p.db.Collection(models.Collections.User, options.Collection()) userCollection := p.db.Collection(models.Collections.User, options.Collection())
err := userCollection.FindOne(nil, bson.M{"email": email}).Decode(&user) err := userCollection.FindOne(ctx, bson.M{"email": email}).Decode(&user)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -111,11 +118,11 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
} }
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) { func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User var user models.User
userCollection := p.db.Collection(models.Collections.User, options.Collection()) userCollection := p.db.Collection(models.Collections.User, options.Collection())
err := userCollection.FindOne(nil, bson.M{"_id": id}).Decode(&user) err := userCollection.FindOne(ctx, bson.M{"_id": id}).Decode(&user)
if err != nil { if err != nil {
return user, err return user, err
} }

View File

@@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -11,7 +12,7 @@ import (
) )
// AddVerification to save verification request in database // AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" { if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String() verificationRequest.ID = uuid.New().String()
@@ -19,7 +20,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
verificationRequest.UpdatedAt = time.Now().Unix() verificationRequest.UpdatedAt = time.Now().Unix()
verificationRequest.Key = verificationRequest.ID verificationRequest.Key = verificationRequest.ID
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
_, err := verificationRequestCollection.InsertOne(nil, verificationRequest) _, err := verificationRequestCollection.InsertOne(ctx, verificationRequest)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -29,11 +30,11 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
} }
// GetVerificationRequestByToken to get verification request from database using token // GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
err := verificationRequestCollection.FindOne(nil, bson.M{"token": token}).Decode(&verificationRequest) err := verificationRequestCollection.FindOne(ctx, bson.M{"token": token}).Decode(&verificationRequest)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -42,11 +43,11 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
} }
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
err := verificationRequestCollection.FindOne(nil, bson.M{"email": email, "identifier": identifier}).Decode(&verificationRequest) err := verificationRequestCollection.FindOne(ctx, bson.M{"email": email, "identifier": identifier}).Decode(&verificationRequest)
if err != nil { if err != nil {
return verificationRequest, err return verificationRequest, err
} }
@@ -55,7 +56,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
} }
// ListVerificationRequests to get list of verification requests from database // ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []*model.VerificationRequest var verificationRequests []*model.VerificationRequest
opts := options.Find() opts := options.Find()
@@ -65,17 +66,17 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
verificationRequestCollectionCount, err := verificationRequestCollection.CountDocuments(nil, bson.M{}) verificationRequestCollectionCount, err := verificationRequestCollection.CountDocuments(ctx, bson.M{})
paginationClone := pagination paginationClone := pagination
paginationClone.Total = verificationRequestCollectionCount paginationClone.Total = verificationRequestCollectionCount
cursor, err := verificationRequestCollection.Find(nil, bson.M{}, opts) cursor, err := verificationRequestCollection.Find(ctx, bson.M{}, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cursor.Close(nil) defer cursor.Close(ctx)
for cursor.Next(nil) { for cursor.Next(ctx) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
err := cursor.Decode(&verificationRequest) err := cursor.Decode(&verificationRequest)
if err != nil { if err != nil {
@@ -91,9 +92,9 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
} }
// DeleteVerificationRequest to delete verification request from database // DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection())
_, err := verificationRequestCollection.DeleteOne(nil, bson.M{"_id": verificationRequest.ID}, options.Delete()) _, err := verificationRequestCollection.DeleteOne(ctx, bson.M{"_id": verificationRequest.ID}, options.Delete())
if err != nil { if err != nil {
return err return err
} }

View File

@@ -0,0 +1,120 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.InsertOne(ctx, webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": webhook.ID}}, bson.M{"$set": webhook}, options.MergeUpdateOptions())
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
var webhooks []*model.Webhook
opts := options.Find()
opts.SetLimit(pagination.Limit)
opts.SetSkip(pagination.Offset)
opts.SetSort(bson.M{"created_at": -1})
paginationClone := pagination
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
count, err := webhookCollection.CountDocuments(ctx, bson.M{}, options.Count())
if err != nil {
return nil, err
}
paginationClone.Total = count
cursor, err := webhookCollection.Find(ctx, bson.M{}, opts)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
var webhook models.Webhook
err := cursor.Decode(&webhook)
if err != nil {
return nil, err
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: webhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
err := webhookCollection.FindOne(ctx, bson.M{"_id": webhookID}).Decode(&webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
err := webhookCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&webhook)
if err != nil {
return nil, err
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.DeleteOne(nil, bson.M{"_id": webhook.ID}, options.Delete())
if err != nil {
return err
}
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
_, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete())
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,74 @@
package mongodb
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
_, err := webhookLogCollection.InsertOne(ctx, webhookLog)
if err != nil {
return nil, err
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
webhookLogs := []*model.WebhookLog{}
opts := options.Find()
opts.SetLimit(pagination.Limit)
opts.SetSkip(pagination.Offset)
opts.SetSort(bson.M{"created_at": -1})
paginationClone := pagination
query := bson.M{}
if webhookID != "" {
query = bson.M{"webhook_id": webhookID}
}
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
count, err := webhookLogCollection.CountDocuments(ctx, query, options.Count())
if err != nil {
return nil, err
}
paginationClone.Total = count
cursor, err := webhookLogCollection.Find(ctx, query, opts)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
var webhookLog models.WebhookLog
err := cursor.Decode(&webhookLog)
if err != nil {
return nil, err
}
webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog())
}
return &model.WebhookLogs{
Pagination: &paginationClone,
WebhookLogs: webhookLogs,
}, nil
}

View File

@@ -0,0 +1,48 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddEmailTemplate to add EmailTemplate
func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
if emailTemplate.ID == "" {
emailTemplate.ID = uuid.New().String()
}
emailTemplate.Key = emailTemplate.ID
emailTemplate.CreatedAt = time.Now().Unix()
emailTemplate.UpdatedAt = time.Now().Unix()
return emailTemplate.AsAPIEmailTemplate(), nil
}
// UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
emailTemplate.UpdatedAt = time.Now().Unix()
return emailTemplate.AsAPIEmailTemplate(), nil
}
// ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
return nil, nil
}
// GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) {
return nil, nil
}
// GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) {
return nil, nil
}
// DeleteEmailTemplate to delete EmailTemplate
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error {
return nil
}

View File

@@ -1,6 +1,7 @@
package provider_template package provider_template
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -8,7 +9,7 @@ import (
) )
// AddEnv to save environment information in database // AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) { func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" { if env.ID == "" {
env.ID = uuid.New().String() env.ID = uuid.New().String()
} }
@@ -19,13 +20,13 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
} }
// UpdateEnv to update environment information in database // UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
return env, nil return env, nil
} }
// GetEnv to get environment information from database // GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) { func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env var env models.Env
return env, nil return env, nil

View File

@@ -1,6 +1,7 @@
package provider_template package provider_template
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -8,7 +9,7 @@ import (
) )
// AddSession to save session information in database // AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error { func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" { if session.ID == "" {
session.ID = uuid.New().String() session.ID = uuid.New().String()
} }
@@ -19,6 +20,6 @@ func (p *provider) AddSession(session models.Session) error {
} }
// DeleteSession to delete session information from database // DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error { func (p *provider) DeleteSession(ctx context.Context, userId string) error {
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package provider_template package provider_template
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@@ -11,7 +12,7 @@ import (
) )
// AddUser to save user information in database // AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) { func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" { if user.ID == "" {
user.ID = uuid.New().String() user.ID = uuid.New().String()
} }
@@ -31,30 +32,30 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
} }
// UpdateUser to update user information in database // UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) { func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
return user, nil return user, nil
} }
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error { func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
return nil return nil
} }
// ListUsers to get list of users from database // ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
return nil, nil return nil, nil
} }
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User var user models.User
return user, nil return user, nil
} }
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) { func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User var user models.User
return user, nil return user, nil

View File

@@ -1,6 +1,7 @@
package provider_template package provider_template
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -9,7 +10,7 @@ import (
) )
// AddVerification to save verification request in database // AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" { if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String() verificationRequest.ID = uuid.New().String()
} }
@@ -21,25 +22,25 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
} }
// GetVerificationRequestByToken to get verification request from database using token // GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
return verificationRequest, nil return verificationRequest, nil
} }
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
return verificationRequest, nil return verificationRequest, nil
} }
// ListVerificationRequests to get list of verification requests from database // ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
return nil, nil return nil, nil
} }
// DeleteVerificationRequest to delete verification request from database // DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
return nil return nil
} }

View File

@@ -0,0 +1,49 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
return nil, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
return nil, nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
return nil, nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
// Also delete webhook logs for given webhook id
return nil
}

View File

@@ -0,0 +1,27 @@
package provider_template
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
return nil, nil
}

View File

@@ -1,44 +1,75 @@
package providers package providers
import ( import (
"context"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
) )
type Provider interface { type Provider interface {
// AddUser to save user information in database // AddUser to save user information in database
AddUser(user models.User) (models.User, error) AddUser(ctx context.Context, user models.User) (models.User, error)
// UpdateUser to update user information in database // UpdateUser to update user information in database
UpdateUser(user models.User) (models.User, error) UpdateUser(ctx context.Context, user models.User) (models.User, error)
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
DeleteUser(user models.User) error DeleteUser(ctx context.Context, user models.User) error
// ListUsers to get list of users from database // ListUsers to get list of users from database
ListUsers(pagination model.Pagination) (*model.Users, error) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error)
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
GetUserByEmail(email string) (models.User, error) GetUserByEmail(ctx context.Context, email string) (models.User, error)
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
GetUserByID(id string) (models.User, error) GetUserByID(ctx context.Context, id string) (models.User, error)
// AddVerification to save verification request in database // AddVerification to save verification request in database
AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error)
// GetVerificationRequestByToken to get verification request from database using token // GetVerificationRequestByToken to get verification request from database using token
GetVerificationRequestByToken(token string) (models.VerificationRequest, error) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error)
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error)
// ListVerificationRequests to get list of verification requests from database // ListVerificationRequests to get list of verification requests from database
ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error)
// DeleteVerificationRequest to delete verification request from database // DeleteVerificationRequest to delete verification request from database
DeleteVerificationRequest(verificationRequest models.VerificationRequest) error DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error
// AddSession to save session information in database // AddSession to save session information in database
AddSession(session models.Session) error AddSession(ctx context.Context, session models.Session) error
// DeleteSession to delete session information from database
DeleteSession(userId string) error
// AddEnv to save environment information in database // AddEnv to save environment information in database
AddEnv(env models.Env) (models.Env, error) AddEnv(ctx context.Context, env models.Env) (models.Env, error)
// UpdateEnv to update environment information in database // UpdateEnv to update environment information in database
UpdateEnv(env models.Env) (models.Env, error) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error)
// GetEnv to get environment information from database // GetEnv to get environment information from database
GetEnv() (models.Env, error) GetEnv(ctx context.Context) (models.Env, error)
// AddWebhook to add webhook
AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error)
// UpdateWebhook to update webhook
UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error)
// ListWebhooks to list webhook
ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error)
// GetWebhookByID to get webhook by id
GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error)
// GetWebhookByEventName to get webhook by event_name
GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error)
// DeleteWebhook to delete webhook
DeleteWebhook(ctx context.Context, webhook *model.Webhook) error
// AddWebhookLog to add webhook log
AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error)
// ListWebhookLogs to list webhook logs
ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error)
// AddEmailTemplate to add EmailTemplate
AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error)
// UpdateEmailTemplate to update EmailTemplate
UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error)
// ListEmailTemplates to list EmailTemplate
ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error)
// GetEmailTemplateByID to get EmailTemplate by id
GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error)
// GetEmailTemplateByEventName to get EmailTemplate by event_name
GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error)
// DeleteEmailTemplate to delete EmailTemplate
DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error
} }

View File

@@ -0,0 +1,100 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddEmailTemplate to add EmailTemplate
func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
if emailTemplate.ID == "" {
emailTemplate.ID = uuid.New().String()
}
emailTemplate.Key = emailTemplate.ID
emailTemplate.CreatedAt = time.Now().Unix()
emailTemplate.UpdatedAt = time.Now().Unix()
res := p.db.Create(&emailTemplate)
if res.Error != nil {
return nil, res.Error
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) {
emailTemplate.UpdatedAt = time.Now().Unix()
res := p.db.Save(&emailTemplate)
if res.Error != nil {
return nil, res.Error
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) {
var emailTemplates []models.EmailTemplate
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&emailTemplates)
if result.Error != nil {
return nil, result.Error
}
var total int64
totalRes := p.db.Model(&models.EmailTemplate{}).Count(&total)
if totalRes.Error != nil {
return nil, totalRes.Error
}
paginationClone := pagination
paginationClone.Total = total
responseEmailTemplates := []*model.EmailTemplate{}
for _, w := range emailTemplates {
responseEmailTemplates = append(responseEmailTemplates, w.AsAPIEmailTemplate())
}
return &model.EmailTemplates{
Pagination: &paginationClone,
EmailTemplates: responseEmailTemplates,
}, nil
}
// GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
result := p.db.Where("id = ?", emailTemplateID).First(&emailTemplate)
if result.Error != nil {
return nil, result.Error
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) {
var emailTemplate models.EmailTemplate
result := p.db.Where("event_name = ?", eventName).First(&emailTemplate)
if result.Error != nil {
return nil, result.Error
}
return emailTemplate.AsAPIEmailTemplate(), nil
}
// DeleteEmailTemplate to delete EmailTemplate
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error {
result := p.db.Delete(&models.EmailTemplate{
ID: emailTemplate.ID,
})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -1,6 +1,7 @@
package sql package sql
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -8,7 +9,7 @@ import (
) )
// AddEnv to save environment information in database // AddEnv to save environment information in database
func (p *provider) AddEnv(env models.Env) (models.Env, error) { func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) {
if env.ID == "" { if env.ID == "" {
env.ID = uuid.New().String() env.ID = uuid.New().String()
} }
@@ -25,7 +26,7 @@ func (p *provider) AddEnv(env models.Env) (models.Env, error) {
} }
// UpdateEnv to update environment information in database // UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) {
env.UpdatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix()
result := p.db.Save(&env) result := p.db.Save(&env)
@@ -36,7 +37,7 @@ func (p *provider) UpdateEnv(env models.Env) (models.Env, error) {
} }
// GetEnv to get environment information from database // GetEnv to get environment information from database
func (p *provider) GetEnv() (models.Env, error) { func (p *provider) GetEnv(ctx context.Context) (models.Env, error) {
var env models.Env var env models.Env
result := p.db.First(&env) result := p.db.First(&env)

View File

@@ -50,7 +50,7 @@ func NewProvider() (*provider, error) {
sqlDB, err = gorm.Open(postgres.Open(dbURL), ormConfig) sqlDB, err = gorm.Open(postgres.Open(dbURL), ormConfig)
case constants.DbTypeSqlite: case constants.DbTypeSqlite:
sqlDB, err = gorm.Open(sqlite.Open(dbURL), ormConfig) sqlDB, err = gorm.Open(sqlite.Open(dbURL), ormConfig)
case constants.DbTypeMysql, constants.DbTypeMariaDB: case constants.DbTypeMysql, constants.DbTypeMariaDB, constants.DbTypePlanetScaleDB:
sqlDB, err = gorm.Open(mysql.Open(dbURL), ormConfig) sqlDB, err = gorm.Open(mysql.Open(dbURL), ormConfig)
case constants.DbTypeSqlserver: case constants.DbTypeSqlserver:
sqlDB, err = gorm.Open(sqlserver.Open(dbURL), ormConfig) sqlDB, err = gorm.Open(sqlserver.Open(dbURL), ormConfig)
@@ -60,7 +60,7 @@ func NewProvider() (*provider, error) {
return nil, err return nil, err
} }
err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}) err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}, &models.Webhook{}, models.WebhookLog{}, models.EmailTemplate{})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,7 @@
package sql package sql
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -9,7 +10,7 @@ import (
) )
// AddSession to save session information in database // AddSession to save session information in database
func (p *provider) AddSession(session models.Session) error { func (p *provider) AddSession(ctx context.Context, session models.Session) error {
if session.ID == "" { if session.ID == "" {
session.ID = uuid.New().String() session.ID = uuid.New().String()
} }
@@ -26,13 +27,3 @@ func (p *provider) AddSession(session models.Session) error {
} }
return nil return nil
} }
// DeleteSession to delete session information from database
func (p *provider) DeleteSession(userId string) error {
result := p.db.Where("user_id = ?", userId).Delete(&models.Session{})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -1,6 +1,7 @@
package sql package sql
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@@ -12,7 +13,7 @@ import (
) )
// AddUser to save user information in database // AddUser to save user information in database
func (p *provider) AddUser(user models.User) (models.User, error) { func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) {
if user.ID == "" { if user.ID == "" {
user.ID = uuid.New().String() user.ID = uuid.New().String()
} }
@@ -42,7 +43,7 @@ func (p *provider) AddUser(user models.User) (models.User, error) {
} }
// UpdateUser to update user information in database // UpdateUser to update user information in database
func (p *provider) UpdateUser(user models.User) (models.User, error) { func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) {
user.UpdatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix()
result := p.db.Save(&user) result := p.db.Save(&user)
@@ -55,18 +56,23 @@ func (p *provider) UpdateUser(user models.User) (models.User, error) {
} }
// DeleteUser to delete user information from database // DeleteUser to delete user information from database
func (p *provider) DeleteUser(user models.User) error { func (p *provider) DeleteUser(ctx context.Context, user models.User) error {
result := p.db.Delete(&user) result := p.db.Delete(&user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
result = p.db.Where("user_id = ?", user.ID).Delete(&models.Session{})
if result.Error != nil {
return result.Error
}
return nil return nil
} }
// ListUsers to get list of users from database // ListUsers to get list of users from database
func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) {
var users []models.User var users []models.User
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&users) result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&users)
if result.Error != nil { if result.Error != nil {
@@ -94,7 +100,7 @@ func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error)
} }
// GetUserByEmail to get user information from database using email address // GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(email string) (models.User, error) { func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) {
var user models.User var user models.User
result := p.db.Where("email = ?", email).First(&user) result := p.db.Where("email = ?", email).First(&user)
if result.Error != nil { if result.Error != nil {
@@ -105,7 +111,7 @@ func (p *provider) GetUserByEmail(email string) (models.User, error) {
} }
// GetUserByID to get user information from database using user ID // GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(id string) (models.User, error) { func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) {
var user models.User var user models.User
result := p.db.Where("id = ?", id).First(&user) result := p.db.Where("id = ?", id).First(&user)

View File

@@ -1,6 +1,7 @@
package sql package sql
import ( import (
"context"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -10,7 +11,7 @@ import (
) )
// AddVerification to save verification request in database // AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) {
if verificationRequest.ID == "" { if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String() verificationRequest.ID = uuid.New().String()
} }
@@ -31,7 +32,7 @@ func (p *provider) AddVerificationRequest(verificationRequest models.Verificatio
} }
// GetVerificationRequestByToken to get verification request from database using token // GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
result := p.db.Where("token = ?", token).First(&verificationRequest) result := p.db.Where("token = ?", token).First(&verificationRequest)
@@ -43,7 +44,7 @@ func (p *provider) GetVerificationRequestByToken(token string) (models.Verificat
} }
// GetVerificationRequestByEmail to get verification request by email from database // GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) {
var verificationRequest models.VerificationRequest var verificationRequest models.VerificationRequest
result := p.db.Where("email = ? AND identifier = ?", email, identifier).First(&verificationRequest) result := p.db.Where("email = ? AND identifier = ?", email, identifier).First(&verificationRequest)
@@ -56,7 +57,7 @@ func (p *provider) GetVerificationRequestByEmail(email string, identifier string
} }
// ListVerificationRequests to get list of verification requests from database // ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) {
var verificationRequests []models.VerificationRequest var verificationRequests []models.VerificationRequest
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&verificationRequests) result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&verificationRequests)
@@ -85,7 +86,7 @@ func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model
} }
// DeleteVerificationRequest to delete verification request from database // DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error {
result := p.db.Delete(&verificationRequest) result := p.db.Delete(&verificationRequest)
if result.Error != nil { if result.Error != nil {

View File

@@ -0,0 +1,104 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
)
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
res := p.db.Create(&webhook)
if res.Error != nil {
return nil, res.Error
}
return webhook.AsAPIWebhook(), nil
}
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
result := p.db.Save(&webhook)
if result.Error != nil {
return nil, result.Error
}
return webhook.AsAPIWebhook(), nil
}
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
var webhooks []models.Webhook
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks)
if result.Error != nil {
return nil, result.Error
}
var total int64
totalRes := p.db.Model(&models.Webhook{}).Count(&total)
if totalRes.Error != nil {
return nil, totalRes.Error
}
paginationClone := pagination
paginationClone.Total = total
responseWebhooks := []*model.Webhook{}
for _, w := range webhooks {
responseWebhooks = append(responseWebhooks, w.AsAPIWebhook())
}
return &model.Webhooks{
Pagination: &paginationClone,
Webhooks: responseWebhooks,
}, nil
}
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook
result := p.db.Where("id = ?", webhookID).First(&webhook)
if result.Error != nil {
return nil, result.Error
}
return webhook.AsAPIWebhook(), nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) {
var webhook models.Webhook
result := p.db.Where("event_name = ?", eventName).First(&webhook)
if result.Error != nil {
return nil, result.Error
}
return webhook.AsAPIWebhook(), nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
result := p.db.Delete(&models.Webhook{
ID: webhook.ID,
})
if result.Error != nil {
return result.Error
}
result = p.db.Where("webhook_id = ?", webhook.ID).Delete(&models.WebhookLog{})
if result.Error != nil {
return result.Error
}
return nil
}

View File

@@ -0,0 +1,68 @@
package sql
import (
"context"
"time"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) {
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
res := p.db.Clauses(
clause.OnConflict{
DoNothing: true,
}).Create(&webhookLog)
if res.Error != nil {
return nil, res.Error
}
return webhookLog.AsAPIWebhookLog(), nil
}
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) {
var webhookLogs []models.WebhookLog
var result *gorm.DB
var totalRes *gorm.DB
var total int64
if webhookID != "" {
result = p.db.Where("webhook_id = ?", webhookID).Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhookLogs)
totalRes = p.db.Where("webhook_id = ?", webhookID).Model(&models.WebhookLog{}).Count(&total)
} else {
result = p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhookLogs)
totalRes = p.db.Model(&models.WebhookLog{}).Count(&total)
}
if result.Error != nil {
return nil, result.Error
}
if totalRes.Error != nil {
return nil, totalRes.Error
}
paginationClone := pagination
paginationClone.Total = total
responseWebhookLogs := []*model.WebhookLog{}
for _, w := range webhookLogs {
responseWebhookLogs = append(responseWebhookLogs, w.AsAPIWebhookLog())
}
return &model.WebhookLogs{
WebhookLogs: responseWebhookLogs,
Pagination: &paginationClone,
}, nil
}

View File

@@ -70,7 +70,7 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error {
<tr style="background: rgb(249,250,251);padding: 10px;margin-bottom:10px;border-radius:5px;"> <tr style="background: rgb(249,250,251);padding: 10px;margin-bottom:10px;border-radius:5px;">
<td class="esd-block-text es-m-txt-c es-p15t" align="center" style="padding:10px;padding-bottom:30px;"> <td class="esd-block-text es-m-txt-c es-p15t" align="center" style="padding:10px;padding-bottom:30px;">
<p>Hi there 👋</p> <p>Hi there 👋</p>
<p>Join us! You are invited to sign-up for <b>{{.org_name}}</b>. Please accept the invitation by clicking the clicking the button below.</p> <br/> <p>Join us! You are invited to sign-up for <b>{{.org_name}}</b>. Please accept the invitation by clicking the button below.</p> <br/>
<a <a
clicktracking="off" href="{{.verification_url}}" class="es-button" target="_blank" style="text-decoration: none;padding:10px 15px;background-color: rgba(59,130,246,1);color: #fff;font-size: 1em;border-radius:5px;">Get Started</a> clicktracking="off" href="{{.verification_url}}" class="es-button" target="_blank" style="text-decoration: none;padding:10px 15px;background-color: rgba(59,130,246,1);color: #fff;font-size: 1em;border-radius:5px;">Get Started</a>
</td> </td>

14
server/env/env.go vendored
View File

@@ -83,6 +83,7 @@ func InitAllEnv() error {
osDisableLoginPage := os.Getenv(constants.EnvKeyDisableLoginPage) osDisableLoginPage := os.Getenv(constants.EnvKeyDisableLoginPage)
osDisableSignUp := os.Getenv(constants.EnvKeyDisableSignUp) osDisableSignUp := os.Getenv(constants.EnvKeyDisableSignUp)
osDisableRedisForEnv := os.Getenv(constants.EnvKeyDisableRedisForEnv) osDisableRedisForEnv := os.Getenv(constants.EnvKeyDisableRedisForEnv)
osDisableStrongPassword := os.Getenv(constants.EnvKeyDisableStrongPassword)
// os slice vars // os slice vars
osAllowedOrigins := os.Getenv(constants.EnvKeyAllowedOrigins) osAllowedOrigins := os.Getenv(constants.EnvKeyAllowedOrigins)
@@ -476,6 +477,19 @@ func InitAllEnv() error {
} }
} }
if _, ok := envData[constants.EnvKeyDisableStrongPassword]; !ok {
envData[constants.EnvKeyDisableStrongPassword] = osDisableStrongPassword == "true"
}
if osDisableStrongPassword != "" {
boolValue, err := strconv.ParseBool(osDisableStrongPassword)
if err != nil {
return err
}
if boolValue != envData[constants.EnvKeyDisableStrongPassword].(bool) {
envData[constants.EnvKeyDisableStrongPassword] = boolValue
}
}
// no need to add nil check as its already done above // no need to add nil check as its already done above
if envData[constants.EnvKeySmtpHost] == "" || envData[constants.EnvKeySmtpUsername] == "" || envData[constants.EnvKeySmtpPassword] == "" || envData[constants.EnvKeySenderEmail] == "" && envData[constants.EnvKeySmtpPort] == "" { if envData[constants.EnvKeySmtpHost] == "" || envData[constants.EnvKeySmtpUsername] == "" || envData[constants.EnvKeySmtpPassword] == "" || envData[constants.EnvKeySenderEmail] == "" && envData[constants.EnvKeySmtpPort] == "" {
envData[constants.EnvKeyDisableEmailVerification] = true envData[constants.EnvKeyDisableEmailVerification] = true

View File

@@ -1,6 +1,7 @@
package env package env
import ( import (
"context"
"encoding/json" "encoding/json"
"os" "os"
"reflect" "reflect"
@@ -58,7 +59,8 @@ func fixBackwardCompatibility(data map[string]interface{}) (bool, map[string]int
// GetEnvData returns the env data from database // GetEnvData returns the env data from database
func GetEnvData() (map[string]interface{}, error) { func GetEnvData() (map[string]interface{}, error) {
var result map[string]interface{} var result map[string]interface{}
env, err := db.Provider.GetEnv() ctx := context.Background()
env, err := db.Provider.GetEnv(ctx)
// config not found in db // config not found in db
if err != nil { if err != nil {
log.Debug("Error while getting env data from db: ", err) log.Debug("Error while getting env data from db: ", err)
@@ -108,9 +110,10 @@ func GetEnvData() (map[string]interface{}, error) {
// PersistEnv persists the environment variables to the database // PersistEnv persists the environment variables to the database
func PersistEnv() error { func PersistEnv() error {
env, err := db.Provider.GetEnv() ctx := context.Background()
env, err := db.Provider.GetEnv(ctx)
// config not found in db // config not found in db
if err != nil { if err != nil || env.EnvData == "" {
// AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid // AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid
hash := uuid.New().String()[:36-4] hash := uuid.New().String()[:36-4]
err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash) err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash)
@@ -137,7 +140,7 @@ func PersistEnv() error {
EnvData: encryptedConfig, EnvData: encryptedConfig,
} }
env, err = db.Provider.AddEnv(env) env, err = db.Provider.AddEnv(ctx, env)
if err != nil { if err != nil {
log.Debug("Error while persisting env data to db: ", err) log.Debug("Error while persisting env data to db: ", err)
return err return err
@@ -171,7 +174,7 @@ func PersistEnv() error {
err = json.Unmarshal(decryptedConfigs, &storeData) err = json.Unmarshal(decryptedConfigs, &storeData)
if err != nil { if err != nil {
log.Debug("Error while unmarshalling env data: ", err) log.Debug("Error while un-marshalling env data: ", err)
return err return err
} }
@@ -198,7 +201,7 @@ func PersistEnv() error {
envValue := strings.TrimSpace(os.Getenv(key)) envValue := strings.TrimSpace(os.Getenv(key))
if envValue != "" { if envValue != "" {
switch key { switch key {
case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv: case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv, constants.EnvKeyDisableStrongPassword:
if envValueBool, err := strconv.ParseBool(envValue); err == nil { if envValueBool, err := strconv.ParseBool(envValue); err == nil {
if value.(bool) != envValueBool { if value.(bool) != envValueBool {
storeData[key] = envValueBool storeData[key] = envValueBool
@@ -251,7 +254,7 @@ func PersistEnv() error {
} }
env.EnvData = encryptedConfig env.EnvData = encryptedConfig
_, err = db.Provider.UpdateEnv(env) _, err = db.Provider.UpdateEnv(ctx, env)
if err != nil { if err != nil {
log.Debug("Failed to Update Config: ", err) log.Debug("Failed to Update Config: ", err)
return err return err

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,18 @@
package model package model
type AddEmailTemplateRequest struct {
EventName string `json:"event_name"`
Template string `json:"template"`
}
type AddWebhookRequest struct {
EventName string `json:"event_name"`
Endpoint string `json:"endpoint"`
Enabled bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
}
type AdminLoginInput struct { type AdminLoginInput struct {
AdminSecret string `json:"admin_secret"` AdminSecret string `json:"admin_secret"`
} }
@@ -19,10 +31,27 @@ type AuthResponse struct {
User *User `json:"user"` User *User `json:"user"`
} }
type DeleteEmailTemplateRequest struct {
ID string `json:"id"`
}
type DeleteUserInput struct { type DeleteUserInput struct {
Email string `json:"email"` Email string `json:"email"`
} }
type EmailTemplate struct {
ID string `json:"id"`
EventName string `json:"event_name"`
Template string `json:"template"`
CreatedAt *int64 `json:"created_at"`
UpdatedAt *int64 `json:"updated_at"`
}
type EmailTemplates struct {
Pagination *Pagination `json:"pagination"`
EmailTemplates []*EmailTemplate `json:"EmailTemplates"`
}
type Env struct { type Env struct {
AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"`
AdminSecret *string `json:"ADMIN_SECRET"` AdminSecret *string `json:"ADMIN_SECRET"`
@@ -55,6 +84,7 @@ type Env struct {
DisableLoginPage bool `json:"DISABLE_LOGIN_PAGE"` DisableLoginPage bool `json:"DISABLE_LOGIN_PAGE"`
DisableSignUp bool `json:"DISABLE_SIGN_UP"` DisableSignUp bool `json:"DISABLE_SIGN_UP"`
DisableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"` DisableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"`
DisableStrongPassword bool `json:"DISABLE_STRONG_PASSWORD"`
Roles []string `json:"ROLES"` Roles []string `json:"ROLES"`
ProtectedRoles []string `json:"PROTECTED_ROLES"` ProtectedRoles []string `json:"PROTECTED_ROLES"`
DefaultRoles []string `json:"DEFAULT_ROLES"` DefaultRoles []string `json:"DEFAULT_ROLES"`
@@ -99,6 +129,11 @@ type InviteMemberInput struct {
RedirectURI *string `json:"redirect_uri"` RedirectURI *string `json:"redirect_uri"`
} }
type ListWebhookLogRequest struct {
Pagination *PaginationInput `json:"pagination"`
WebhookID *string `json:"webhook_id"`
}
type LoginInput struct { type LoginInput struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
@@ -126,6 +161,7 @@ type Meta struct {
IsBasicAuthenticationEnabled bool `json:"is_basic_authentication_enabled"` IsBasicAuthenticationEnabled bool `json:"is_basic_authentication_enabled"`
IsMagicLinkLoginEnabled bool `json:"is_magic_link_login_enabled"` IsMagicLinkLoginEnabled bool `json:"is_magic_link_login_enabled"`
IsSignUpEnabled bool `json:"is_sign_up_enabled"` IsSignUpEnabled bool `json:"is_sign_up_enabled"`
IsStrongPasswordEnabled bool `json:"is_strong_password_enabled"`
} }
type OAuthRevokeInput struct { type OAuthRevokeInput struct {
@@ -185,10 +221,27 @@ type SignUpInput struct {
RedirectURI *string `json:"redirect_uri"` RedirectURI *string `json:"redirect_uri"`
} }
type TestEndpointRequest struct {
Endpoint string `json:"endpoint"`
EventName string `json:"event_name"`
Headers map[string]interface{} `json:"headers"`
}
type TestEndpointResponse struct {
HTTPStatus *int64 `json:"http_status"`
Response map[string]interface{} `json:"response"`
}
type UpdateAccessInput struct { type UpdateAccessInput struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
} }
type UpdateEmailTemplateRequest struct {
ID string `json:"id"`
EventName *string `json:"event_name"`
Template *string `json:"template"`
}
type UpdateEnvInput struct { type UpdateEnvInput struct {
AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"`
AdminSecret *string `json:"ADMIN_SECRET"` AdminSecret *string `json:"ADMIN_SECRET"`
@@ -212,6 +265,7 @@ type UpdateEnvInput struct {
DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"` DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"`
DisableSignUp *bool `json:"DISABLE_SIGN_UP"` DisableSignUp *bool `json:"DISABLE_SIGN_UP"`
DisableRedisForEnv *bool `json:"DISABLE_REDIS_FOR_ENV"` DisableRedisForEnv *bool `json:"DISABLE_REDIS_FOR_ENV"`
DisableStrongPassword *bool `json:"DISABLE_STRONG_PASSWORD"`
Roles []string `json:"ROLES"` Roles []string `json:"ROLES"`
ProtectedRoles []string `json:"PROTECTED_ROLES"` ProtectedRoles []string `json:"PROTECTED_ROLES"`
DefaultRoles []string `json:"DEFAULT_ROLES"` DefaultRoles []string `json:"DEFAULT_ROLES"`
@@ -260,6 +314,14 @@ type UpdateUserInput struct {
Roles []*string `json:"roles"` Roles []*string `json:"roles"`
} }
type UpdateWebhookRequest struct {
ID string `json:"id"`
EventName *string `json:"event_name"`
Endpoint *string `json:"endpoint"`
Enabled *bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
}
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
@@ -316,3 +378,37 @@ type VerificationRequests struct {
type VerifyEmailInput struct { type VerifyEmailInput struct {
Token string `json:"token"` Token string `json:"token"`
} }
type Webhook struct {
ID string `json:"id"`
EventName *string `json:"event_name"`
Endpoint *string `json:"endpoint"`
Enabled *bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
CreatedAt *int64 `json:"created_at"`
UpdatedAt *int64 `json:"updated_at"`
}
type WebhookLog struct {
ID string `json:"id"`
HTTPStatus *int64 `json:"http_status"`
Response *string `json:"response"`
Request *string `json:"request"`
WebhookID *string `json:"webhook_id"`
CreatedAt *int64 `json:"created_at"`
UpdatedAt *int64 `json:"updated_at"`
}
type WebhookLogs struct {
Pagination *Pagination `json:"pagination"`
WebhookLogs []*WebhookLog `json:"webhook_logs"`
}
type WebhookRequest struct {
ID string `json:"id"`
}
type Webhooks struct {
Pagination *Pagination `json:"pagination"`
Webhooks []*Webhook `json:"webhooks"`
}

View File

@@ -24,6 +24,7 @@ type Meta {
is_basic_authentication_enabled: Boolean! is_basic_authentication_enabled: Boolean!
is_magic_link_login_enabled: Boolean! is_magic_link_login_enabled: Boolean!
is_sign_up_enabled: Boolean! is_sign_up_enabled: Boolean!
is_strong_password_enabled: Boolean!
} }
type User { type User {
@@ -120,6 +121,7 @@ type Env {
DISABLE_LOGIN_PAGE: Boolean! DISABLE_LOGIN_PAGE: Boolean!
DISABLE_SIGN_UP: Boolean! DISABLE_SIGN_UP: Boolean!
DISABLE_REDIS_FOR_ENV: Boolean! DISABLE_REDIS_FOR_ENV: Boolean!
DISABLE_STRONG_PASSWORD: Boolean!
ROLES: [String!] ROLES: [String!]
PROTECTED_ROLES: [String!] PROTECTED_ROLES: [String!]
DEFAULT_ROLES: [String!] DEFAULT_ROLES: [String!]
@@ -148,6 +150,54 @@ type GenerateJWTKeysResponse {
private_key: String private_key: String
} }
type Webhook {
id: ID!
event_name: String
endpoint: String
enabled: Boolean
headers: Map
created_at: Int64
updated_at: Int64
}
type Webhooks {
pagination: Pagination!
webhooks: [Webhook!]!
}
type WebhookLog {
id: ID!
http_status: Int64
response: String
request: String
webhook_id: ID
created_at: Int64
updated_at: Int64
}
type TestEndpointResponse {
http_status: Int64
response: Map
}
type WebhookLogs {
pagination: Pagination!
webhook_logs: [WebhookLog!]!
}
type EmailTemplate {
id: ID!
event_name: String!
template: String!
created_at: Int64
updated_at: Int64
}
type EmailTemplates {
pagination: Pagination!
EmailTemplates: [EmailTemplate!]!
}
input UpdateEnvInput { input UpdateEnvInput {
ACCESS_TOKEN_EXPIRY_TIME: String ACCESS_TOKEN_EXPIRY_TIME: String
ADMIN_SECRET: String ADMIN_SECRET: String
@@ -171,6 +221,7 @@ input UpdateEnvInput {
DISABLE_LOGIN_PAGE: Boolean DISABLE_LOGIN_PAGE: Boolean
DISABLE_SIGN_UP: Boolean DISABLE_SIGN_UP: Boolean
DISABLE_REDIS_FOR_ENV: Boolean DISABLE_REDIS_FOR_ENV: Boolean
DISABLE_STRONG_PASSWORD: Boolean
ROLES: [String!] ROLES: [String!]
PROTECTED_ROLES: [String!] PROTECTED_ROLES: [String!]
DEFAULT_ROLES: [String!] DEFAULT_ROLES: [String!]
@@ -321,6 +372,51 @@ input GenerateJWTKeysInput {
type: String! type: String!
} }
input ListWebhookLogRequest {
pagination: PaginationInput
webhook_id: String
}
input AddWebhookRequest {
event_name: String!
endpoint: String!
enabled: Boolean!
headers: Map
}
input UpdateWebhookRequest {
id: ID!
event_name: String
endpoint: String
enabled: Boolean
headers: Map
}
input WebhookRequest {
id: ID!
}
input TestEndpointRequest {
endpoint: String!
event_name: String!
headers: Map
}
input AddEmailTemplateRequest {
event_name: String!
template: String!
}
input UpdateEmailTemplateRequest {
id: ID!
event_name: String
template: String
}
input DeleteEmailTemplateRequest {
id: ID!
}
type Mutation { type Mutation {
signup(params: SignUpInput!): AuthResponse! signup(params: SignUpInput!): AuthResponse!
login(params: LoginInput!): AuthResponse! login(params: LoginInput!): AuthResponse!
@@ -343,6 +439,13 @@ type Mutation {
_revoke_access(param: UpdateAccessInput!): Response! _revoke_access(param: UpdateAccessInput!): Response!
_enable_access(param: UpdateAccessInput!): Response! _enable_access(param: UpdateAccessInput!): Response!
_generate_jwt_keys(params: GenerateJWTKeysInput!): GenerateJWTKeysResponse! _generate_jwt_keys(params: GenerateJWTKeysInput!): GenerateJWTKeysResponse!
_add_webhook(params: AddWebhookRequest!): Response!
_update_webhook(params: UpdateWebhookRequest!): Response!
_delete_webhook(params: WebhookRequest!): Response!
_test_endpoint(params: TestEndpointRequest!): TestEndpointResponse!
_add_email_template(params: AddEmailTemplateRequest!): Response!
_update_email_template(params: UpdateEmailTemplateRequest!): Response!
_delete_email_template(params: DeleteEmailTemplateRequest!): Response!
} }
type Query { type Query {
@@ -355,4 +458,8 @@ type Query {
_verification_requests(params: PaginatedInput): VerificationRequests! _verification_requests(params: PaginatedInput): VerificationRequests!
_admin_session: Response! _admin_session: Response!
_env: Env! _env: Env!
_webhook(params: WebhookRequest!): Webhook!
_webhooks(params: PaginatedInput): Webhooks!
_webhook_logs(params: ListWebhookLogRequest): WebhookLogs!
_email_templates(params: PaginatedInput): EmailTemplates!
} }

View File

@@ -91,6 +91,34 @@ func (r *mutationResolver) GenerateJwtKeys(ctx context.Context, params model.Gen
return resolvers.GenerateJWTKeysResolver(ctx, params) return resolvers.GenerateJWTKeysResolver(ctx, params)
} }
func (r *mutationResolver) AddWebhook(ctx context.Context, params model.AddWebhookRequest) (*model.Response, error) {
return resolvers.AddWebhookResolver(ctx, params)
}
func (r *mutationResolver) UpdateWebhook(ctx context.Context, params model.UpdateWebhookRequest) (*model.Response, error) {
return resolvers.UpdateWebhookResolver(ctx, params)
}
func (r *mutationResolver) DeleteWebhook(ctx context.Context, params model.WebhookRequest) (*model.Response, error) {
return resolvers.DeleteWebhookResolver(ctx, params)
}
func (r *mutationResolver) TestEndpoint(ctx context.Context, params model.TestEndpointRequest) (*model.TestEndpointResponse, error) {
return resolvers.TestEndpointResolver(ctx, params)
}
func (r *mutationResolver) AddEmailTemplate(ctx context.Context, params model.AddEmailTemplateRequest) (*model.Response, error) {
return resolvers.AddEmailTemplateResolver(ctx, params)
}
func (r *mutationResolver) UpdateEmailTemplate(ctx context.Context, params model.UpdateEmailTemplateRequest) (*model.Response, error) {
return resolvers.UpdateEmailTemplateResolver(ctx, params)
}
func (r *mutationResolver) DeleteEmailTemplate(ctx context.Context, params model.DeleteEmailTemplateRequest) (*model.Response, error) {
return resolvers.DeleteEmailTemplateResolver(ctx, params)
}
func (r *queryResolver) Meta(ctx context.Context) (*model.Meta, error) { func (r *queryResolver) Meta(ctx context.Context) (*model.Meta, error) {
return resolvers.MetaResolver(ctx) return resolvers.MetaResolver(ctx)
} }
@@ -123,11 +151,29 @@ func (r *queryResolver) Env(ctx context.Context) (*model.Env, error) {
return resolvers.EnvResolver(ctx) return resolvers.EnvResolver(ctx)
} }
func (r *queryResolver) Webhook(ctx context.Context, params model.WebhookRequest) (*model.Webhook, error) {
return resolvers.WebhookResolver(ctx, params)
}
func (r *queryResolver) Webhooks(ctx context.Context, params *model.PaginatedInput) (*model.Webhooks, error) {
return resolvers.WebhooksResolver(ctx, params)
}
func (r *queryResolver) WebhookLogs(ctx context.Context, params *model.ListWebhookLogRequest) (*model.WebhookLogs, error) {
return resolvers.WebhookLogsResolver(ctx, params)
}
func (r *queryResolver) EmailTemplates(ctx context.Context, params *model.PaginatedInput) (*model.EmailTemplates, error) {
return resolvers.EmailTemplatesResolver(ctx, params)
}
// Mutation returns generated.MutationResolver implementation. // Mutation returns generated.MutationResolver implementation.
func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} } func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} }
// Query returns generated.QueryResolver implementation. // Query returns generated.QueryResolver implementation.
func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} }
type mutationResolver struct{ *Resolver } type (
type queryResolver struct{ *Resolver } mutationResolver struct{ *Resolver }
queryResolver struct{ *Resolver }
)

View File

@@ -199,7 +199,7 @@ func AuthorizeHandler() gin.HandlerFunc {
return return
} }
userID := claims.Subject userID := claims.Subject
user, err := db.Provider.GetUserByID(userID) user, err := db.Provider.GetUserByID(gc, userID)
if err != nil { if err != nil {
if isQuery { if isQuery {
gc.Redirect(http.StatusFound, loginURL) gc.Redirect(http.StatusFound, loginURL)
@@ -218,13 +218,18 @@ func AuthorizeHandler() gin.HandlerFunc {
return return
} }
sessionKey := user.ID
if claims.LoginMethod != "" {
sessionKey = claims.LoginMethod + ":" + user.ID
}
// if user is logged in // if user is logged in
// based on the response type, generate the response // based on the response type code, generate the response
if isResponseTypeCode { if isResponseTypeCode {
// rollover the session for security // rollover the session for security
go memorystore.Provider.DeleteUserSession(user.ID, claims.Nonce) go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
nonce := uuid.New().String() nonce := uuid.New().String()
newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope) newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod)
if err != nil { if err != nil {
if isQuery { if isQuery {
gc.Redirect(http.StatusFound, loginURL) gc.Redirect(http.StatusFound, loginURL)
@@ -262,7 +267,7 @@ func AuthorizeHandler() gin.HandlerFunc {
if isResponseTypeToken { if isResponseTypeToken {
// rollover the session for security // rollover the session for security
authToken, err := token.CreateAuthToken(gc, user, claims.Roles, scope) authToken, err := token.CreateAuthToken(gc, user, claims.Roles, scope, claims.LoginMethod)
if err != nil { if err != nil {
if isQuery { if isQuery {
gc.Redirect(http.StatusFound, loginURL) gc.Redirect(http.StatusFound, loginURL)
@@ -280,9 +285,10 @@ func AuthorizeHandler() gin.HandlerFunc {
} }
return return
} }
go memorystore.Provider.DeleteUserSession(user.ID, claims.Nonce)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -305,7 +311,7 @@ func AuthorizeHandler() gin.HandlerFunc {
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res["refresh_token"] = authToken.RefreshToken.Token res["refresh_token"] = authToken.RefreshToken.Token
params += "&refresh_token=" + authToken.RefreshToken.Token params += "&refresh_token=" + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
} }
if isQuery { if isQuery {

View File

@@ -2,6 +2,7 @@ package handlers
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@@ -17,7 +18,6 @@ import (
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/cookie"
"github.com/authorizerdev/authorizer/server/crypto"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/memorystore"
@@ -28,21 +28,21 @@ import (
// OAuthCallbackHandler handles the OAuth callback for various oauth providers // OAuthCallbackHandler handles the OAuth callback for various oauth providers
func OAuthCallbackHandler() gin.HandlerFunc { func OAuthCallbackHandler() gin.HandlerFunc {
return func(c *gin.Context) { return func(ctx *gin.Context) {
provider := c.Param("oauth_provider") provider := ctx.Param("oauth_provider")
state := c.Request.FormValue("state") state := ctx.Request.FormValue("state")
sessionState, err := memorystore.Provider.GetState(state) sessionState, err := memorystore.Provider.GetState(state)
if sessionState == "" || err != nil { if sessionState == "" || err != nil {
log.Debug("Invalid oauth state: ", state) log.Debug("Invalid oauth state: ", state)
c.JSON(400, gin.H{"error": "invalid oauth state"}) ctx.JSON(400, gin.H{"error": "invalid oauth state"})
} }
// contains random token, redirect url, role // contains random token, redirect url, role
sessionSplit := strings.Split(state, "___") sessionSplit := strings.Split(state, "___")
if len(sessionSplit) < 3 { if len(sessionSplit) < 3 {
log.Debug("Unable to get redirect url from state: ", state) log.Debug("Unable to get redirect url from state: ", state)
c.JSON(400, gin.H{"error": "invalid redirect url"}) ctx.JSON(400, gin.H{"error": "invalid redirect url"})
return return
} }
@@ -55,17 +55,17 @@ func OAuthCallbackHandler() gin.HandlerFunc {
scopes := strings.Split(sessionSplit[3], ",") scopes := strings.Split(sessionSplit[3], ",")
user := models.User{} user := models.User{}
code := c.Request.FormValue("code") code := ctx.Request.FormValue("code")
switch provider { switch provider {
case constants.SignupMethodGoogle: case constants.AuthRecipeMethodGoogle:
user, err = processGoogleUserInfo(code) user, err = processGoogleUserInfo(code)
case constants.SignupMethodGithub: case constants.AuthRecipeMethodGithub:
user, err = processGithubUserInfo(code) user, err = processGithubUserInfo(code)
case constants.SignupMethodFacebook: case constants.AuthRecipeMethodFacebook:
user, err = processFacebookUserInfo(code) user, err = processFacebookUserInfo(code)
case constants.SignupMethodLinkedIn: case constants.AuthRecipeMethodLinkedIn:
user, err = processLinkedInUserInfo(code) user, err = processLinkedInUserInfo(code)
case constants.SignupMethodApple: case constants.AuthRecipeMethodApple:
user, err = processAppleUserInfo(code) user, err = processAppleUserInfo(code)
default: default:
log.Info("Invalid oauth provider") log.Info("Invalid oauth provider")
@@ -74,23 +74,24 @@ func OAuthCallbackHandler() gin.HandlerFunc {
if err != nil { if err != nil {
log.Debug("Failed to process user info: ", err) log.Debug("Failed to process user info: ", err)
c.JSON(400, gin.H{"error": err.Error()}) ctx.JSON(400, gin.H{"error": err.Error()})
return return
} }
existingUser, err := db.Provider.GetUserByEmail(user.Email) existingUser, err := db.Provider.GetUserByEmail(ctx, user.Email)
log := log.WithField("user", user.Email) log := log.WithField("user", user.Email)
isSignUp := false
if err != nil { if err != nil {
isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp)
if err != nil { if err != nil {
log.Debug("Failed to get signup disabled env variable: ", err) log.Debug("Failed to get signup disabled env variable: ", err)
c.JSON(400, gin.H{"error": err.Error()}) ctx.JSON(400, gin.H{"error": err.Error()})
return return
} }
if isSignupDisabled { if isSignupDisabled {
log.Debug("Failed to signup as disabled") log.Debug("Failed to signup as disabled")
c.JSON(400, gin.H{"error": "signup is disabled for this instance"}) ctx.JSON(400, gin.H{"error": "signup is disabled for this instance"})
return return
} }
// user not registered, register user and generate session token // user not registered, register user and generate session token
@@ -113,19 +114,20 @@ func OAuthCallbackHandler() gin.HandlerFunc {
if hasProtectedRole { if hasProtectedRole {
log.Debug("Signup is not allowed with protected roles:", inputRoles) log.Debug("Signup is not allowed with protected roles:", inputRoles)
c.JSON(400, gin.H{"error": "invalid role"}) ctx.JSON(400, gin.H{"error": "invalid role"})
return return
} }
user.Roles = strings.Join(inputRoles, ",") user.Roles = strings.Join(inputRoles, ",")
now := time.Now().Unix() now := time.Now().Unix()
user.EmailVerifiedAt = &now user.EmailVerifiedAt = &now
user, _ = db.Provider.AddUser(user) user, _ = db.Provider.AddUser(ctx, user)
isSignUp = true
} else { } else {
user = existingUser user = existingUser
if user.RevokedTimestamp != nil { if user.RevokedTimestamp != nil {
log.Debug("User access revoked at: ", user.RevokedTimestamp) log.Debug("User access revoked at: ", user.RevokedTimestamp)
c.JSON(400, gin.H{"error": "user access has been revoked"}) ctx.JSON(400, gin.H{"error": "user access has been revoked"})
return return
} }
@@ -175,7 +177,7 @@ func OAuthCallbackHandler() gin.HandlerFunc {
if hasProtectedRole { if hasProtectedRole {
log.Debug("Invalid role. User is using protected unassigned role") log.Debug("Invalid role. User is using protected unassigned role")
c.JSON(400, gin.H{"error": "invalid role"}) ctx.JSON(400, gin.H{"error": "invalid role"})
return return
} else { } else {
user.Roles = existingUser.Roles + "," + strings.Join(unasignedRoles, ",") user.Roles = existingUser.Roles + "," + strings.Join(unasignedRoles, ",")
@@ -184,18 +186,18 @@ func OAuthCallbackHandler() gin.HandlerFunc {
user.Roles = existingUser.Roles user.Roles = existingUser.Roles
} }
user, err = db.Provider.UpdateUser(user) user, err = db.Provider.UpdateUser(ctx, user)
if err != nil { if err != nil {
log.Debug("Failed to update user: ", err) log.Debug("Failed to update user: ", err)
c.JSON(500, gin.H{"error": err.Error()}) ctx.JSON(500, gin.H{"error": err.Error()})
return return
} }
} }
authToken, err := token.CreateAuthToken(c, user, inputRoles, scopes) authToken, err := token.CreateAuthToken(ctx, user, inputRoles, scopes, provider)
if err != nil { if err != nil {
log.Debug("Failed to create auth token: ", err) log.Debug("Failed to create auth token: ", err)
c.JSON(500, gin.H{"error": err.Error()}) ctx.JSON(500, gin.H{"error": err.Error()})
} }
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -205,27 +207,35 @@ func OAuthCallbackHandler() gin.HandlerFunc {
params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + stateValue + "&id_token=" + authToken.IDToken.Token params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + stateValue + "&id_token=" + authToken.IDToken.Token
cookie.SetSession(c, authToken.FingerPrintHash) sessionKey := provider + ":" + user.ID
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) cookie.SetSession(ctx, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
params = params + `&refresh_token=` + authToken.RefreshToken.Token params = params + `&refresh_token=` + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
} }
go db.Provider.AddSession(models.Session{ go func() {
UserID: user.ID, if isSignUp {
UserAgent: utils.GetUserAgent(c.Request), utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, provider, user)
IP: utils.GetIP(c.Request), } else {
}) utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, provider, user)
}
db.Provider.AddSession(ctx, models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(ctx.Request),
IP: utils.GetIP(ctx.Request),
})
}()
if strings.Contains(redirectURL, "?") { if strings.Contains(redirectURL, "?") {
redirectURL = redirectURL + "&" + params redirectURL = redirectURL + "&" + params
} else { } else {
redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&") redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&")
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) ctx.Redirect(http.StatusFound, redirectURL)
} }
} }
@@ -310,12 +320,60 @@ func processGithubUserInfo(code string) (models.User, error) {
} }
picture := userRawData["avatar_url"] picture := userRawData["avatar_url"]
email := userRawData["email"]
if email == "" {
type GithubUserEmails struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
// fetch using /users/email endpoint
req, err := http.NewRequest("GET", constants.GithubUserEmails, nil)
if err != nil {
log.Debug("Failed to create github emails request: ", err)
return user, fmt.Errorf("error creating github user info request: %s", err.Error())
}
req.Header = http.Header{
"Authorization": []string{fmt.Sprintf("token %s", oauth2Token.AccessToken)},
}
response, err := client.Do(req)
if err != nil {
log.Debug("Failed to request github user email: ", err)
return user, err
}
defer response.Body.Close()
body, err := ioutil.ReadAll(response.Body)
if err != nil {
log.Debug("Failed to read github user email response body: ", err)
return user, fmt.Errorf("failed to read github response body: %s", err.Error())
}
if response.StatusCode >= 400 {
log.Debug("Failed to request github user email: ", string(body))
return user, fmt.Errorf("failed to request github user info: %s", string(body))
}
emailData := []GithubUserEmails{}
err = json.Unmarshal(body, &emailData)
if err != nil {
log.Debug("Failed to parse github user email: ", err)
}
for _, userEmail := range emailData {
email = userEmail.Email
if userEmail.Primary {
break
}
}
}
user = models.User{ user = models.User{
GivenName: &firstName, GivenName: &firstName,
FamilyName: &lastName, FamilyName: &lastName,
Picture: &picture, Picture: &picture,
Email: userRawData["email"], Email: email,
} }
return user, nil return user, nil
@@ -462,8 +520,6 @@ func processAppleUserInfo(code string) (models.User, error) {
return user, fmt.Errorf("invalid apple exchange code: %s", err.Error()) return user, fmt.Errorf("invalid apple exchange code: %s", err.Error())
} }
fmt.Println("=> token", oauth2Token.AccessToken)
// Extract the ID Token from OAuth2 token. // Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string) rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok { if !ok {
@@ -471,39 +527,39 @@ func processAppleUserInfo(code string) (models.User, error) {
return user, fmt.Errorf("unable to extract id_token") return user, fmt.Errorf("unable to extract id_token")
} }
fmt.Println("=> rawIDToken", rawIDToken)
tokenSplit := strings.Split(rawIDToken, ".") tokenSplit := strings.Split(rawIDToken, ".")
claimsData := tokenSplit[1] claimsData := tokenSplit[1]
decodedClaimsData, err := crypto.DecryptB64(claimsData) decodedClaimsData, err := base64.RawURLEncoding.DecodeString(claimsData)
if err != nil { if err != nil {
log.Debug("Failed to decrypt claims data: ", err) log.Debugf("Failed to decrypt claims %s: %s", claimsData, err.Error())
return user, fmt.Errorf("failed to decrypt claims data: %s", err.Error()) return user, fmt.Errorf("failed to decrypt claims data: %s", err.Error())
} }
fmt.Println("=> decoded claims data", decodedClaimsData)
claims := make(map[string]interface{}) claims := make(map[string]interface{})
err = json.Unmarshal([]byte(decodedClaimsData), &claims) err = json.Unmarshal(decodedClaimsData, &claims)
if err != nil { if err != nil {
log.Debug("Failed to unmarshal claims data: ", err) log.Debug("Failed to unmarshal claims data: ", err)
return user, fmt.Errorf("failed to unmarshal claims data: %s", err.Error()) return user, fmt.Errorf("failed to unmarshal claims data: %s", err.Error())
} }
fmt.Println("=> claims", claims)
if val, ok := claims["email"]; !ok { if val, ok := claims["email"]; !ok {
log.Debug("Failed to extract email from claims") log.Debug("Failed to extract email from claims.")
return user, fmt.Errorf("unable to extract email") return user, fmt.Errorf("unable to extract email, please check the scopes enabled for your app. It needs `email`, `name` scopes")
} else { } else {
user.Email = val.(string) user.Email = val.(string)
} }
if val, ok := claims["name"]; ok { if val, ok := claims["name"]; ok {
nameData := val.(map[string]interface{}) nameData := val.(map[string]interface{})
givenName := nameData["firstName"].(string) if nameVal, ok := nameData["firstName"]; ok {
familyName := nameData["lastName"].(string) givenName := nameVal.(string)
user.GivenName = &givenName user.GivenName = &givenName
user.FamilyName = &familyName }
if nameVal, ok := nameData["lastName"]; ok {
familyName := nameVal.(string)
user.FamilyName = &familyName
}
} }
return user, err return user, err

View File

@@ -100,13 +100,13 @@ func OAuthLoginHandler() gin.HandlerFunc {
provider := c.Param("oauth_provider") provider := c.Param("oauth_provider")
isProviderConfigured := true isProviderConfigured := true
switch provider { switch provider {
case constants.SignupMethodGoogle: case constants.AuthRecipeMethodGoogle:
if oauth.OAuthProviders.GoogleConfig == nil { if oauth.OAuthProviders.GoogleConfig == nil {
log.Debug("Google OAuth provider is not configured") log.Debug("Google OAuth provider is not configured")
isProviderConfigured = false isProviderConfigured = false
break break
} }
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGoogle) err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodGoogle)
if err != nil { if err != nil {
log.Debug("Error setting state: ", err) log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{ c.JSON(500, gin.H{
@@ -115,16 +115,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
return return
} }
// during the init of OAuthProvider authorizer url might be empty // during the init of OAuthProvider authorizer url might be empty
oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodGoogle oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodGoogle
url := oauth.OAuthProviders.GoogleConfig.AuthCodeURL(oauthStateString) url := oauth.OAuthProviders.GoogleConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url) c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodGithub: case constants.AuthRecipeMethodGithub:
if oauth.OAuthProviders.GithubConfig == nil { if oauth.OAuthProviders.GithubConfig == nil {
log.Debug("Github OAuth provider is not configured") log.Debug("Github OAuth provider is not configured")
isProviderConfigured = false isProviderConfigured = false
break break
} }
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGithub) err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodGithub)
if err != nil { if err != nil {
log.Debug("Error setting state: ", err) log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{ c.JSON(500, gin.H{
@@ -132,16 +132,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
}) })
return return
} }
oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodGithub oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodGithub
url := oauth.OAuthProviders.GithubConfig.AuthCodeURL(oauthStateString) url := oauth.OAuthProviders.GithubConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url) c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodFacebook: case constants.AuthRecipeMethodFacebook:
if oauth.OAuthProviders.FacebookConfig == nil { if oauth.OAuthProviders.FacebookConfig == nil {
log.Debug("Facebook OAuth provider is not configured") log.Debug("Facebook OAuth provider is not configured")
isProviderConfigured = false isProviderConfigured = false
break break
} }
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodFacebook) err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodFacebook)
if err != nil { if err != nil {
log.Debug("Error setting state: ", err) log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{ c.JSON(500, gin.H{
@@ -149,16 +149,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
}) })
return return
} }
oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodFacebook oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodFacebook
url := oauth.OAuthProviders.FacebookConfig.AuthCodeURL(oauthStateString) url := oauth.OAuthProviders.FacebookConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url) c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodLinkedIn: case constants.AuthRecipeMethodLinkedIn:
if oauth.OAuthProviders.LinkedInConfig == nil { if oauth.OAuthProviders.LinkedInConfig == nil {
log.Debug("Linkedin OAuth provider is not configured") log.Debug("Linkedin OAuth provider is not configured")
isProviderConfigured = false isProviderConfigured = false
break break
} }
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodLinkedIn) err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodLinkedIn)
if err != nil { if err != nil {
log.Debug("Error setting state: ", err) log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{ c.JSON(500, gin.H{
@@ -166,16 +166,16 @@ func OAuthLoginHandler() gin.HandlerFunc {
}) })
return return
} }
oauth.OAuthProviders.LinkedInConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodLinkedIn oauth.OAuthProviders.LinkedInConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodLinkedIn
url := oauth.OAuthProviders.LinkedInConfig.AuthCodeURL(oauthStateString) url := oauth.OAuthProviders.LinkedInConfig.AuthCodeURL(oauthStateString)
c.Redirect(http.StatusTemporaryRedirect, url) c.Redirect(http.StatusTemporaryRedirect, url)
case constants.SignupMethodApple: case constants.AuthRecipeMethodApple:
if oauth.OAuthProviders.AppleConfig == nil { if oauth.OAuthProviders.AppleConfig == nil {
log.Debug("Apple OAuth provider is not configured") log.Debug("Apple OAuth provider is not configured")
isProviderConfigured = false isProviderConfigured = false
break break
} }
err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodApple) err := memorystore.Provider.SetState(oauthStateString, constants.AuthRecipeMethodApple)
if err != nil { if err != nil {
log.Debug("Error setting state: ", err) log.Debug("Error setting state: ", err)
c.JSON(500, gin.H{ c.JSON(500, gin.H{
@@ -183,8 +183,10 @@ func OAuthLoginHandler() gin.HandlerFunc {
}) })
return return
} }
oauth.OAuthProviders.AppleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.SignupMethodApple oauth.OAuthProviders.AppleConfig.RedirectURL = hostname + "/oauth_callback/" + constants.AuthRecipeMethodApple
url := oauth.OAuthProviders.AppleConfig.AuthCodeURL(oauthStateString, oauth2.SetAuthURLParam("response_mode", "form_post")) // there is scope encoding issue with oauth2 and how apple expects, hence added scope manually
// check: https://github.com/golang/oauth2/issues/449
url := oauth.OAuthProviders.AppleConfig.AuthCodeURL(oauthStateString, oauth2.SetAuthURLParam("response_mode", "form_post")) + "&scope=name email"
c.Redirect(http.StatusTemporaryRedirect, url) c.Redirect(http.StatusTemporaryRedirect, url)
default: default:
log.Debug("Invalid oauth provider: ", provider) log.Debug("Invalid oauth provider: ", provider)

View File

@@ -12,8 +12,8 @@ import (
"github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/token"
) )
// Revoke handler to revoke refresh token // RevokeRefreshTokenHandler handler to revoke refresh token
func RevokeHandler() gin.HandlerFunc { func RevokeRefreshTokenHandler() gin.HandlerFunc {
return func(gc *gin.Context) { return func(gc *gin.Context) {
var reqBody map[string]string var reqBody map[string]string
if err := gc.BindJSON(&reqBody); err != nil { if err := gc.BindJSON(&reqBody); err != nil {
@@ -56,7 +56,14 @@ func RevokeHandler() gin.HandlerFunc {
return return
} }
memorystore.Provider.DeleteUserSession(claims["sub"].(string), claims["nonce"].(string)) userID := claims["sub"].(string)
loginMethod := claims["login_method"]
sessionToken := userID
if loginMethod != nil && loginMethod != "" {
sessionToken = loginMethod.(string) + ":" + userID
}
memorystore.Provider.DeleteUserSession(sessionToken, claims["nonce"].(string))
gc.JSON(http.StatusOK, gin.H{ gc.JSON(http.StatusOK, gin.H{
"message": "Token revoked successfully", "message": "Token revoked successfully",

View File

@@ -72,6 +72,9 @@ func TokenHandler() gin.HandlerFunc {
var userID string var userID string
var roles, scope []string var roles, scope []string
loginMethod := ""
sessionKey := ""
if isAuthorizationCodeGrant { if isAuthorizationCodeGrant {
if codeVerifier == "" { if codeVerifier == "" {
@@ -134,8 +137,13 @@ func TokenHandler() gin.HandlerFunc {
userID = claims.Subject userID = claims.Subject
roles = claims.Roles roles = claims.Roles
scope = claims.Scope scope = claims.Scope
loginMethod = claims.LoginMethod
// rollover the session for security // rollover the session for security
go memorystore.Provider.DeleteUserSession(userID, claims.Nonce) sessionKey = userID
if loginMethod != "" {
sessionKey = loginMethod + ":" + userID
}
go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
} else { } else {
// validate refresh token // validate refresh token
if refreshToken == "" { if refreshToken == "" {
@@ -155,6 +163,7 @@ func TokenHandler() gin.HandlerFunc {
}) })
} }
userID = claims["sub"].(string) userID = claims["sub"].(string)
loginMethod := claims["login_method"]
rolesInterface := claims["roles"].([]interface{}) rolesInterface := claims["roles"].([]interface{})
scopeInterface := claims["scope"].([]interface{}) scopeInterface := claims["scope"].([]interface{})
for _, v := range rolesInterface { for _, v := range rolesInterface {
@@ -163,11 +172,25 @@ func TokenHandler() gin.HandlerFunc {
for _, v := range scopeInterface { for _, v := range scopeInterface {
scope = append(scope, v.(string)) scope = append(scope, v.(string))
} }
sessionKey = userID
if loginMethod != nil && loginMethod != "" {
sessionKey = loginMethod.(string) + ":" + sessionKey
}
// remove older refresh token and rotate it for security // remove older refresh token and rotate it for security
go memorystore.Provider.DeleteUserSession(userID, claims["nonce"].(string)) go memorystore.Provider.DeleteUserSession(sessionKey, claims["nonce"].(string))
} }
user, err := db.Provider.GetUserByID(userID) if sessionKey == "" {
log.Debug("Error getting sessionKey: ", sessionKey, loginMethod)
gc.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"error_description": "User not found",
})
return
}
user, err := db.Provider.GetUserByID(gc, userID)
if err != nil { if err != nil {
log.Debug("Error getting user: ", err) log.Debug("Error getting user: ", err)
gc.JSON(http.StatusUnauthorized, gin.H{ gc.JSON(http.StatusUnauthorized, gin.H{
@@ -177,7 +200,7 @@ func TokenHandler() gin.HandlerFunc {
return return
} }
authToken, err := token.CreateAuthToken(gc, user, roles, scope) authToken, err := token.CreateAuthToken(gc, user, roles, scope, loginMethod)
if err != nil { if err != nil {
log.Debug("Error creating auth token: ", err) log.Debug("Error creating auth token: ", err)
gc.JSON(http.StatusUnauthorized, gin.H{ gc.JSON(http.StatusUnauthorized, gin.H{
@@ -186,8 +209,8 @@ func TokenHandler() gin.HandlerFunc {
}) })
return return
} }
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -205,7 +228,7 @@ func TokenHandler() gin.HandlerFunc {
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res["refresh_token"] = authToken.RefreshToken.Token res["refresh_token"] = authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
} }
gc.JSON(http.StatusOK, res) gc.JSON(http.StatusOK, res)

View File

@@ -31,7 +31,7 @@ func UserInfoHandler() gin.HandlerFunc {
} }
userID := claims["sub"].(string) userID := claims["sub"].(string)
user, err := db.Provider.GetUserByID(userID) user, err := db.Provider.GetUserByID(gc, userID)
if err != nil { if err != nil {
log.Debug("Error getting user: ", err) log.Debug("Error getting user: ", err)
gc.JSON(http.StatusUnauthorized, gin.H{ gc.JSON(http.StatusUnauthorized, gin.H{

View File

@@ -33,7 +33,7 @@ func VerifyEmailHandler() gin.HandlerFunc {
return return
} }
verificationRequest, err := db.Provider.GetVerificationRequestByToken(tokenInQuery) verificationRequest, err := db.Provider.GetVerificationRequestByToken(c, tokenInQuery)
if err != nil { if err != nil {
log.Debug("Error getting verification request: ", err) log.Debug("Error getting verification request: ", err)
errorRes["error_description"] = err.Error() errorRes["error_description"] = err.Error()
@@ -58,7 +58,7 @@ func VerifyEmailHandler() gin.HandlerFunc {
return return
} }
user, err := db.Provider.GetUserByEmail(verificationRequest.Email) user, err := db.Provider.GetUserByEmail(c, verificationRequest.Email)
if err != nil { if err != nil {
log.Debug("Error getting user: ", err) log.Debug("Error getting user: ", err)
errorRes["error_description"] = err.Error() errorRes["error_description"] = err.Error()
@@ -66,14 +66,16 @@ func VerifyEmailHandler() gin.HandlerFunc {
return return
} }
isSignUp := false
// update email_verified_at in users table // update email_verified_at in users table
if user.EmailVerifiedAt == nil { if user.EmailVerifiedAt == nil {
now := time.Now().Unix() now := time.Now().Unix()
user.EmailVerifiedAt = &now user.EmailVerifiedAt = &now
db.Provider.UpdateUser(user) isSignUp = true
db.Provider.UpdateUser(c, user)
} }
// delete from verification table // delete from verification table
db.Provider.DeleteVerificationRequest(verificationRequest) db.Provider.DeleteVerificationRequest(c, verificationRequest)
state := strings.TrimSpace(c.Query("state")) state := strings.TrimSpace(c.Query("state"))
redirectURL := strings.TrimSpace(c.Query("redirect_uri")) redirectURL := strings.TrimSpace(c.Query("redirect_uri"))
@@ -92,7 +94,11 @@ func VerifyEmailHandler() gin.HandlerFunc {
} else { } else {
scope = strings.Split(scopeString, " ") scope = strings.Split(scopeString, " ")
} }
authToken, err := token.CreateAuthToken(c, user, roles, scope) loginMethod := constants.AuthRecipeMethodBasicAuth
if verificationRequest.Identifier == constants.VerificationTypeMagicLinkLogin {
loginMethod = constants.AuthRecipeMethodMagicLinkLogin
}
authToken, err := token.CreateAuthToken(c, user, roles, scope, loginMethod)
if err != nil { if err != nil {
log.Debug("Error creating auth token: ", err) log.Debug("Error creating auth token: ", err)
errorRes["error_description"] = err.Error() errorRes["error_description"] = err.Error()
@@ -107,13 +113,14 @@ func VerifyEmailHandler() gin.HandlerFunc {
params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + state + "&id_token=" + authToken.IDToken.Token params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + state + "&id_token=" + authToken.IDToken.Token
sessionKey := loginMethod + ":" + user.ID
cookie.SetSession(c, authToken.FingerPrintHash) cookie.SetSession(c, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
params = params + `&refresh_token=` + authToken.RefreshToken.Token params = params + `&refresh_token=` + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
} }
if redirectURL == "" { if redirectURL == "" {
@@ -126,11 +133,19 @@ func VerifyEmailHandler() gin.HandlerFunc {
redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&") redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&")
} }
go db.Provider.AddSession(models.Session{ go func() {
UserID: user.ID, if isSignUp {
UserAgent: utils.GetUserAgent(c.Request), utils.RegisterEvent(c, constants.UserSignUpWebhookEvent, loginMethod, user)
IP: utils.GetIP(c.Request), } else {
}) utils.RegisterEvent(c, constants.UserLoginWebhookEvent, loginMethod, user)
}
db.Provider.AddSession(c, models.Session{
UserID: user.ID,
UserAgent: utils.GetUserAgent(c.Request),
IP: utils.GetIP(c.Request),
})
}()
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }

View File

@@ -30,6 +30,7 @@ func InitMemStore() error {
constants.EnvKeyDisableEmailVerification: false, constants.EnvKeyDisableEmailVerification: false,
constants.EnvKeyDisableLoginPage: false, constants.EnvKeyDisableLoginPage: false,
constants.EnvKeyDisableSignUp: false, constants.EnvKeyDisableSignUp: false,
constants.EnvKeyDisableStrongPassword: false,
} }
requiredEnvs := RequiredEnvStoreObj.GetRequiredEnv() requiredEnvs := RequiredEnvStoreObj.GetRequiredEnv()

View File

@@ -26,24 +26,34 @@ func (c *provider) GetUserSession(userId, sessionToken string) (string, error) {
// DeleteAllUserSessions deletes all the user sessions from in-memory store. // DeleteAllUserSessions deletes all the user sessions from in-memory store.
func (c *provider) DeleteAllUserSessions(userId string) error { func (c *provider) DeleteAllUserSessions(userId string) error {
if os.Getenv("ENV") != constants.TestEnv { namespaces := []string{
c.mutex.Lock() constants.AuthRecipeMethodBasicAuth,
defer c.mutex.Unlock() constants.AuthRecipeMethodMagicLinkLogin,
constants.AuthRecipeMethodApple,
constants.AuthRecipeMethodFacebook,
constants.AuthRecipeMethodGithub,
constants.AuthRecipeMethodGoogle,
constants.AuthRecipeMethodLinkedIn,
}
for _, namespace := range namespaces {
c.sessionStore.RemoveAll(namespace + ":" + userId)
} }
c.sessionStore.RemoveAll(userId)
return nil return nil
} }
// DeleteUserSession deletes the user session from the in-memory store. // DeleteUserSession deletes the user session from the in-memory store.
func (c *provider) DeleteUserSession(userId, sessionToken string) error { func (c *provider) DeleteUserSession(userId, sessionToken string) error {
if os.Getenv("ENV") != constants.TestEnv {
c.mutex.Lock()
defer c.mutex.Unlock()
}
c.sessionStore.Remove(userId, sessionToken) c.sessionStore.Remove(userId, sessionToken)
return nil return nil
} }
// DeleteSessionForNamespace to delete session for a given namespace example google,github
func (c *provider) DeleteSessionForNamespace(namespace string) error {
c.sessionStore.RemoveByNamespace(namespace)
return nil
}
// SetState sets the state in the in-memory store. // SetState sets the state in the in-memory store.
func (c *provider) SetState(key, state string) error { func (c *provider) SetState(key, state string) error {
if os.Getenv("ENV") != constants.TestEnv { if os.Getenv("ENV") != constants.TestEnv {

View File

@@ -1,10 +1,7 @@
package stores package stores
import ( import (
"os"
"sync" "sync"
"github.com/authorizerdev/authorizer/server/constants"
) )
// EnvStore struct to store the env variables // EnvStore struct to store the env variables
@@ -23,12 +20,10 @@ func NewEnvStore() *EnvStore {
// UpdateEnvStore to update the whole env store object // UpdateEnvStore to update the whole env store object
func (e *EnvStore) UpdateStore(store map[string]interface{}) { func (e *EnvStore) UpdateStore(store map[string]interface{}) {
if os.Getenv("ENV") != constants.TestEnv { e.mutex.Lock()
e.mutex.Lock() defer e.mutex.Unlock()
defer e.mutex.Unlock()
}
// just override the keys + new keys
// just override the keys + new keys
for key, value := range store { for key, value := range store {
e.store[key] = value e.store[key] = value
} }
@@ -46,9 +41,8 @@ func (e *EnvStore) Get(key string) interface{} {
// Set sets the value of the key in env store // Set sets the value of the key in env store
func (e *EnvStore) Set(key string, value interface{}) { func (e *EnvStore) Set(key string, value interface{}) {
if os.Getenv("ENV") != constants.TestEnv { e.mutex.Lock()
e.mutex.Lock() defer e.mutex.Unlock()
defer e.mutex.Unlock()
}
e.store[key] = value e.store[key] = value
} }

View File

@@ -1,10 +1,8 @@
package stores package stores
import ( import (
"os" "strings"
"sync" "sync"
"github.com/authorizerdev/authorizer/server/constants"
) )
// SessionStore struct to store the env variables // SessionStore struct to store the env variables
@@ -28,10 +26,9 @@ func (s *SessionStore) Get(key, subKey string) string {
// Set sets the value of the key in state store // Set sets the value of the key in state store
func (s *SessionStore) Set(key string, subKey, value string) { func (s *SessionStore) Set(key string, subKey, value string) {
if os.Getenv("ENV") != constants.TestEnv { s.mutex.Lock()
s.mutex.Lock() defer s.mutex.Unlock()
defer s.mutex.Unlock()
}
if _, ok := s.store[key]; !ok { if _, ok := s.store[key]; !ok {
s.store[key] = make(map[string]string) s.store[key] = make(map[string]string)
} }
@@ -40,19 +37,15 @@ func (s *SessionStore) Set(key string, subKey, value string) {
// RemoveAll all values for given key // RemoveAll all values for given key
func (s *SessionStore) RemoveAll(key string) { func (s *SessionStore) RemoveAll(key string) {
if os.Getenv("ENV") != constants.TestEnv { s.mutex.Lock()
s.mutex.Lock() defer s.mutex.Unlock()
defer s.mutex.Unlock()
}
delete(s.store, key) delete(s.store, key)
} }
// Remove value for given key and subkey // Remove value for given key and subkey
func (s *SessionStore) Remove(key, subKey string) { func (s *SessionStore) Remove(key, subKey string) {
if os.Getenv("ENV") != constants.TestEnv { s.mutex.Lock()
s.mutex.Lock() defer s.mutex.Unlock()
defer s.mutex.Unlock()
}
if _, ok := s.store[key]; ok { if _, ok := s.store[key]; ok {
delete(s.store[key], subKey) delete(s.store[key], subKey)
} }
@@ -65,3 +58,15 @@ func (s *SessionStore) GetAll(key string) map[string]string {
} }
return s.store[key] return s.store[key]
} }
// RemoveByNamespace to delete session for a given namespace example google,github
func (s *SessionStore) RemoveByNamespace(namespace string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
for key := range s.store {
if strings.Contains(key, namespace+":") {
delete(s.store, key)
}
}
return nil
}

View File

@@ -1,10 +1,7 @@
package stores package stores
import ( import (
"os"
"sync" "sync"
"github.com/authorizerdev/authorizer/server/constants"
) )
// StateStore struct to store the env variables // StateStore struct to store the env variables
@@ -28,19 +25,16 @@ func (s *StateStore) Get(key string) string {
// Set sets the value of the key in state store // Set sets the value of the key in state store
func (s *StateStore) Set(key string, value string) { func (s *StateStore) Set(key string, value string) {
if os.Getenv("ENV") != constants.TestEnv { s.mutex.Lock()
s.mutex.Lock() defer s.mutex.Unlock()
defer s.mutex.Unlock()
}
s.store[key] = value s.store[key] = value
} }
// Remove removes the key from state store // Remove removes the key from state store
func (s *StateStore) Remove(key string) { func (s *StateStore) Remove(key string) {
if os.Getenv("ENV") != constants.TestEnv { s.mutex.Lock()
s.mutex.Lock() defer s.mutex.Unlock()
defer s.mutex.Unlock()
}
delete(s.store, key) delete(s.store, key)
} }

View File

@@ -12,6 +12,8 @@ type Provider interface {
DeleteUserSession(userId, key string) error DeleteUserSession(userId, key string) error
// DeleteAllSessions deletes all the sessions from the session store // DeleteAllSessions deletes all the sessions from the session store
DeleteAllUserSessions(userId string) error DeleteAllUserSessions(userId string) error
// DeleteSessionForNamespace deletes the session for a given namespace
DeleteSessionForNamespace(namespace string) error
// SetState sets the login state (key, value form) in the session store // SetState sets the login state (key, value form) in the session store
SetState(key, state string) error SetState(key, state string) error

View File

@@ -63,11 +63,48 @@ func (c *provider) DeleteUserSession(userId, key string) error {
// DeleteAllUserSessions deletes all the user session from redis // DeleteAllUserSessions deletes all the user session from redis
func (c *provider) DeleteAllUserSessions(userID string) error { func (c *provider) DeleteAllUserSessions(userID string) error {
err := c.store.Del(c.ctx, userID).Err() namespaces := []string{
if err != nil { constants.AuthRecipeMethodBasicAuth,
log.Debug("Error deleting all user sessions from redis: ", err) constants.AuthRecipeMethodMagicLinkLogin,
return err constants.AuthRecipeMethodApple,
constants.AuthRecipeMethodFacebook,
constants.AuthRecipeMethodGithub,
constants.AuthRecipeMethodGoogle,
constants.AuthRecipeMethodLinkedIn,
} }
for _, namespace := range namespaces {
err := c.store.Del(c.ctx, namespace+":"+userID).Err()
if err != nil {
log.Debug("Error deleting all user sessions from redis: ", err)
return err
}
}
return nil
}
// DeleteSessionForNamespace to delete session for a given namespace example google,github
func (c *provider) DeleteSessionForNamespace(namespace string) error {
var cursor uint64
for {
keys := []string{}
keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result()
if err != nil {
log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error())
return err
}
for _, key := range keys {
err := c.store.Del(c.ctx, key).Err()
if err != nil {
log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error())
return err
}
}
if cursor == 0 { // no more keys
break
}
}
return nil return nil
} }
@@ -123,7 +160,7 @@ func (c *provider) GetEnvStore() (map[string]interface{}, error) {
return nil, err return nil, err
} }
for key, value := range data { for key, value := range data {
if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp { if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp || key == constants.EnvKeyDisableStrongPassword {
boolValue, err := strconv.ParseBool(value) boolValue, err := strconv.ParseBool(value)
if err != nil { if err != nil {
return res, err return res, err

View File

@@ -130,7 +130,6 @@ func InitOAuth() error {
AuthURL: "https://appleid.apple.com/auth/authorize", AuthURL: "https://appleid.apple.com/auth/authorize",
TokenURL: "https://appleid.apple.com/auth/token", TokenURL: "https://appleid.apple.com/auth/token",
}, },
Scopes: []string{"openid", "name", "email"},
} }
} }

14
server/refs/bool.go Normal file
View File

@@ -0,0 +1,14 @@
package refs
// NewBoolRef returns a reference to a bool with given value
func NewBoolRef(v bool) *bool {
return &v
}
// BoolValue returns the value of the given bool ref
func BoolValue(r *bool) bool {
if r == nil {
return false
}
return *r
}

14
server/refs/int.go Normal file
View File

@@ -0,0 +1,14 @@
package refs
// NewInt64Ref returns a reference to a int64 with given value
func NewInt64Ref(v int64) *int64 {
return &v
}
// Int64Value returns the value of the given bool ref
func Int64Value(r *int64) int64 {
if r == nil {
return 0
}
return *r
}

19
server/refs/string.go Normal file
View File

@@ -0,0 +1,19 @@
package refs
// NewStringRef returns a reference to a string with given value
func NewStringRef(v string) *string {
return &v
}
// StringValue returns the value of the given string ref
func StringValue(r *string, defaultValue ...string) string {
if r != nil {
return *r
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
return ""
}

View File

@@ -0,0 +1,53 @@
package resolvers
import (
"context"
"fmt"
"strings"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/authorizerdev/authorizer/server/validators"
log "github.com/sirupsen/logrus"
)
// TODO add template validator
// AddEmailTemplateResolver resolver for add email template mutation
func AddEmailTemplateResolver(ctx context.Context, params model.AddEmailTemplateRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if !validators.IsValidEmailTemplateEventName(params.EventName) {
log.Debug("Invalid Event Name: ", params.EventName)
return nil, fmt.Errorf("invalid event name %s", params.EventName)
}
if strings.TrimSpace(params.Template) == "" {
return nil, fmt.Errorf("empty template not allowed")
}
_, err = db.Provider.AddEmailTemplate(ctx, models.EmailTemplate{
EventName: params.EventName,
Template: params.Template,
})
if err != nil {
log.Debug("Failed to add email template: ", err)
return nil, err
}
return &model.Response{
Message: `Email template added successfully`,
}, nil
}

View File

@@ -0,0 +1,60 @@
package resolvers
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/authorizerdev/authorizer/server/validators"
log "github.com/sirupsen/logrus"
)
// AddWebhookResolver resolver for add webhook mutation
func AddWebhookResolver(ctx context.Context, params model.AddWebhookRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if !validators.IsValidWebhookEventName(params.EventName) {
log.Debug("Invalid Event Name: ", params.EventName)
return nil, fmt.Errorf("invalid event name %s", params.EventName)
}
if strings.TrimSpace(params.Endpoint) == "" {
log.Debug("empty endpoint not allowed")
return nil, fmt.Errorf("empty endpoint not allowed")
}
headerBytes, err := json.Marshal(params.Headers)
if err != nil {
return nil, err
}
_, err = db.Provider.AddWebhook(ctx, models.Webhook{
EventName: params.EventName,
EndPoint: params.Endpoint,
Enabled: params.Enabled,
Headers: string(headerBytes),
})
if err != nil {
log.Debug("Failed to add webhook: ", err)
return nil, err
}
return &model.Response{
Message: `Webhook added successfully`,
}, nil
}

View File

@@ -58,7 +58,7 @@ func AdminSignupResolver(ctx context.Context, params model.AdminSignupInput) (*m
return res, err return res, err
} }
env, err := db.Provider.GetEnv() env, err := db.Provider.GetEnv(ctx)
if err != nil { if err != nil {
log.Debug("Failed to get env: ", err) log.Debug("Failed to get env: ", err)
return res, err return res, err
@@ -71,7 +71,7 @@ func AdminSignupResolver(ctx context.Context, params model.AdminSignupInput) (*m
} }
env.EnvData = envData env.EnvData = envData
if _, err := db.Provider.UpdateEnv(env); err != nil { if _, err := db.Provider.UpdateEnv(ctx, env); err != nil {
log.Debug("Failed to update env: ", err) log.Debug("Failed to update env: ", err)
return res, err return res, err
} }

View File

@@ -0,0 +1,49 @@
package resolvers
import (
"context"
"fmt"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
log "github.com/sirupsen/logrus"
)
// DeleteEmailTemplateResolver resolver to delete email template and its relevant logs
func DeleteEmailTemplateResolver(ctx context.Context, params model.DeleteEmailTemplateRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if params.ID == "" {
log.Debug("email template is required")
return nil, fmt.Errorf("email template ID required")
}
log := log.WithField("email_template_id", params.ID)
emailTemplate, err := db.Provider.GetEmailTemplateByID(ctx, params.ID)
if err != nil {
log.Debug("failed to get email template: ", err)
return nil, err
}
err = db.Provider.DeleteEmailTemplate(ctx, emailTemplate)
if err != nil {
log.Debug("failed to delete email template: ", err)
return nil, err
}
return &model.Response{
Message: "Email templated deleted successfully",
}, nil
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/memorystore"
@@ -32,15 +33,13 @@ func DeleteUserResolver(ctx context.Context, params model.DeleteUserInput) (*mod
"email": params.Email, "email": params.Email,
}) })
user, err := db.Provider.GetUserByEmail(params.Email) user, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil { if err != nil {
log.Debug("Failed to get user from DB: ", err) log.Debug("Failed to get user from DB: ", err)
return res, err return res, err
} }
go memorystore.Provider.DeleteAllUserSessions(user.ID) err = db.Provider.DeleteUser(ctx, user)
err = db.Provider.DeleteUser(user)
if err != nil { if err != nil {
log.Debug("Failed to delete user: ", err) log.Debug("Failed to delete user: ", err)
return res, err return res, err
@@ -50,5 +49,10 @@ func DeleteUserResolver(ctx context.Context, params model.DeleteUserInput) (*mod
Message: `user deleted successfully`, Message: `user deleted successfully`,
} }
go func() {
memorystore.Provider.DeleteAllUserSessions(user.ID)
utils.RegisterEvent(ctx, constants.UserDeletedWebhookEvent, "", user)
}()
return res, nil return res, nil
} }

View File

@@ -0,0 +1,49 @@
package resolvers
import (
"context"
"fmt"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
log "github.com/sirupsen/logrus"
)
// DeleteWebhookResolver resolver to delete webhook and its relevant logs
func DeleteWebhookResolver(ctx context.Context, params model.WebhookRequest) (*model.Response, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
if params.ID == "" {
log.Debug("webhookID is required")
return nil, fmt.Errorf("webhook ID required")
}
log := log.WithField("webhook_id", params.ID)
webhook, err := db.Provider.GetWebhookByID(ctx, params.ID)
if err != nil {
log.Debug("failed to get webhook: ", err)
return nil, err
}
err = db.Provider.DeleteWebhook(ctx, webhook)
if err != nil {
log.Debug("failed to delete webhook: ", err)
return nil, err
}
return &model.Response{
Message: "Webhook deleted successfully",
}, nil
}

View File

@@ -0,0 +1,35 @@
package resolvers
import (
"context"
"fmt"
"github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils"
log "github.com/sirupsen/logrus"
)
// EmailTemplatesResolver resolver for getting the list of email templates based on pagination
func EmailTemplatesResolver(ctx context.Context, params *model.PaginatedInput) (*model.EmailTemplates, error) {
gc, err := utils.GinContextFromContext(ctx)
if err != nil {
log.Debug("Failed to get GinContext: ", err)
return nil, err
}
if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized")
}
pagination := utils.GetPagination(params)
emailTemplates, err := db.Provider.ListEmailTemplate(ctx, pagination)
if err != nil {
log.Debug("failed to get email templates: ", err)
return nil, err
}
return emailTemplates, nil
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/token"
@@ -31,7 +32,7 @@ func EnableAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
"user_id": params.UserID, "user_id": params.UserID,
}) })
user, err := db.Provider.GetUserByID(params.UserID) user, err := db.Provider.GetUserByID(ctx, params.UserID)
if err != nil { if err != nil {
log.Debug("Failed to get user from DB: ", err) log.Debug("Failed to get user from DB: ", err)
return res, err return res, err
@@ -39,7 +40,7 @@ func EnableAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
user.RevokedTimestamp = nil user.RevokedTimestamp = nil
user, err = db.Provider.UpdateUser(user) user, err = db.Provider.UpdateUser(ctx, user)
if err != nil { if err != nil {
log.Debug("Failed to update user: ", err) log.Debug("Failed to update user: ", err)
return res, err return res, err
@@ -49,5 +50,7 @@ func EnableAccessResolver(ctx context.Context, params model.UpdateAccessInput) (
Message: `user access enabled successfully`, Message: `user access enabled successfully`,
} }
go utils.RegisterEvent(ctx, constants.UserAccessEnabledWebhookEvent, "", user)
return res, nil return res, nil
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils" "github.com/authorizerdev/authorizer/server/utils"
) )
@@ -38,10 +39,10 @@ func EnvResolver(ctx context.Context) (*model.Env, error) {
} }
if val, ok := store[constants.EnvKeyAccessTokenExpiryTime]; ok { if val, ok := store[constants.EnvKeyAccessTokenExpiryTime]; ok {
res.AccessTokenExpiryTime = utils.NewStringRef(val.(string)) res.AccessTokenExpiryTime = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyAdminSecret]; ok { if val, ok := store[constants.EnvKeyAdminSecret]; ok {
res.AdminSecret = utils.NewStringRef(val.(string)) res.AdminSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyClientID]; ok { if val, ok := store[constants.EnvKeyClientID]; ok {
res.ClientID = val.(string) res.ClientID = val.(string)
@@ -50,103 +51,103 @@ func EnvResolver(ctx context.Context) (*model.Env, error) {
res.ClientSecret = val.(string) res.ClientSecret = val.(string)
} }
if val, ok := store[constants.EnvKeyDatabaseURL]; ok { if val, ok := store[constants.EnvKeyDatabaseURL]; ok {
res.DatabaseURL = utils.NewStringRef(val.(string)) res.DatabaseURL = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDatabaseName]; ok { if val, ok := store[constants.EnvKeyDatabaseName]; ok {
res.DatabaseName = utils.NewStringRef(val.(string)) res.DatabaseName = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDatabaseType]; ok { if val, ok := store[constants.EnvKeyDatabaseType]; ok {
res.DatabaseType = utils.NewStringRef(val.(string)) res.DatabaseType = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDatabaseUsername]; ok { if val, ok := store[constants.EnvKeyDatabaseUsername]; ok {
res.DatabaseUsername = utils.NewStringRef(val.(string)) res.DatabaseUsername = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDatabasePassword]; ok { if val, ok := store[constants.EnvKeyDatabasePassword]; ok {
res.DatabasePassword = utils.NewStringRef(val.(string)) res.DatabasePassword = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDatabaseHost]; ok { if val, ok := store[constants.EnvKeyDatabaseHost]; ok {
res.DatabaseHost = utils.NewStringRef(val.(string)) res.DatabaseHost = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDatabasePort]; ok { if val, ok := store[constants.EnvKeyDatabasePort]; ok {
res.DatabasePort = utils.NewStringRef(val.(string)) res.DatabasePort = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyCustomAccessTokenScript]; ok { if val, ok := store[constants.EnvKeyCustomAccessTokenScript]; ok {
res.CustomAccessTokenScript = utils.NewStringRef(val.(string)) res.CustomAccessTokenScript = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeySmtpHost]; ok { if val, ok := store[constants.EnvKeySmtpHost]; ok {
res.SMTPHost = utils.NewStringRef(val.(string)) res.SMTPHost = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeySmtpPort]; ok { if val, ok := store[constants.EnvKeySmtpPort]; ok {
res.SMTPPort = utils.NewStringRef(val.(string)) res.SMTPPort = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeySmtpUsername]; ok { if val, ok := store[constants.EnvKeySmtpUsername]; ok {
res.SMTPUsername = utils.NewStringRef(val.(string)) res.SMTPUsername = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeySmtpPassword]; ok { if val, ok := store[constants.EnvKeySmtpPassword]; ok {
res.SMTPPassword = utils.NewStringRef(val.(string)) res.SMTPPassword = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeySenderEmail]; ok { if val, ok := store[constants.EnvKeySenderEmail]; ok {
res.SenderEmail = utils.NewStringRef(val.(string)) res.SenderEmail = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyJwtType]; ok { if val, ok := store[constants.EnvKeyJwtType]; ok {
res.JwtType = utils.NewStringRef(val.(string)) res.JwtType = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyJwtSecret]; ok { if val, ok := store[constants.EnvKeyJwtSecret]; ok {
res.JwtSecret = utils.NewStringRef(val.(string)) res.JwtSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyJwtRoleClaim]; ok { if val, ok := store[constants.EnvKeyJwtRoleClaim]; ok {
res.JwtRoleClaim = utils.NewStringRef(val.(string)) res.JwtRoleClaim = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyJwtPublicKey]; ok { if val, ok := store[constants.EnvKeyJwtPublicKey]; ok {
res.JwtPublicKey = utils.NewStringRef(val.(string)) res.JwtPublicKey = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyJwtPrivateKey]; ok { if val, ok := store[constants.EnvKeyJwtPrivateKey]; ok {
res.JwtPrivateKey = utils.NewStringRef(val.(string)) res.JwtPrivateKey = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyAppURL]; ok { if val, ok := store[constants.EnvKeyAppURL]; ok {
res.AppURL = utils.NewStringRef(val.(string)) res.AppURL = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyRedisURL]; ok { if val, ok := store[constants.EnvKeyRedisURL]; ok {
res.RedisURL = utils.NewStringRef(val.(string)) res.RedisURL = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyResetPasswordURL]; ok { if val, ok := store[constants.EnvKeyResetPasswordURL]; ok {
res.ResetPasswordURL = utils.NewStringRef(val.(string)) res.ResetPasswordURL = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyGoogleClientID]; ok { if val, ok := store[constants.EnvKeyGoogleClientID]; ok {
res.GoogleClientID = utils.NewStringRef(val.(string)) res.GoogleClientID = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyGoogleClientSecret]; ok { if val, ok := store[constants.EnvKeyGoogleClientSecret]; ok {
res.GoogleClientSecret = utils.NewStringRef(val.(string)) res.GoogleClientSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyFacebookClientID]; ok { if val, ok := store[constants.EnvKeyFacebookClientID]; ok {
res.FacebookClientID = utils.NewStringRef(val.(string)) res.FacebookClientID = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyFacebookClientSecret]; ok { if val, ok := store[constants.EnvKeyFacebookClientSecret]; ok {
res.FacebookClientSecret = utils.NewStringRef(val.(string)) res.FacebookClientSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyGithubClientID]; ok { if val, ok := store[constants.EnvKeyGithubClientID]; ok {
res.GithubClientID = utils.NewStringRef(val.(string)) res.GithubClientID = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyGithubClientSecret]; ok { if val, ok := store[constants.EnvKeyGithubClientSecret]; ok {
res.GithubClientSecret = utils.NewStringRef(val.(string)) res.GithubClientSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyLinkedInClientID]; ok { if val, ok := store[constants.EnvKeyLinkedInClientID]; ok {
res.LinkedinClientID = utils.NewStringRef(val.(string)) res.LinkedinClientID = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyLinkedInClientSecret]; ok { if val, ok := store[constants.EnvKeyLinkedInClientSecret]; ok {
res.LinkedinClientSecret = utils.NewStringRef(val.(string)) res.LinkedinClientSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyAppleClientID]; ok { if val, ok := store[constants.EnvKeyAppleClientID]; ok {
res.AppleClientID = utils.NewStringRef(val.(string)) res.AppleClientID = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyAppleClientSecret]; ok { if val, ok := store[constants.EnvKeyAppleClientSecret]; ok {
res.AppleClientSecret = utils.NewStringRef(val.(string)) res.AppleClientSecret = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyOrganizationName]; ok { if val, ok := store[constants.EnvKeyOrganizationName]; ok {
res.OrganizationName = utils.NewStringRef(val.(string)) res.OrganizationName = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyOrganizationLogo]; ok { if val, ok := store[constants.EnvKeyOrganizationLogo]; ok {
res.OrganizationLogo = utils.NewStringRef(val.(string)) res.OrganizationLogo = refs.NewStringRef(val.(string))
} }
// string slice vars // string slice vars
@@ -168,6 +169,7 @@ func EnvResolver(ctx context.Context) (*model.Env, error) {
res.DisableMagicLinkLogin = store[constants.EnvKeyDisableMagicLinkLogin].(bool) res.DisableMagicLinkLogin = store[constants.EnvKeyDisableMagicLinkLogin].(bool)
res.DisableLoginPage = store[constants.EnvKeyDisableLoginPage].(bool) res.DisableLoginPage = store[constants.EnvKeyDisableLoginPage].(bool)
res.DisableSignUp = store[constants.EnvKeyDisableSignUp].(bool) res.DisableSignUp = store[constants.EnvKeyDisableSignUp].(bool)
res.DisableStrongPassword = store[constants.EnvKeyDisableStrongPassword].(bool)
return res, nil return res, nil
} }

View File

@@ -49,7 +49,7 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu
log := log.WithFields(log.Fields{ log := log.WithFields(log.Fields{
"email": params.Email, "email": params.Email,
}) })
_, err = db.Provider.GetUserByEmail(params.Email) _, err = db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil { if err != nil {
log.Debug("User not found: ", err) log.Debug("User not found: ", err)
return res, fmt.Errorf(`user with this email not found`) return res, fmt.Errorf(`user with this email not found`)
@@ -71,7 +71,7 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu
log.Debug("Failed to create verification token", err) log.Debug("Failed to create verification token", err)
return res, err return res, err
} }
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{ _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken, Token: verificationToken,
Identifier: constants.VerificationTypeForgotPassword, Identifier: constants.VerificationTypeForgotPassword,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),

View File

@@ -70,7 +70,7 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput)
// for each emails check if emails exists in db // for each emails check if emails exists in db
newEmails := []string{} newEmails := []string{}
for _, email := range emails { for _, email := range emails {
_, err := db.Provider.GetUserByEmail(email) _, err := db.Provider.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
log.Debugf("User with %s email not found, so inviting user", email) log.Debugf("User with %s email not found, so inviting user", email)
newEmails = append(newEmails, email) newEmails = append(newEmails, email)
@@ -129,24 +129,24 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput)
// use magic link login if that option is on // use magic link login if that option is on
if !isMagicLinkLoginDisabled { if !isMagicLinkLoginDisabled {
user.SignupMethods = constants.SignupMethodMagicLinkLogin user.SignupMethods = constants.AuthRecipeMethodMagicLinkLogin
verificationRequest.Identifier = constants.VerificationTypeMagicLinkLogin verificationRequest.Identifier = constants.VerificationTypeMagicLinkLogin
} else { } else {
// use basic authentication if that option is on // use basic authentication if that option is on
user.SignupMethods = constants.SignupMethodBasicAuth user.SignupMethods = constants.AuthRecipeMethodBasicAuth
verificationRequest.Identifier = constants.VerificationTypeForgotPassword verificationRequest.Identifier = constants.VerificationTypeForgotPassword
verifyEmailURL = appURL + "/setup-password" verifyEmailURL = appURL + "/setup-password"
} }
user, err = db.Provider.AddUser(user) user, err = db.Provider.AddUser(ctx, user)
if err != nil { if err != nil {
log.Debugf("Error adding user: %s, err: %v", email, err) log.Debugf("Error adding user: %s, err: %v", email, err)
return nil, err return nil, err
} }
_, err = db.Provider.AddVerificationRequest(verificationRequest) _, err = db.Provider.AddVerificationRequest(ctx, verificationRequest)
if err != nil { if err != nil {
log.Debugf("Error adding verification request: %s, err: %v", email, err) log.Debugf("Error adding verification request: %s, err: %v", email, err)
return nil, err return nil, err

View File

@@ -45,7 +45,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
"email": params.Email, "email": params.Email,
}) })
params.Email = strings.ToLower(params.Email) params.Email = strings.ToLower(params.Email)
user, err := db.Provider.GetUserByEmail(params.Email) user, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil { if err != nil {
log.Debug("Failed to get user by email: ", err) log.Debug("Failed to get user by email: ", err)
return res, fmt.Errorf(`user with this email not found`) return res, fmt.Errorf(`user with this email not found`)
@@ -56,7 +56,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
return res, fmt.Errorf(`user access has been revoked`) return res, fmt.Errorf(`user access has been revoked`)
} }
if !strings.Contains(user.SignupMethods, constants.SignupMethodBasicAuth) { if !strings.Contains(user.SignupMethods, constants.AuthRecipeMethodBasicAuth) {
log.Debug("User signup method is not basic auth") log.Debug("User signup method is not basic auth")
return res, fmt.Errorf(`user has not signed up email & password`) return res, fmt.Errorf(`user has not signed up email & password`)
} }
@@ -97,7 +97,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
scope = params.Scope scope = params.Scope
} }
authToken, err := token.CreateAuthToken(gc, user, roles, scope) authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth)
if err != nil { if err != nil {
log.Debug("Failed to create auth token", err) log.Debug("Failed to create auth token", err)
return res, err return res, err
@@ -117,19 +117,23 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
} }
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(user.ID, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
} }
go db.Provider.AddSession(models.Session{ go func() {
UserID: user.ID, utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodBasicAuth, user)
UserAgent: utils.GetUserAgent(gc.Request), db.Provider.AddSession(ctx, models.Session{
IP: utils.GetIP(gc.Request), UserID: user.ID,
}) UserAgent: utils.GetUserAgent(gc.Request),
IP: utils.GetIP(gc.Request),
})
}()
return res, nil return res, nil
} }

View File

@@ -41,7 +41,12 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) {
return nil, err return nil, err
} }
memorystore.Provider.DeleteUserSession(sessionData.Subject, sessionData.Nonce) sessionKey := sessionData.Subject
if sessionData.LoginMethod != "" {
sessionKey = sessionData.LoginMethod + ":" + sessionData.Subject
}
memorystore.Provider.DeleteUserSession(sessionKey, sessionData.Nonce)
cookie.DeleteSession(gc) cookie.DeleteSession(gc)
res := &model.Response{ res := &model.Response{

View File

@@ -59,7 +59,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
} }
// find user with email // find user with email
existingUser, err := db.Provider.GetUserByEmail(params.Email) existingUser, err := db.Provider.GetUserByEmail(ctx, params.Email)
if err != nil { if err != nil {
isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp)
if err != nil { if err != nil {
@@ -70,7 +70,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
return res, fmt.Errorf(`signup is disabled for this instance`) return res, fmt.Errorf(`signup is disabled for this instance`)
} }
user.SignupMethods = constants.SignupMethodMagicLinkLogin user.SignupMethods = constants.AuthRecipeMethodMagicLinkLogin
// define roles for new user // define roles for new user
if len(params.Roles) > 0 { if len(params.Roles) > 0 {
// check if roles exists // check if roles exists
@@ -99,7 +99,8 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
} }
user.Roles = strings.Join(inputRoles, ",") user.Roles = strings.Join(inputRoles, ",")
user, _ = db.Provider.AddUser(user) user, _ = db.Provider.AddUser(ctx, user)
go utils.RegisterEvent(ctx, constants.UserCreatedWebhookEvent, constants.AuthRecipeMethodMagicLinkLogin, user)
} else { } else {
user = existingUser user = existingUser
// There multiple scenarios with roles here in magic link login // There multiple scenarios with roles here in magic link login
@@ -158,12 +159,12 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
} }
signupMethod := existingUser.SignupMethods signupMethod := existingUser.SignupMethods
if !strings.Contains(signupMethod, constants.SignupMethodMagicLinkLogin) { if !strings.Contains(signupMethod, constants.AuthRecipeMethodMagicLinkLogin) {
signupMethod = signupMethod + "," + constants.SignupMethodMagicLinkLogin signupMethod = signupMethod + "," + constants.AuthRecipeMethodMagicLinkLogin
} }
user.SignupMethods = signupMethod user.SignupMethods = signupMethod
user, _ = db.Provider.UpdateUser(user) user, _ = db.Provider.UpdateUser(ctx, user)
if err != nil { if err != nil {
log.Debug("Failed to update user: ", err) log.Debug("Failed to update user: ", err)
} }
@@ -205,7 +206,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu
if err != nil { if err != nil {
log.Debug("Failed to create verification token: ", err) log.Debug("Failed to create verification token: ", err)
} }
_, err = db.Provider.AddVerificationRequest(models.VerificationRequest{ _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{
Token: verificationToken, Token: verificationToken,
Identifier: verificationType, Identifier: verificationType,
ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), ExpiresAt: time.Now().Add(time.Minute * 30).Unix(),

Some files were not shown because too many files have changed in this diff Show More