diff --git a/server/constants/db_types.go b/server/constants/db_types.go index cf08597..30ae23e 100644 --- a/server/constants/db_types.go +++ b/server/constants/db_types.go @@ -17,4 +17,6 @@ const ( DbTypeYugabyte = "yugabyte" // DbTypeMariaDB is the mariadb database type DbTypeMariaDB = "mariadb" + // DbTypeCassandra is the cassandra database type + DbTypeCassandraDB = "cassandradb" ) diff --git a/server/constants/env.go b/server/constants/env.go index 66b466a..4f391a9 100644 --- a/server/constants/env.go +++ b/server/constants/env.go @@ -30,6 +30,14 @@ const ( EnvKeyDatabaseURL = "DATABASE_URL" // EnvKeyDatabaseName key for env variable DATABASE_NAME EnvKeyDatabaseName = "DATABASE_NAME" + // EnvKeyDatabaseUsername key for env variable DATABASE_USERNAME + EnvKeyDatabaseUsername = "DATABASE_USERNAME" + // EnvKeyDatabasePassword key for env variable DATABASE_PASSWORD + EnvKeyDatabasePassword = "DATABASE_PASSWORD" + // EnvKeyDatabasePort key for env variable DATABASE_PORT + EnvKeyDatabasePort = "DATABASE_PORT" + // EnvKeyDatabaseHost key for env variable DATABASE_HOST + EnvKeyDatabaseHost = "DATABASE_HOST" // EnvKeySmtpHost key for env variable SMTP_HOST EnvKeySmtpHost = "SMTP_HOST" // EnvKeySmtpPort key for env variable SMTP_PORT diff --git a/server/db/db.go b/server/db/db.go index 17fdbec..70b1033 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -4,6 +4,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/providers" "github.com/authorizerdev/authorizer/server/db/providers/arangodb" + "github.com/authorizerdev/authorizer/server/db/providers/cassandradb" "github.com/authorizerdev/authorizer/server/db/providers/mongodb" "github.com/authorizerdev/authorizer/server/db/providers/sql" "github.com/authorizerdev/authorizer/server/envstore" @@ -15,9 +16,10 @@ var Provider providers.Provider func InitDB() error { var err error - isSQL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeArangodb && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeMongodb + isSQL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeArangodb && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeMongodb && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeCassandraDB isArangoDB := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) == constants.DbTypeArangodb isMongoDB := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) == constants.DbTypeMongodb + isCassandra := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) == constants.DbTypeCassandraDB if isSQL { Provider, err = sql.NewProvider() @@ -40,5 +42,12 @@ func InitDB() error { } } + if isCassandra { + Provider, err = cassandradb.NewProvider() + if err != nil { + return err + } + } + return nil } diff --git a/server/db/models/env.go b/server/db/models/env.go index 16a250b..959284a 100644 --- a/server/db/models/env.go +++ b/server/db/models/env.go @@ -1,11 +1,13 @@ package models +// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation + // Env model for db type Env struct { - Key string `json:"_key,omitempty" bson:"_key"` // for arangodb - ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id"` - EnvData string `gorm:"type:text" json:"env" bson:"env"` - Hash string `gorm:"type:text" json:"hash" bson:"hash"` - UpdatedAt int64 `json:"updated_at" bson:"updated_at"` - CreatedAt int64 `json:"created_at" bson:"created_at"` + Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb + ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"` + EnvData string `gorm:"type:text" json:"env" bson:"env" cql:"env"` + Hash string `gorm:"type:text" json:"hash" bson:"hash" cql:"hash"` + UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"` + CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"` } diff --git a/server/db/models/session.go b/server/db/models/session.go index 34b715f..a495135 100644 --- a/server/db/models/session.go +++ b/server/db/models/session.go @@ -1,13 +1,15 @@ package models +// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation + // Session model for db type Session struct { - Key string `json:"_key,omitempty" bson:"_key,omitempty"` // for arangodb - ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id"` - UserID string `gorm:"type:char(36),index:" json:"user_id" bson:"user_id"` - User User `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"-" bson:"-"` - UserAgent string `json:"user_agent" bson:"user_agent"` - IP string `json:"ip" bson:"ip"` - CreatedAt int64 `json:"created_at" bson:"created_at"` - UpdatedAt int64 `json:"updated_at" bson:"updated_at"` + Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb + ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"` + UserID string `gorm:"type:char(36),index:" json:"user_id" bson:"user_id" cql:"user_id"` + User User `gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;" json:"-" bson:"-" cql:"-"` + UserAgent string `json:"user_agent" bson:"user_agent" cql:"user_agent"` + IP string `json:"ip" bson:"ip" cql:"ip"` + CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"` + UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"` } diff --git a/server/db/models/user.go b/server/db/models/user.go index d07486a..e072650 100644 --- a/server/db/models/user.go +++ b/server/db/models/user.go @@ -6,28 +6,30 @@ import ( "github.com/authorizerdev/authorizer/server/graph/model" ) +// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation + // User model for db type User struct { - Key string `json:"_key,omitempty" bson:"_key"` // for arangodb - ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id"` + Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty"` // for arangodb + ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"` - Email string `gorm:"unique" json:"email" bson:"email"` - EmailVerifiedAt *int64 `json:"email_verified_at" bson:"email_verified_at"` - Password *string `gorm:"type:text" json:"password" bson:"password"` - SignupMethods string `json:"signup_methods" bson:"signup_methods"` - GivenName *string `json:"given_name" bson:"given_name"` - FamilyName *string `json:"family_name" bson:"family_name"` - MiddleName *string `json:"middle_name" bson:"middle_name"` - Nickname *string `json:"nickname" bson:"nickname"` - Gender *string `json:"gender" bson:"gender"` - Birthdate *string `json:"birthdate" bson:"birthdate"` - PhoneNumber *string `gorm:"unique" json:"phone_number" bson:"phone_number"` - PhoneNumberVerifiedAt *int64 `json:"phone_number_verified_at" bson:"phone_number_verified_at"` - Picture *string `gorm:"type:text" json:"picture" bson:"picture"` - Roles string `json:"roles" bson:"roles"` - UpdatedAt int64 `json:"updated_at" bson:"updated_at"` - CreatedAt int64 `json:"created_at" bson:"created_at"` - RevokedTimestamp *int64 `json:"revoked_timestamp" bson:"revoked_timestamp"` + Email string `gorm:"unique" json:"email" bson:"email" cql:"email"` + EmailVerifiedAt *int64 `json:"email_verified_at" bson:"email_verified_at" cql:"email_verified_at"` + Password *string `gorm:"type:text" json:"password" bson:"password" cql:"password"` + SignupMethods string `json:"signup_methods" bson:"signup_methods" cql:"signup_methods"` + GivenName *string `json:"given_name" bson:"given_name" cql:"given_name"` + FamilyName *string `json:"family_name" bson:"family_name" cql:"family_name"` + MiddleName *string `json:"middle_name" bson:"middle_name" cql:"middle_name"` + Nickname *string `json:"nickname" bson:"nickname" cql:"nickname"` + Gender *string `json:"gender" bson:"gender" cql:"gender"` + Birthdate *string `json:"birthdate" bson:"birthdate" cql:"birthdate"` + PhoneNumber *string `gorm:"unique" json:"phone_number" bson:"phone_number" cql:"phone_number"` + PhoneNumberVerifiedAt *int64 `json:"phone_number_verified_at" bson:"phone_number_verified_at" cql:"phone_number_verified_at"` + Picture *string `gorm:"type:text" json:"picture" bson:"picture" cql:"picture"` + Roles string `json:"roles" bson:"roles" cql:"roles"` + RevokedTimestamp *int64 `json:"revoked_timestamp" bson:"revoked_timestamp" cql:"revoked_timestamp"` + UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"` + CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"` } func (user *User) AsAPIUser() *model.User { @@ -53,8 +55,8 @@ func (user *User) AsAPIUser() *model.User { PhoneNumberVerified: &isPhoneVerified, Picture: user.Picture, Roles: strings.Split(user.Roles, ","), + RevokedTimestamp: revokedTimestamp, CreatedAt: &createdAt, UpdatedAt: &updatedAt, - RevokedTimestamp: revokedTimestamp, } } diff --git a/server/db/models/verification_requests.go b/server/db/models/verification_requests.go index cbb0322..afd9ad7 100644 --- a/server/db/models/verification_requests.go +++ b/server/db/models/verification_requests.go @@ -2,18 +2,20 @@ package models import "github.com/authorizerdev/authorizer/server/graph/model" +// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation + // VerificationRequest model for db type VerificationRequest struct { - Key string `json:"_key,omitempty" bson:"_key"` // for arangodb - ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id"` - Token string `gorm:"type:text" json:"token" bson:"token"` - Identifier string `gorm:"uniqueIndex:idx_email_identifier;type:varchar(64)" json:"identifier" bson:"identifier"` - ExpiresAt int64 `json:"expires_at" bson:"expires_at"` - CreatedAt int64 `json:"created_at" bson:"created_at"` - UpdatedAt int64 `json:"updated_at" bson:"updated_at"` - Email string `gorm:"uniqueIndex:idx_email_identifier;type:varchar(256)" json:"email" bson:"email"` - Nonce string `gorm:"type:text" json:"nonce" bson:"nonce"` - RedirectURI string `gorm:"type:text" json:"redirect_uri" bson:"redirect_uri"` + Key string `json:"_key,omitempty" bson:"_key" cql:"_key,omitempty"` // for arangodb + ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id"` + Token string `gorm:"type:text" json:"token" bson:"token" cql:"jwt_token"` // token is reserved keyword in cassandra + Identifier string `gorm:"uniqueIndex:idx_email_identifier;type:varchar(64)" json:"identifier" bson:"identifier" cql:"identifier"` + ExpiresAt int64 `json:"expires_at" bson:"expires_at" cql:"expires_at"` + Email string `gorm:"uniqueIndex:idx_email_identifier;type:varchar(256)" json:"email" bson:"email" cql:"email"` + Nonce string `gorm:"type:text" json:"nonce" bson:"nonce" cql:"nonce"` + RedirectURI string `gorm:"type:text" json:"redirect_uri" bson:"redirect_uri" cql:"redirect_uri"` + CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at"` + UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at"` } func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequest { @@ -30,10 +32,10 @@ func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequ Token: &token, Identifier: &identifier, Expires: &expires, - CreatedAt: &createdAt, - UpdatedAt: &updatedAt, Email: &email, Nonce: &nonce, RedirectURI: &redirectURI, + CreatedAt: &createdAt, + UpdatedAt: &updatedAt, } } diff --git a/server/db/providers/arangodb/arangodb.go b/server/db/providers/arangodb/provider.go similarity index 100% rename from server/db/providers/arangodb/arangodb.go rename to server/db/providers/arangodb/provider.go diff --git a/server/db/providers/cassandradb/env.go b/server/db/providers/cassandradb/env.go new file mode 100644 index 0000000..bd684b3 --- /dev/null +++ b/server/db/providers/cassandradb/env.go @@ -0,0 +1,52 @@ +package cassandradb + +import ( + "fmt" + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/gocql/gocql" + "github.com/google/uuid" +) + +// AddEnv to save environment information in database +func (p *provider) AddEnv(env models.Env) (models.Env, error) { + if env.ID == "" { + env.ID = uuid.New().String() + } + + env.CreatedAt = time.Now().Unix() + env.UpdatedAt = time.Now().Unix() + insertEnvQuery := fmt.Sprintf("INSERT INTO %s (id, env, hash, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.Env, env.ID, env.EnvData, env.Hash, env.CreatedAt, env.UpdatedAt) + err := p.db.Query(insertEnvQuery).Exec() + if err != nil { + return env, err + } + + return env, nil +} + +// UpdateEnv to update environment information in database +func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { + env.UpdatedAt = time.Now().Unix() + + updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = '%s', updated_at = %d WHERE id = '%s'", KeySpace+"."+models.Collections.Env, env.EnvData, env.UpdatedAt, env.ID) + err := p.db.Query(updateEnvQuery).Exec() + if err != nil { + return env, err + } + return env, nil +} + +// GetEnv to get environment information from database +func (p *provider) GetEnv() (models.Env, error) { + var env models.Env + + query := fmt.Sprintf("SELECT id, env, hash, created_at, updated_at FROM %s LIMIT 1", KeySpace+"."+models.Collections.Env) + err := p.db.Query(query).Consistency(gocql.One).Scan(&env.ID, &env.EnvData, &env.Hash, &env.CreatedAt, &env.UpdatedAt) + if err != nil { + return env, err + } + + return env, nil +} diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go new file mode 100644 index 0000000..4c41870 --- /dev/null +++ b/server/db/providers/cassandradb/provider.go @@ -0,0 +1,116 @@ +package cassandradb + +import ( + "fmt" + "log" + "strings" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/envstore" + cansandraDriver "github.com/gocql/gocql" +) + +type provider struct { + db *cansandraDriver.Session +} + +// KeySpace for the cassandra database +var KeySpace string + +// NewProvider to initialize arangodb connection +func NewProvider() (*provider, error) { + dbURL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + KeySpace = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName) + clusterURL := []string{} + if strings.Contains(dbURL, ",") { + clusterURL = strings.Split(dbURL, ",") + } else { + clusterURL = append(clusterURL, dbURL) + } + cassandraClient := cansandraDriver.NewCluster(clusterURL...) + if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword) != "" { + cassandraClient.Authenticator = &cansandraDriver.PasswordAuthenticator{ + Username: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername), + Password: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword), + } + } + + cassandraClient.RetryPolicy = &cansandraDriver.SimpleRetryPolicy{ + NumRetries: 3, + } + cassandraClient.Consistency = cansandraDriver.Quorum + + session, err := cassandraClient.CreateSession() + if err != nil { + log.Println("Error while creating connection to cassandra db", err) + return nil, err + } + + keyspaceQuery := fmt.Sprintf("CREATE KEYSPACE IF NOT EXISTS %s WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}", + KeySpace) + err = session.Query(keyspaceQuery).Exec() + if err != nil { + log.Println("Unable to create keyspace:", err) + return nil, err + } + + // make sure collections are present + envCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, env text, hash text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", + KeySpace, models.Collections.Env) + err = session.Query(envCollectionQuery).Exec() + if err != nil { + log.Println("Unable to create env collection:", err) + return nil, err + } + + sessionCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, user_id text, user_agent text, ip text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Session) + err = session.Query(sessionCollectionQuery).Exec() + if err != nil { + log.Println("Unable to create session collection:", err) + return nil, err + } + + userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User) + err = session.Query(userCollectionQuery).Exec() + if err != nil { + log.Println("Unable to create user collection:", err) + return nil, err + } + userIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_user_email ON %s.%s (email)", KeySpace, models.Collections.User) + err = session.Query(userIndexQuery).Exec() + if err != nil { + log.Println("Unable to create user index:", err) + return nil, err + } + + // token is reserved keyword in cassandra, hence we need to use jwt_token + verificationRequestCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, jwt_token text, identifier text, expires_at bigint, email text, nonce text, redirect_uri text, created_at bigint, updated_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestCollectionQuery).Exec() + if err != nil { + log.Println("Unable to create verification request collection:", err) + return nil, err + } + verificationRequestIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_email ON %s.%s (email)", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestIndexQuery).Exec() + if err != nil { + log.Println("Unable to create verification_requests index:", err) + return nil, err + } + verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_identifier ON %s.%s (identifier)", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestIndexQuery).Exec() + if err != nil { + log.Println("Unable to create verification_requests index:", err) + return nil, err + } + verificationRequestIndexQuery = fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_verification_request_jwt_token ON %s.%s (jwt_token)", KeySpace, models.Collections.VerificationRequest) + err = session.Query(verificationRequestIndexQuery).Exec() + if err != nil { + log.Println("Unable to create verification_requests index:", err) + return nil, err + } + + return &provider{ + db: session, + }, err +} diff --git a/server/db/providers/cassandradb/session.go b/server/db/providers/cassandradb/session.go new file mode 100644 index 0000000..28e46ed --- /dev/null +++ b/server/db/providers/cassandradb/session.go @@ -0,0 +1,36 @@ +package cassandradb + +import ( + "fmt" + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/google/uuid" +) + +// AddSession to save session information in database +func (p *provider) AddSession(session models.Session) error { + if session.ID == "" { + session.ID = uuid.New().String() + } + + session.CreatedAt = time.Now().Unix() + session.UpdatedAt = time.Now().Unix() + + insertSessionQuery := fmt.Sprintf("INSERT INTO %s (id, user_id, user_agent, ip, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.Session, session.ID, session.UserID, session.UserAgent, session.IP, session.CreatedAt, session.UpdatedAt) + err := p.db.Query(insertSessionQuery).Exec() + if err != nil { + return err + } + return nil +} + +// DeleteSession to delete session information from database +func (p *provider) DeleteSession(userId string) error { + deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE user_id = '%s'", KeySpace+"."+models.Collections.Session, userId) + err := p.db.Query(deleteSessionQuery).Exec() + if err != nil { + return err + } + return nil +} diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go new file mode 100644 index 0000000..09b7476 --- /dev/null +++ b/server/db/providers/cassandradb/user.go @@ -0,0 +1,189 @@ +package cassandradb + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "time" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/gocql/gocql" + "github.com/google/uuid" +) + +// AddUser to save user information in database +func (p *provider) AddUser(user models.User) (models.User, error) { + if user.ID == "" { + user.ID = uuid.New().String() + } + + if user.Roles == "" { + user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + } + + user.CreatedAt = time.Now().Unix() + user.UpdatedAt = time.Now().Unix() + + bytes, err := json.Marshal(user) + if err != nil { + return user, err + } + + // use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling + decoder := json.NewDecoder(strings.NewReader(string(bytes))) + decoder.UseNumber() + userMap := map[string]interface{}{} + err = decoder.Decode(&userMap) + if err != nil { + return user, err + } + + fields := "(" + values := "(" + for key, value := range userMap { + if value != nil { + if key == "_id" { + fields += "id," + } else { + fields += key + "," + } + + valueType := reflect.TypeOf(value) + if valueType.Name() == "string" { + values += fmt.Sprintf("'%s',", value.(string)) + } else { + values += fmt.Sprintf("%v,", value) + } + } + } + + fields = fields[:len(fields)-1] + ")" + values = values[:len(values)-1] + ")" + + query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values) + + err = p.db.Query(query).Exec() + if err != nil { + return user, err + } + + return user, nil +} + +// UpdateUser to update user information in database +func (p *provider) UpdateUser(user models.User) (models.User, error) { + user.UpdatedAt = time.Now().Unix() + + bytes, err := json.Marshal(user) + if err != nil { + return user, err + } + // use decoder instead of json.Unmarshall, because it converts int64 -> float64 after unmarshalling + decoder := json.NewDecoder(strings.NewReader(string(bytes))) + decoder.UseNumber() + userMap := map[string]interface{}{} + err = decoder.Decode(&userMap) + if err != nil { + return user, err + } + + updateFields := "" + for key, value := range userMap { + if value != nil && key != "_id" { + } + + if key == "_id" { + continue + } + + if value == nil { + updateFields += fmt.Sprintf("%s = null,", key) + continue + } + + valueType := reflect.TypeOf(value) + if valueType.Name() == "string" { + updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string)) + } else { + updateFields += fmt.Sprintf("%s = %v, ", key, value) + } + } + updateFields = strings.Trim(updateFields, " ") + updateFields = strings.TrimSuffix(updateFields, ",") + + query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, updateFields, user.ID) + + err = p.db.Query(query).Exec() + if err != nil { + return user, err + } + + return user, nil +} + +// DeleteUser to delete user information from database +func (p *provider) DeleteUser(user models.User) error { + query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID) + err := p.db.Query(query).Exec() + return err +} + +// ListUsers to get list of users from database +func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { + responseUsers := []*model.User{} + paginationClone := pagination + totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User) + err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) + if err != nil { + return nil, err + } + + // there is no offset in cassandra + // so we fetch till limit + offset + // and return the results from offset to limit + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) + + scanner := p.db.Query(query).Iter().Scanner() + counter := int64(0) + for scanner.Next() { + if counter >= pagination.Offset { + var user models.User + err := scanner.Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) + if err != nil { + return nil, err + } + responseUsers = append(responseUsers, user.AsAPIUser()) + } + counter++ + } + return &model.Users{ + Users: responseUsers, + Pagination: &paginationClone, + }, nil +} + +// GetUserByEmail to get user information from database using email address +func (p *provider) GetUserByEmail(email string) (models.User, error) { + var user models.User + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, email) + err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) + if err != nil { + return user, err + } + return user, nil +} + +// GetUserByID to get user information from database using user ID +func (p *provider) GetUserByID(id string) (models.User, error) { + var user models.User + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id) + err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) + if err != nil { + return user, err + } + return user, nil +} diff --git a/server/db/providers/cassandradb/verification_requests.go b/server/db/providers/cassandradb/verification_requests.go new file mode 100644 index 0000000..6c82462 --- /dev/null +++ b/server/db/providers/cassandradb/verification_requests.go @@ -0,0 +1,102 @@ +package cassandradb + +import ( + "fmt" + "log" + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/gocql/gocql" + "github.com/google/uuid" +) + +// AddVerification to save verification request in database +func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { + if verificationRequest.ID == "" { + verificationRequest.ID = uuid.New().String() + } + + verificationRequest.CreatedAt = time.Now().Unix() + verificationRequest.UpdatedAt = time.Now().Unix() + + query := fmt.Sprintf("INSERT INTO %s (id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, '%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.VerificationRequest, verificationRequest.ID, verificationRequest.Token, verificationRequest.Identifier, verificationRequest.ExpiresAt, verificationRequest.Email, verificationRequest.Nonce, verificationRequest.RedirectURI, verificationRequest.CreatedAt, verificationRequest.UpdatedAt) + err := p.db.Query(query).Exec() + if err != nil { + return verificationRequest, err + } + return verificationRequest, nil +} + +// GetVerificationRequestByToken to get verification request from database using token +func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { + var verificationRequest models.VerificationRequest + query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token) + + err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) + if err != nil { + return verificationRequest, err + } + return verificationRequest, nil +} + +// GetVerificationRequestByEmail to get verification request by email from database +func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { + var verificationRequest models.VerificationRequest + query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier) + + err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) + if err != nil { + return verificationRequest, err + } + + return verificationRequest, nil +} + +// ListVerificationRequests to get list of verification requests from database +func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { + var verificationRequests []*model.VerificationRequest + + paginationClone := pagination + totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.VerificationRequest) + err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) + if err != nil { + log.Println("Error while quering verification request", err) + return nil, err + } + + // there is no offset in cassandra + // so we fetch till limit + offset + // and return the results from offset to limit + query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s LIMIT %d`, KeySpace+"."+models.Collections.VerificationRequest, pagination.Limit+pagination.Offset) + + scanner := p.db.Query(query).Iter().Scanner() + counter := int64(0) + for scanner.Next() { + if counter >= pagination.Offset { + var verificationRequest models.VerificationRequest + err := scanner.Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) + if err != nil { + log.Println("Error while parsing verification request", err) + return nil, err + } + verificationRequests = append(verificationRequests, verificationRequest.AsAPIVerificationRequest()) + } + counter++ + } + + return &model.VerificationRequests{ + VerificationRequests: verificationRequests, + Pagination: &paginationClone, + }, nil +} + +// DeleteVerificationRequest to delete verification request from database +func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { + query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.VerificationRequest, verificationRequest.ID) + err := p.db.Query(query).Exec() + if err != nil { + return err + } + return nil +} diff --git a/server/db/providers/mongodb/mongodb.go b/server/db/providers/mongodb/provider.go similarity index 100% rename from server/db/providers/mongodb/mongodb.go rename to server/db/providers/mongodb/provider.go diff --git a/server/db/providers/provider_template/env.go b/server/db/providers/provider_template/env.go new file mode 100644 index 0000000..0f31e8e --- /dev/null +++ b/server/db/providers/provider_template/env.go @@ -0,0 +1,32 @@ +package provider_template + +import ( + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/google/uuid" +) + +// AddEnv to save environment information in database +func (p *provider) AddEnv(env models.Env) (models.Env, error) { + if env.ID == "" { + env.ID = uuid.New().String() + } + + env.CreatedAt = time.Now().Unix() + env.UpdatedAt = time.Now().Unix() + return env, nil +} + +// UpdateEnv to update environment information in database +func (p *provider) UpdateEnv(env models.Env) (models.Env, error) { + env.UpdatedAt = time.Now().Unix() + return env, nil +} + +// GetEnv to get environment information from database +func (p *provider) GetEnv() (models.Env, error) { + var env models.Env + + return env, nil +} diff --git a/server/db/providers/provider_template/provider.go b/server/db/providers/provider_template/provider.go new file mode 100644 index 0000000..30490b4 --- /dev/null +++ b/server/db/providers/provider_template/provider.go @@ -0,0 +1,20 @@ +package provider_template + +import ( + "gorm.io/gorm" +) + +// TODO change following provider to new db provider +type provider struct { + db *gorm.DB +} + +// NewProvider returns a new SQL provider +// TODO change following provider to new db provider +func NewProvider() (*provider, error) { + var sqlDB *gorm.DB + + return &provider{ + db: sqlDB, + }, nil +} diff --git a/server/db/providers/provider_template/session.go b/server/db/providers/provider_template/session.go new file mode 100644 index 0000000..a914fdb --- /dev/null +++ b/server/db/providers/provider_template/session.go @@ -0,0 +1,24 @@ +package provider_template + +import ( + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/google/uuid" +) + +// AddSession to save session information in database +func (p *provider) AddSession(session models.Session) error { + if session.ID == "" { + session.ID = uuid.New().String() + } + + session.CreatedAt = time.Now().Unix() + session.UpdatedAt = time.Now().Unix() + return nil +} + +// DeleteSession to delete session information from database +func (p *provider) DeleteSession(userId string) error { + return nil +} diff --git a/server/db/providers/provider_template/user.go b/server/db/providers/provider_template/user.go new file mode 100644 index 0000000..07f6a06 --- /dev/null +++ b/server/db/providers/provider_template/user.go @@ -0,0 +1,58 @@ +package provider_template + +import ( + "strings" + "time" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/google/uuid" +) + +// AddUser to save user information in database +func (p *provider) AddUser(user models.User) (models.User, error) { + if user.ID == "" { + user.ID = uuid.New().String() + } + + if user.Roles == "" { + user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + } + + user.CreatedAt = time.Now().Unix() + user.UpdatedAt = time.Now().Unix() + + return user, nil +} + +// UpdateUser to update user information in database +func (p *provider) UpdateUser(user models.User) (models.User, error) { + user.UpdatedAt = time.Now().Unix() + return user, nil +} + +// DeleteUser to delete user information from database +func (p *provider) DeleteUser(user models.User) error { + return nil +} + +// ListUsers to get list of users from database +func (p *provider) ListUsers(pagination model.Pagination) (*model.Users, error) { + return nil, nil +} + +// GetUserByEmail to get user information from database using email address +func (p *provider) GetUserByEmail(email string) (models.User, error) { + var user models.User + + return user, nil +} + +// GetUserByID to get user information from database using user ID +func (p *provider) GetUserByID(id string) (models.User, error) { + var user models.User + + return user, nil +} diff --git a/server/db/providers/provider_template/verification_requests.go b/server/db/providers/provider_template/verification_requests.go new file mode 100644 index 0000000..c0b4b2a --- /dev/null +++ b/server/db/providers/provider_template/verification_requests.go @@ -0,0 +1,45 @@ +package provider_template + +import ( + "time" + + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/google/uuid" +) + +// AddVerification to save verification request in database +func (p *provider) AddVerificationRequest(verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { + if verificationRequest.ID == "" { + verificationRequest.ID = uuid.New().String() + } + + verificationRequest.CreatedAt = time.Now().Unix() + verificationRequest.UpdatedAt = time.Now().Unix() + + return verificationRequest, nil +} + +// GetVerificationRequestByToken to get verification request from database using token +func (p *provider) GetVerificationRequestByToken(token string) (models.VerificationRequest, error) { + var verificationRequest models.VerificationRequest + + return verificationRequest, nil +} + +// GetVerificationRequestByEmail to get verification request by email from database +func (p *provider) GetVerificationRequestByEmail(email string, identifier string) (models.VerificationRequest, error) { + var verificationRequest models.VerificationRequest + + return verificationRequest, nil +} + +// ListVerificationRequests to get list of verification requests from database +func (p *provider) ListVerificationRequests(pagination model.Pagination) (*model.VerificationRequests, error) { + return nil, nil +} + +// DeleteVerificationRequest to delete verification request from database +func (p *provider) DeleteVerificationRequest(verificationRequest models.VerificationRequest) error { + return nil +} diff --git a/server/db/providers/sql/sql.go b/server/db/providers/sql/provider.go similarity index 100% rename from server/db/providers/sql/sql.go rename to server/db/providers/sql/provider.go diff --git a/server/email/invite_email.go b/server/email/invite_email.go index 5cbd1c9..bdebd81 100644 --- a/server/email/invite_email.go +++ b/server/email/invite_email.go @@ -107,7 +107,7 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error { err := SendMail(Receiver, Subject, message) if err != nil { - log.Println("=> error sending email:", err) + log.Println("error sending email:", err) } return err } diff --git a/server/email/verification_email.go b/server/email/verification_email.go index bb0881f..c373151 100644 --- a/server/email/verification_email.go +++ b/server/email/verification_email.go @@ -107,7 +107,7 @@ func SendVerificationMail(toEmail, token, hostname string) error { err := SendMail(Receiver, Subject, message) if err != nil { - log.Println("=> error sending email:", err) + log.Println("error sending email:", err) } return err } diff --git a/server/env/env.go b/server/env/env.go index 7202d9b..aeb5c27 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -38,6 +38,10 @@ func InitRequiredEnv() error { dbURL := os.Getenv(constants.EnvKeyDatabaseURL) dbType := os.Getenv(constants.EnvKeyDatabaseType) dbName := os.Getenv(constants.EnvKeyDatabaseName) + dbPort := os.Getenv(constants.EnvKeyDatabasePort) + dbHost := os.Getenv(constants.EnvKeyDatabaseHost) + dbUsername := os.Getenv(constants.EnvKeyDatabaseUsername) + dbPassword := os.Getenv(constants.EnvKeyDatabasePassword) if strings.TrimSpace(dbType) == "" { if envstore.ARG_DB_TYPE != nil && *envstore.ARG_DB_TYPE != "" { @@ -54,7 +58,7 @@ func InitRequiredEnv() error { dbURL = strings.TrimSpace(*envstore.ARG_DB_URL) } - if dbURL == "" { + if dbURL == "" && dbPort == "" && dbHost == "" && dbUsername == "" && dbPassword == "" { return errors.New("invalid database url. DATABASE_URL is required") } } @@ -69,6 +73,10 @@ func InitRequiredEnv() error { envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseURL, dbURL) envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseType, dbType) envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseName, dbName) + envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseHost, dbHost) + envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabasePort, dbPort) + envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseUsername, dbUsername) + envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabasePassword, dbPassword) return nil } diff --git a/server/go.mod b/server/go.mod index 92c9882..7c341f5 100644 --- a/server/go.mod +++ b/server/go.mod @@ -9,6 +9,7 @@ require ( github.com/gin-gonic/gin v1.7.2 github.com/go-playground/validator/v10 v10.8.0 // indirect github.com/go-redis/redis/v8 v8.11.0 + github.com/gocql/gocql v1.0.0 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.2 // indirect github.com/google/uuid v1.3.0 diff --git a/server/go.sum b/server/go.sum index 1b2daa1..e386cbd 100644 --- a/server/go.sum +++ b/server/go.sum @@ -48,6 +48,10 @@ github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e h1:Xg+hGrY2 github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e/go.mod h1:mq7Shfa/CaixoDxiyAAc5jZ6CVBAyPaNQCGS7mkj4Ho= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -106,6 +110,8 @@ github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfC github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gocql/gocql v1.0.0 h1:UnbTERpP72VZ/viKE1Q1gPtmLvyTZTvuAstvSRydw/c= +github.com/gocql/gocql v1.0.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -140,8 +146,9 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -175,6 +182,8 @@ github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB github.com/gorilla/mux v1.6.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -677,6 +686,8 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/mail.v2 v2.3.1 h1:WYFn/oANrAGP2C0dcV6/pbkPzv8yGzqTjPmTeO7qoXk= gopkg.in/mail.v2 v2.3.1/go.mod h1:htwXN1Qh09vZJ1NVKxQqHPBaCBbzKhp5GzuJEA4VJWw= gopkg.in/readline.v1 v1.0.0-20160726135117-62c6fe619375/go.mod h1:lNEQeAhU009zbRxng+XOj5ITVgY24WcbNnQopyfKoYQ= diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index e33e660..c8cb538 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -61,9 +61,13 @@ type ComplexityRoot struct { ClientSecret func(childComplexity int) int CookieName func(childComplexity int) int CustomAccessTokenScript func(childComplexity int) int + DatabaseHost func(childComplexity int) int DatabaseName func(childComplexity int) int + DatabasePassword func(childComplexity int) int + DatabasePort func(childComplexity int) int DatabaseType func(childComplexity int) int DatabaseURL func(childComplexity int) int + DatabaseUsername func(childComplexity int) int DefaultRoles func(childComplexity int) int DisableBasicAuthentication func(childComplexity int) int DisableEmailVerification func(childComplexity int) int @@ -356,6 +360,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Env.CustomAccessTokenScript(childComplexity), true + case "Env.DATABASE_HOST": + if e.complexity.Env.DatabaseHost == nil { + break + } + + return e.complexity.Env.DatabaseHost(childComplexity), true + case "Env.DATABASE_NAME": if e.complexity.Env.DatabaseName == nil { break @@ -363,6 +374,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Env.DatabaseName(childComplexity), true + case "Env.DATABASE_PASSWORD": + if e.complexity.Env.DatabasePassword == nil { + break + } + + return e.complexity.Env.DatabasePassword(childComplexity), true + + case "Env.DATABASE_PORT": + if e.complexity.Env.DatabasePort == nil { + break + } + + return e.complexity.Env.DatabasePort(childComplexity), true + case "Env.DATABASE_TYPE": if e.complexity.Env.DatabaseType == nil { break @@ -377,6 +402,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Env.DatabaseURL(childComplexity), true + case "Env.DATABASE_USERNAME": + if e.complexity.Env.DatabaseUsername == nil { + break + } + + return e.complexity.Env.DatabaseUsername(childComplexity), true + case "Env.DEFAULT_ROLES": if e.complexity.Env.DefaultRoles == nil { break @@ -1394,6 +1426,10 @@ type Env { DATABASE_NAME: String! DATABASE_URL: String! DATABASE_TYPE: String! + DATABASE_USERNAME: String! + DATABASE_PASSWORD: String! + DATABASE_HOST: String! + DATABASE_PORT: String! CLIENT_ID: String! CLIENT_SECRET: String! CUSTOM_ACCESS_TOKEN_SCRIPT: String @@ -2400,6 +2436,146 @@ func (ec *executionContext) _Env_DATABASE_TYPE(ctx context.Context, field graphq return ec.marshalNString2string(ctx, field.Selections, res) } +func (ec *executionContext) _Env_DATABASE_USERNAME(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Env", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.DatabaseUsername, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) _Env_DATABASE_PASSWORD(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Env", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.DatabasePassword, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) _Env_DATABASE_HOST(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Env", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.DatabaseHost, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) _Env_DATABASE_PORT(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Env", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.DatabasePort, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + func (ec *executionContext) _Env_CLIENT_ID(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -8780,6 +8956,26 @@ func (ec *executionContext) _Env(ctx context.Context, sel ast.SelectionSet, obj if out.Values[i] == graphql.Null { invalids++ } + case "DATABASE_USERNAME": + out.Values[i] = ec._Env_DATABASE_USERNAME(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "DATABASE_PASSWORD": + out.Values[i] = ec._Env_DATABASE_PASSWORD(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "DATABASE_HOST": + out.Values[i] = ec._Env_DATABASE_HOST(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "DATABASE_PORT": + out.Values[i] = ec._Env_DATABASE_PORT(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "CLIENT_ID": out.Values[i] = ec._Env_CLIENT_ID(ctx, field, obj) if out.Values[i] == graphql.Null { diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 09468b4..1660c9c 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -29,6 +29,10 @@ type Env struct { DatabaseName string `json:"DATABASE_NAME"` DatabaseURL string `json:"DATABASE_URL"` DatabaseType string `json:"DATABASE_TYPE"` + DatabaseUsername string `json:"DATABASE_USERNAME"` + DatabasePassword string `json:"DATABASE_PASSWORD"` + DatabaseHost string `json:"DATABASE_HOST"` + DatabasePort string `json:"DATABASE_PORT"` ClientID string `json:"CLIENT_ID"` ClientSecret string `json:"CLIENT_SECRET"` CustomAccessTokenScript *string `json:"CUSTOM_ACCESS_TOKEN_SCRIPT"` diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 841ad8c..d3673cd 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -92,6 +92,10 @@ type Env { DATABASE_NAME: String! DATABASE_URL: String! DATABASE_TYPE: String! + DATABASE_USERNAME: String! + DATABASE_PASSWORD: String! + DATABASE_HOST: String! + DATABASE_PORT: String! CLIENT_ID: String! CLIENT_SECRET: String! CUSTOM_ACCESS_TOKEN_SCRIPT: String diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index bfa4f00..936e618 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -259,7 +259,7 @@ func processGithubUserInfo(code string) (models.User, error) { GivenName: &firstName, FamilyName: &lastName, Picture: &picture, - Email: userRawData["sub"], + Email: userRawData["email"], } return user, nil diff --git a/server/resolvers/env.go b/server/resolvers/env.go index dc7db8d..b8c2d3d 100644 --- a/server/resolvers/env.go +++ b/server/resolvers/env.go @@ -34,6 +34,10 @@ func EnvResolver(ctx context.Context) (*model.Env, error) { databaseURL := store.StringEnv[constants.EnvKeyDatabaseURL] databaseName := store.StringEnv[constants.EnvKeyDatabaseName] databaseType := store.StringEnv[constants.EnvKeyDatabaseType] + databaseUsername := store.StringEnv[constants.EnvKeyDatabaseUsername] + databasePassword := store.StringEnv[constants.EnvKeyDatabasePassword] + databaseHost := store.StringEnv[constants.EnvKeyDatabaseHost] + databasePort := store.StringEnv[constants.EnvKeyDatabasePort] customAccessTokenScript := store.StringEnv[constants.EnvKeyCustomAccessTokenScript] smtpHost := store.StringEnv[constants.EnvKeySmtpHost] smtpPort := store.StringEnv[constants.EnvKeySmtpPort] @@ -77,6 +81,10 @@ func EnvResolver(ctx context.Context) (*model.Env, error) { DatabaseName: databaseName, DatabaseURL: databaseURL, DatabaseType: databaseType, + DatabaseUsername: databaseUsername, + DatabasePassword: databasePassword, + DatabaseHost: databaseHost, + DatabasePort: databasePort, ClientID: clientID, ClientSecret: clientSecret, CustomAccessTokenScript: &customAccessTokenScript, diff --git a/server/test/enable_access_test.go b/server/test/enable_access_test.go index c54f91b..6d06153 100644 --- a/server/test/enable_access_test.go +++ b/server/test/enable_access_test.go @@ -15,9 +15,9 @@ import ( func enableAccessTest(t *testing.T, s TestSetup) { t.Helper() - t.Run(`should revoke access`, func(t *testing.T) { + t.Run(`should enable access`, func(t *testing.T) { req, ctx := createContext(s) - email := "revoke_access." + s.TestInfo.Email + email := "enable_access." + s.TestInfo.Email _, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) @@ -45,7 +45,7 @@ func enableAccessTest(t *testing.T, s TestSetup) { assert.NoError(t, err) assert.NotEmpty(t, res.Message) - // it should allow login with revoked access + // it should allow login with enabled access res, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index 40812b1..84a5c76 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -11,9 +11,10 @@ import ( func TestResolvers(t *testing.T) { databases := map[string]string{ - constants.DbTypeSqlite: "../../data.db", + // constants.DbTypeSqlite: "../../data.db", // constants.DbTypeArangodb: "http://localhost:8529", // constants.DbTypeMongodb: "mongodb://localhost:27017", + constants.DbTypeCassandraDB: "127.0.0.1:9042", } for dbType, dbURL := range databases { diff --git a/server/test/test.go b/server/test/test.go index ff0bdb3..2c673b6 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -31,26 +31,31 @@ type TestSetup struct { } func cleanData(email string) { - verificationRequest, err := db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeBasicAuthSignup) - if err == nil { - err = db.Provider.DeleteVerificationRequest(verificationRequest) - } + // verificationRequest, err := db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeBasicAuthSignup) + // if err == nil { + // err = db.Provider.DeleteVerificationRequest(verificationRequest) + // } - verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeForgotPassword) - if err == nil { - err = db.Provider.DeleteVerificationRequest(verificationRequest) - } + // verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeForgotPassword) + // if err == nil { + // err = db.Provider.DeleteVerificationRequest(verificationRequest) + // } - verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeUpdateEmail) - if err == nil { - err = db.Provider.DeleteVerificationRequest(verificationRequest) - } + // verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeUpdateEmail) + // if err == nil { + // err = db.Provider.DeleteVerificationRequest(verificationRequest) + // } - dbUser, err := db.Provider.GetUserByEmail(email) - if err == nil { - db.Provider.DeleteUser(dbUser) - db.Provider.DeleteSession(dbUser.ID) - } + // verificationRequest, err = db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeMagicLinkLogin) + // if err == nil { + // err = db.Provider.DeleteVerificationRequest(verificationRequest) + // } + + // dbUser, err := db.Provider.GetUserByEmail(email) + // if err == nil { + // db.Provider.DeleteUser(dbUser) + // db.Provider.DeleteSession(dbUser.ID) + // } } func createContext(s TestSetup) (*http.Request, context.Context) {