2022-01-21 06:48:07 +00:00
package sql
import (
2022-12-29 14:44:06 +00:00
"fmt"
"log"
"os"
2022-02-05 03:30:56 +00:00
"time"
2022-01-21 06:48:07 +00:00
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db/models"
2022-05-29 11:52:46 +00:00
"github.com/authorizerdev/authorizer/server/memorystore"
2022-01-21 06:48:07 +00:00
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
2022-12-29 14:44:06 +00:00
"gorm.io/driver/sqlite"
2022-01-21 06:48:07 +00:00
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
2022-02-05 03:30:56 +00:00
"gorm.io/gorm/logger"
2022-01-21 06:48:07 +00:00
"gorm.io/gorm/schema"
)
type provider struct {
db * gorm . DB
}
2022-10-21 16:28:56 +00:00
const (
phoneNumberIndexName = "UQ_phone_number"
phoneNumberColumnName = "phone_number"
)
type indexInfo struct {
IndexName string ` json:"index_name" `
ColumnName string ` json:"column_name" `
}
2022-01-21 06:48:07 +00:00
// NewProvider returns a new SQL provider
func NewProvider ( ) ( * provider , error ) {
var sqlDB * gorm . DB
var err error
2022-02-05 03:30:56 +00:00
customLogger := logger . New (
2022-12-29 14:44:06 +00:00
log . New ( os . Stdout , "\r\n" , log . LstdFlags ) , // io writer
2022-02-05 03:30:56 +00:00
logger . Config {
2022-12-29 14:44:06 +00:00
SlowThreshold : time . Second , // Slow SQL threshold
LogLevel : logger . Silent , // Log level
IgnoreRecordNotFoundError : true , // Ignore ErrRecordNotFound error for logger
Colorful : false , // Disable color
2022-02-05 03:30:56 +00:00
} ,
)
2022-01-21 06:48:07 +00:00
ormConfig := & gorm . Config {
2022-02-05 03:30:56 +00:00
Logger : customLogger ,
2022-01-21 06:48:07 +00:00
NamingStrategy : schema . NamingStrategy {
TablePrefix : models . Prefix ,
} ,
2022-08-02 08:42:36 +00:00
AllowGlobalUpdate : true ,
2022-01-21 06:48:07 +00:00
}
2022-05-29 11:52:46 +00:00
2022-05-31 02:44:03 +00:00
dbType := memorystore . RequiredEnvStoreObj . GetRequiredEnv ( ) . DatabaseType
dbURL := memorystore . RequiredEnvStoreObj . GetRequiredEnv ( ) . DatabaseURL
2022-05-29 11:52:46 +00:00
switch dbType {
2022-06-09 18:13:21 +00:00
case constants . DbTypePostgres , constants . DbTypeYugabyte , constants . DbTypeCockroachDB :
2022-05-29 11:52:46 +00:00
sqlDB , err = gorm . Open ( postgres . Open ( dbURL ) , ormConfig )
2022-01-21 06:48:07 +00:00
case constants . DbTypeSqlite :
2022-12-29 14:44:06 +00:00
sqlDB , err = gorm . Open ( sqlite . Open ( dbURL ) , ormConfig )
2022-07-11 17:07:07 +00:00
case constants . DbTypeMysql , constants . DbTypeMariaDB , constants . DbTypePlanetScaleDB :
2022-05-29 11:52:46 +00:00
sqlDB , err = gorm . Open ( mysql . Open ( dbURL ) , ormConfig )
2022-01-21 06:48:07 +00:00
case constants . DbTypeSqlserver :
2022-05-29 11:52:46 +00:00
sqlDB , err = gorm . Open ( sqlserver . Open ( dbURL ) , ormConfig )
2022-01-21 06:48:07 +00:00
}
if err != nil {
return nil , err
}
2022-07-23 10:25:06 +00:00
err = sqlDB . AutoMigrate ( & models . User { } , & models . VerificationRequest { } , & models . Session { } , & models . Env { } , & models . Webhook { } , models . WebhookLog { } , models . EmailTemplate { } , & models . OTP { } )
2022-03-02 12:12:31 +00:00
if err != nil {
return nil , err
}
2022-10-21 16:28:56 +00:00
// unique constraint on phone number does not work with multiple null values for sqlserver
// for more information check https://stackoverflow.com/a/767702
2022-12-29 14:44:06 +00:00
if dbType == constants . DbTypeSqlserver {
var indexInfos [ ] indexInfo
// remove index on phone number if present with different name
res := sqlDB . Raw ( "SELECT i.name AS index_name, i.type_desc AS index_algorithm, CASE i.is_unique WHEN 1 THEN 'TRUE' ELSE 'FALSE' END AS is_unique, ac.Name AS column_name FROM sys.tables AS t INNER JOIN sys.indexes AS i ON t.object_id = i.object_id INNER JOIN sys.index_columns AS ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id INNER JOIN sys.all_columns AS ac ON ic.object_id = ac.object_id AND ic.column_id = ac.column_id WHERE t.name = 'authorizer_users' AND SCHEMA_NAME(t.schema_id) = 'dbo';" ) . Scan ( & indexInfos )
if res . Error != nil {
return nil , res . Error
}
for _ , val := range indexInfos {
if val . ColumnName == phoneNumberColumnName && val . IndexName != phoneNumberIndexName {
// drop index & create new
if res := sqlDB . Exec ( fmt . Sprintf ( ` ALTER TABLE authorizer_users DROP CONSTRAINT "%s"; ` , val . IndexName ) ) ; res . Error != nil {
return nil , res . Error
}
// create index
if res := sqlDB . Exec ( fmt . Sprintf ( "CREATE UNIQUE NONCLUSTERED INDEX %s ON authorizer_users(phone_number) WHERE phone_number IS NOT NULL;" , phoneNumberIndexName ) ) ; res . Error != nil {
return nil , res . Error
}
}
}
}
2022-10-21 16:28:56 +00:00
2022-01-21 06:48:07 +00:00
return & provider {
db : sqlDB ,
} , nil
}