parent
bdbbe4adee
commit
8f7582e1ec
2
Makefile
2
Makefile
|
@ -6,4 +6,4 @@ cmd:
|
||||||
clean:
|
clean:
|
||||||
rm -rf build
|
rm -rf build
|
||||||
test:
|
test:
|
||||||
cd server && go clean --testcache && go test ./...
|
cd server && go clean --testcache && go test -v ./...
|
17
server/env/env.go
vendored
17
server/env/env.go
vendored
|
@ -87,15 +87,30 @@ func InitEnv() {
|
||||||
|
|
||||||
allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
|
allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
|
||||||
allowedOrigins := []string{}
|
allowedOrigins := []string{}
|
||||||
|
hasWildCard := false
|
||||||
|
|
||||||
for _, val := range allowedOriginsSplit {
|
for _, val := range allowedOriginsSplit {
|
||||||
trimVal := strings.TrimSpace(val)
|
trimVal := strings.TrimSpace(val)
|
||||||
if trimVal != "" {
|
if trimVal != "" {
|
||||||
allowedOrigins = append(allowedOrigins, trimVal)
|
if trimVal != "*" {
|
||||||
|
host, port := utils.GetHostParts(trimVal)
|
||||||
|
allowedOrigins = append(allowedOrigins, host+":"+port)
|
||||||
|
} else {
|
||||||
|
hasWildCard = true
|
||||||
|
allowedOrigins = append(allowedOrigins, trimVal)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(allowedOrigins) > 1 && hasWildCard {
|
||||||
|
allowedOrigins = []string{"*"}
|
||||||
|
}
|
||||||
|
|
||||||
if len(allowedOrigins) == 0 {
|
if len(allowedOrigins) == 0 {
|
||||||
allowedOrigins = []string{"*"}
|
allowedOrigins = []string{"*"}
|
||||||
}
|
}
|
||||||
|
|
||||||
constants.ALLOWED_ORIGINS = allowedOrigins
|
constants.ALLOWED_ORIGINS = allowedOrigins
|
||||||
|
|
||||||
if *ARG_AUTHORIZER_URL != "" {
|
if *ARG_AUTHORIZER_URL != "" {
|
||||||
|
|
|
@ -49,7 +49,7 @@ func AppHandler() gin.HandlerFunc {
|
||||||
stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/")
|
stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/")
|
||||||
|
|
||||||
// validate redirect url with allowed origins
|
// validate redirect url with allowed origins
|
||||||
if !utils.IsValidRedirectURL(stateObj.RedirectURL) {
|
if !utils.IsValidOrigin(stateObj.RedirectURL) {
|
||||||
c.JSON(400, gin.H{"error": "invalid redirect url"})
|
c.JSON(400, gin.H{"error": "invalid redirect url"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
44
server/integration_test/cors_test.go
Normal file
44
server/integration_test/cors_test.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package integration_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/authorizerdev/authorizer/server/constants"
|
||||||
|
"github.com/authorizerdev/authorizer/server/env"
|
||||||
|
"github.com/authorizerdev/authorizer/server/middlewares"
|
||||||
|
"github.com/gin-contrib/location"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCors(t *testing.T) {
|
||||||
|
constants.ENV_PATH = "../../.env.local"
|
||||||
|
env.InitEnv()
|
||||||
|
r := gin.Default()
|
||||||
|
r.Use(location.Default())
|
||||||
|
r.Use(middlewares.GinContextToContextMiddleware())
|
||||||
|
r.Use(middlewares.CORSMiddleware())
|
||||||
|
allowedOrigin := "http://localhost:8080" // The allowed origin that you want to check
|
||||||
|
notAllowedOrigin := "http://myapp.com"
|
||||||
|
|
||||||
|
server := httptest.NewServer(r)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
req, _ := http.NewRequest(
|
||||||
|
"GET",
|
||||||
|
"http://"+server.Listener.Addr().String()+"/api",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
req.Header.Add("Origin", allowedOrigin)
|
||||||
|
|
||||||
|
get, _ := client.Do(req)
|
||||||
|
|
||||||
|
// You should get your origin (or a * depending on your config) if the
|
||||||
|
// passed origin is allowed.
|
||||||
|
o := get.Header.Get("Access-Control-Allow-Origin")
|
||||||
|
assert.NotEqual(t, o, notAllowedOrigin, "Origins should not match")
|
||||||
|
assert.Equal(t, o, allowedOrigin, "Origins don't match")
|
||||||
|
}
|
|
@ -1,13 +1,10 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/authorizerdev/authorizer/server/constants"
|
|
||||||
"github.com/authorizerdev/authorizer/server/db"
|
"github.com/authorizerdev/authorizer/server/db"
|
||||||
"github.com/authorizerdev/authorizer/server/env"
|
"github.com/authorizerdev/authorizer/server/env"
|
||||||
"github.com/authorizerdev/authorizer/server/handlers"
|
"github.com/authorizerdev/authorizer/server/handlers"
|
||||||
|
"github.com/authorizerdev/authorizer/server/middlewares"
|
||||||
"github.com/authorizerdev/authorizer/server/oauth"
|
"github.com/authorizerdev/authorizer/server/oauth"
|
||||||
"github.com/authorizerdev/authorizer/server/session"
|
"github.com/authorizerdev/authorizer/server/session"
|
||||||
"github.com/authorizerdev/authorizer/server/utils"
|
"github.com/authorizerdev/authorizer/server/utils"
|
||||||
|
@ -15,39 +12,6 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GinContextToContextMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
if constants.AUTHORIZER_URL == "" {
|
|
||||||
url := location.Get(c)
|
|
||||||
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
|
|
||||||
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
|
|
||||||
}
|
|
||||||
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO use allowed origins for cors origin
|
|
||||||
// TODO throw error if url is not allowed
|
|
||||||
func CORSMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
origin := c.Request.Header.Get("Origin")
|
|
||||||
constants.APP_URL = origin
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
|
|
||||||
|
|
||||||
if c.Request.Method == "OPTIONS" {
|
|
||||||
c.AbortWithStatus(204)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
env.InitEnv()
|
env.InitEnv()
|
||||||
db.InitDB()
|
db.InitDB()
|
||||||
|
@ -57,8 +21,8 @@ func main() {
|
||||||
|
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.Use(location.Default())
|
r.Use(location.Default())
|
||||||
r.Use(GinContextToContextMiddleware())
|
r.Use(middlewares.GinContextToContextMiddleware())
|
||||||
r.Use(CORSMiddleware())
|
r.Use(middlewares.CORSMiddleware())
|
||||||
|
|
||||||
r.GET("/", handlers.PlaygroundHandler())
|
r.GET("/", handlers.PlaygroundHandler())
|
||||||
r.POST("/graphql", handlers.GraphqlHandler())
|
r.POST("/graphql", handlers.GraphqlHandler())
|
||||||
|
|
23
server/middlewares/context.go
Normal file
23
server/middlewares/context.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/authorizerdev/authorizer/server/constants"
|
||||||
|
"github.com/gin-contrib/location"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GinContextToContextMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if constants.AUTHORIZER_URL == "" {
|
||||||
|
url := location.Get(c)
|
||||||
|
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
|
||||||
|
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
29
server/middlewares/cors.go
Normal file
29
server/middlewares/cors.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/authorizerdev/authorizer/server/constants"
|
||||||
|
"github.com/authorizerdev/authorizer/server/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CORSMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
origin := c.Request.Header.Get("Origin")
|
||||||
|
constants.APP_URL = origin
|
||||||
|
|
||||||
|
if utils.IsValidOrigin(origin) {
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
|
||||||
|
|
||||||
|
if c.Request.Method == "OPTIONS" {
|
||||||
|
c.AbortWithStatus(204)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,7 +10,7 @@ import (
|
||||||
func SetCookie(gc *gin.Context, token string) {
|
func SetCookie(gc *gin.Context, token string) {
|
||||||
secure := true
|
secure := true
|
||||||
httpOnly := true
|
httpOnly := true
|
||||||
host := GetHostName(constants.AUTHORIZER_URL)
|
host, _ := GetHostParts(constants.AUTHORIZER_URL)
|
||||||
domain := GetDomainName(constants.AUTHORIZER_URL)
|
domain := GetDomainName(constants.AUTHORIZER_URL)
|
||||||
if domain != "localhost" {
|
if domain != "localhost" {
|
||||||
domain = "." + domain
|
domain = "." + domain
|
||||||
|
@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) {
|
||||||
secure := true
|
secure := true
|
||||||
httpOnly := true
|
httpOnly := true
|
||||||
|
|
||||||
host := GetDomainName(constants.AUTHORIZER_URL)
|
host, _ := GetHostParts(constants.AUTHORIZER_URL)
|
||||||
domain := GetDomainName(constants.AUTHORIZER_URL)
|
domain := GetDomainName(constants.AUTHORIZER_URL)
|
||||||
if domain != "localhost" {
|
if domain != "localhost" {
|
||||||
domain = "." + domain
|
domain = "." + domain
|
||||||
|
|
|
@ -5,21 +5,32 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetHostName function to get hostname
|
// GetHostName function returns hostname and port
|
||||||
func GetHostName(auth_url string) string {
|
func GetHostParts(uri string) (string, string) {
|
||||||
u, err := url.Parse(auth_url)
|
tempURI := uri
|
||||||
|
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
|
||||||
|
tempURI = "https://" + tempURI
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(tempURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return `localhost`
|
return "localhost", "8080"
|
||||||
}
|
}
|
||||||
|
|
||||||
host := u.Hostname()
|
host := u.Hostname()
|
||||||
|
port := u.Port()
|
||||||
|
|
||||||
return host
|
return host, port
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDomainName function to get domain name
|
// GetDomainName function to get domain name
|
||||||
func GetDomainName(auth_url string) string {
|
func GetDomainName(uri string) string {
|
||||||
u, err := url.Parse(auth_url)
|
tempURI := uri
|
||||||
|
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
|
||||||
|
tempURI = "https://" + tempURI
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(tempURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return `localhost`
|
return `localhost`
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,12 +7,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetHostName(t *testing.T) {
|
func TestGetHostName(t *testing.T) {
|
||||||
authorizer_url := "http://test.herokuapp.com"
|
authorizer_url := "http://test.herokuapp.com:80"
|
||||||
|
|
||||||
got := GetHostName(authorizer_url)
|
host, port := GetHostParts(authorizer_url)
|
||||||
want := "test.herokuapp.com"
|
expectedHost := "test.herokuapp.com"
|
||||||
|
|
||||||
assert.Equal(t, got, want, "hostname should be equal")
|
assert.Equal(t, host, expectedHost, "hostname should be equal")
|
||||||
|
assert.Equal(t, port, "80", "port should be 80")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetDomainName(t *testing.T) {
|
func TestGetDomainName(t *testing.T) {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/mail"
|
"net/mail"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/authorizerdev/authorizer/server/constants"
|
"github.com/authorizerdev/authorizer/server/constants"
|
||||||
|
@ -13,16 +14,32 @@ func IsValidEmail(email string) bool {
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsValidRedirectURL(url string) bool {
|
func IsValidOrigin(url string) bool {
|
||||||
if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" {
|
if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
hasValidURL := false
|
hasValidURL := false
|
||||||
urlDomain := GetDomainName(url)
|
hostName, port := GetHostParts(url)
|
||||||
|
currentOrigin := hostName + ":" + port
|
||||||
|
|
||||||
for _, val := range constants.ALLOWED_ORIGINS {
|
for _, origin := range constants.ALLOWED_ORIGINS {
|
||||||
if strings.Contains(val, urlDomain) {
|
replacedString := origin
|
||||||
|
// if has regex whitelisted domains
|
||||||
|
if strings.Contains(origin, "*") {
|
||||||
|
replacedString = strings.Replace(origin, ".", "\\.", -1)
|
||||||
|
replacedString = strings.Replace(replacedString, "*", ".*", -1)
|
||||||
|
|
||||||
|
if strings.HasPrefix(replacedString, ".*") {
|
||||||
|
replacedString += "\\b"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(replacedString, ".*") {
|
||||||
|
replacedString = "\\b" + replacedString
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched {
|
||||||
hasValidURL = true
|
hasValidURL = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package utils
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/authorizerdev/authorizer/server/constants"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,3 +16,19 @@ func TestIsValidEmail(t *testing.T) {
|
||||||
assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email")
|
assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email")
|
||||||
assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email")
|
assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsValidOrigin(t *testing.T) {
|
||||||
|
// don't use portocal(http/https) for ALLOWED_ORIGINS while testing,
|
||||||
|
// as we trim them off while running the main function
|
||||||
|
constants.ALLOWED_ORIGINS = []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"}
|
||||||
|
|
||||||
|
assert.False(t, IsValidOrigin("http://myapp.com"), "it should be invalid origin")
|
||||||
|
assert.False(t, IsValidOrigin("http://appgoogle.com"), "it should be invalid origin")
|
||||||
|
assert.True(t, IsValidOrigin("http://app.google.com"), "it should be valid origin")
|
||||||
|
assert.False(t, IsValidOrigin("http://app.google.ind"), "it should be invalid origin")
|
||||||
|
assert.True(t, IsValidOrigin("http://app.google.in"), "it should be valid origin")
|
||||||
|
assert.True(t, IsValidOrigin("http://xyx.abc.com"), "it should be valid origin")
|
||||||
|
assert.True(t, IsValidOrigin("http://xyx.abc.in"), "it should be valid origin")
|
||||||
|
assert.True(t, IsValidOrigin("http://xyxabc.in"), "it should be valid origin")
|
||||||
|
assert.True(t, IsValidOrigin("http://localhost:8080"), "it should be valid origin")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user