parent
bdbbe4adee
commit
8f7582e1ec
2
Makefile
2
Makefile
|
@ -6,4 +6,4 @@ cmd:
|
|||
clean:
|
||||
rm -rf build
|
||||
test:
|
||||
cd server && go clean --testcache && go test ./...
|
||||
cd server && go clean --testcache && go test -v ./...
|
17
server/env/env.go
vendored
17
server/env/env.go
vendored
|
@ -87,15 +87,30 @@ func InitEnv() {
|
|||
|
||||
allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
|
||||
allowedOrigins := []string{}
|
||||
hasWildCard := false
|
||||
|
||||
for _, val := range allowedOriginsSplit {
|
||||
trimVal := strings.TrimSpace(val)
|
||||
if trimVal != "" {
|
||||
allowedOrigins = append(allowedOrigins, trimVal)
|
||||
if trimVal != "*" {
|
||||
host, port := utils.GetHostParts(trimVal)
|
||||
allowedOrigins = append(allowedOrigins, host+":"+port)
|
||||
} else {
|
||||
hasWildCard = true
|
||||
allowedOrigins = append(allowedOrigins, trimVal)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedOrigins) > 1 && hasWildCard {
|
||||
allowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
if len(allowedOrigins) == 0 {
|
||||
allowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
constants.ALLOWED_ORIGINS = allowedOrigins
|
||||
|
||||
if *ARG_AUTHORIZER_URL != "" {
|
||||
|
|
|
@ -49,7 +49,7 @@ func AppHandler() gin.HandlerFunc {
|
|||
stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/")
|
||||
|
||||
// validate redirect url with allowed origins
|
||||
if !utils.IsValidRedirectURL(stateObj.RedirectURL) {
|
||||
if !utils.IsValidOrigin(stateObj.RedirectURL) {
|
||||
c.JSON(400, gin.H{"error": "invalid redirect url"})
|
||||
return
|
||||
}
|
||||
|
|
44
server/integration_test/cors_test.go
Normal file
44
server/integration_test/cors_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package integration_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
"github.com/authorizerdev/authorizer/server/env"
|
||||
"github.com/authorizerdev/authorizer/server/middlewares"
|
||||
"github.com/gin-contrib/location"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCors(t *testing.T) {
|
||||
constants.ENV_PATH = "../../.env.local"
|
||||
env.InitEnv()
|
||||
r := gin.Default()
|
||||
r.Use(location.Default())
|
||||
r.Use(middlewares.GinContextToContextMiddleware())
|
||||
r.Use(middlewares.CORSMiddleware())
|
||||
allowedOrigin := "http://localhost:8080" // The allowed origin that you want to check
|
||||
notAllowedOrigin := "http://myapp.com"
|
||||
|
||||
server := httptest.NewServer(r)
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req, _ := http.NewRequest(
|
||||
"GET",
|
||||
"http://"+server.Listener.Addr().String()+"/api",
|
||||
nil,
|
||||
)
|
||||
req.Header.Add("Origin", allowedOrigin)
|
||||
|
||||
get, _ := client.Do(req)
|
||||
|
||||
// You should get your origin (or a * depending on your config) if the
|
||||
// passed origin is allowed.
|
||||
o := get.Header.Get("Access-Control-Allow-Origin")
|
||||
assert.NotEqual(t, o, notAllowedOrigin, "Origins should not match")
|
||||
assert.Equal(t, o, allowedOrigin, "Origins don't match")
|
||||
}
|
|
@ -1,13 +1,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
"github.com/authorizerdev/authorizer/server/db"
|
||||
"github.com/authorizerdev/authorizer/server/env"
|
||||
"github.com/authorizerdev/authorizer/server/handlers"
|
||||
"github.com/authorizerdev/authorizer/server/middlewares"
|
||||
"github.com/authorizerdev/authorizer/server/oauth"
|
||||
"github.com/authorizerdev/authorizer/server/session"
|
||||
"github.com/authorizerdev/authorizer/server/utils"
|
||||
|
@ -15,39 +12,6 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GinContextToContextMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if constants.AUTHORIZER_URL == "" {
|
||||
url := location.Get(c)
|
||||
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
|
||||
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO use allowed origins for cors origin
|
||||
// TODO throw error if url is not allowed
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
constants.APP_URL = origin
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
env.InitEnv()
|
||||
db.InitDB()
|
||||
|
@ -57,8 +21,8 @@ func main() {
|
|||
|
||||
r := gin.Default()
|
||||
r.Use(location.Default())
|
||||
r.Use(GinContextToContextMiddleware())
|
||||
r.Use(CORSMiddleware())
|
||||
r.Use(middlewares.GinContextToContextMiddleware())
|
||||
r.Use(middlewares.CORSMiddleware())
|
||||
|
||||
r.GET("/", handlers.PlaygroundHandler())
|
||||
r.POST("/graphql", handlers.GraphqlHandler())
|
||||
|
|
23
server/middlewares/context.go
Normal file
23
server/middlewares/context.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
"github.com/gin-contrib/location"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GinContextToContextMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if constants.AUTHORIZER_URL == "" {
|
||||
url := location.Get(c)
|
||||
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
|
||||
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
29
server/middlewares/cors.go
Normal file
29
server/middlewares/cors.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
"github.com/authorizerdev/authorizer/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
constants.APP_URL = origin
|
||||
|
||||
if utils.IsValidOrigin(origin) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||
func SetCookie(gc *gin.Context, token string) {
|
||||
secure := true
|
||||
httpOnly := true
|
||||
host := GetHostName(constants.AUTHORIZER_URL)
|
||||
host, _ := GetHostParts(constants.AUTHORIZER_URL)
|
||||
domain := GetDomainName(constants.AUTHORIZER_URL)
|
||||
if domain != "localhost" {
|
||||
domain = "." + domain
|
||||
|
@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) {
|
|||
secure := true
|
||||
httpOnly := true
|
||||
|
||||
host := GetDomainName(constants.AUTHORIZER_URL)
|
||||
host, _ := GetHostParts(constants.AUTHORIZER_URL)
|
||||
domain := GetDomainName(constants.AUTHORIZER_URL)
|
||||
if domain != "localhost" {
|
||||
domain = "." + domain
|
||||
|
|
|
@ -5,21 +5,32 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// GetHostName function to get hostname
|
||||
func GetHostName(auth_url string) string {
|
||||
u, err := url.Parse(auth_url)
|
||||
// GetHostName function returns hostname and port
|
||||
func GetHostParts(uri string) (string, string) {
|
||||
tempURI := uri
|
||||
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
|
||||
tempURI = "https://" + tempURI
|
||||
}
|
||||
|
||||
u, err := url.Parse(tempURI)
|
||||
if err != nil {
|
||||
return `localhost`
|
||||
return "localhost", "8080"
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
port := u.Port()
|
||||
|
||||
return host
|
||||
return host, port
|
||||
}
|
||||
|
||||
// GetDomainName function to get domain name
|
||||
func GetDomainName(auth_url string) string {
|
||||
u, err := url.Parse(auth_url)
|
||||
func GetDomainName(uri string) string {
|
||||
tempURI := uri
|
||||
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
|
||||
tempURI = "https://" + tempURI
|
||||
}
|
||||
|
||||
u, err := url.Parse(tempURI)
|
||||
if err != nil {
|
||||
return `localhost`
|
||||
}
|
||||
|
|
|
@ -7,12 +7,13 @@ import (
|
|||
)
|
||||
|
||||
func TestGetHostName(t *testing.T) {
|
||||
authorizer_url := "http://test.herokuapp.com"
|
||||
authorizer_url := "http://test.herokuapp.com:80"
|
||||
|
||||
got := GetHostName(authorizer_url)
|
||||
want := "test.herokuapp.com"
|
||||
host, port := GetHostParts(authorizer_url)
|
||||
expectedHost := "test.herokuapp.com"
|
||||
|
||||
assert.Equal(t, got, want, "hostname should be equal")
|
||||
assert.Equal(t, host, expectedHost, "hostname should be equal")
|
||||
assert.Equal(t, port, "80", "port should be 80")
|
||||
}
|
||||
|
||||
func TestGetDomainName(t *testing.T) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package utils
|
|||
|
||||
import (
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
|
@ -13,16 +14,32 @@ func IsValidEmail(email string) bool {
|
|||
return err == nil
|
||||
}
|
||||
|
||||
func IsValidRedirectURL(url string) bool {
|
||||
func IsValidOrigin(url string) bool {
|
||||
if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
hasValidURL := false
|
||||
urlDomain := GetDomainName(url)
|
||||
hostName, port := GetHostParts(url)
|
||||
currentOrigin := hostName + ":" + port
|
||||
|
||||
for _, val := range constants.ALLOWED_ORIGINS {
|
||||
if strings.Contains(val, urlDomain) {
|
||||
for _, origin := range constants.ALLOWED_ORIGINS {
|
||||
replacedString := origin
|
||||
// if has regex whitelisted domains
|
||||
if strings.Contains(origin, "*") {
|
||||
replacedString = strings.Replace(origin, ".", "\\.", -1)
|
||||
replacedString = strings.Replace(replacedString, "*", ".*", -1)
|
||||
|
||||
if strings.HasPrefix(replacedString, ".*") {
|
||||
replacedString += "\\b"
|
||||
}
|
||||
|
||||
if strings.HasSuffix(replacedString, ".*") {
|
||||
replacedString = "\\b" + replacedString
|
||||
}
|
||||
}
|
||||
|
||||
if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched {
|
||||
hasValidURL = true
|
||||
break
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package utils
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -15,3 +16,19 @@ func TestIsValidEmail(t *testing.T) {
|
|||
assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email")
|
||||
assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email")
|
||||
}
|
||||
|
||||
func TestIsValidOrigin(t *testing.T) {
|
||||
// don't use portocal(http/https) for ALLOWED_ORIGINS while testing,
|
||||
// as we trim them off while running the main function
|
||||
constants.ALLOWED_ORIGINS = []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"}
|
||||
|
||||
assert.False(t, IsValidOrigin("http://myapp.com"), "it should be invalid origin")
|
||||
assert.False(t, IsValidOrigin("http://appgoogle.com"), "it should be invalid origin")
|
||||
assert.True(t, IsValidOrigin("http://app.google.com"), "it should be valid origin")
|
||||
assert.False(t, IsValidOrigin("http://app.google.ind"), "it should be invalid origin")
|
||||
assert.True(t, IsValidOrigin("http://app.google.in"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://xyx.abc.com"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://xyx.abc.in"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://xyxabc.in"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://localhost:8080"), "it should be valid origin")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user