fix: add valid origin check for cors (#83)

Resolves #72
This commit is contained in:
Lakhan Samani 2021-12-21 18:46:54 +05:30 committed by GitHub
parent bdbbe4adee
commit 8f7582e1ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 180 additions and 59 deletions

View File

@ -6,4 +6,4 @@ cmd:
clean: clean:
rm -rf build rm -rf build
test: test:
cd server && go clean --testcache && go test ./... cd server && go clean --testcache && go test -v ./...

17
server/env/env.go vendored
View File

@ -87,15 +87,30 @@ func InitEnv() {
allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",") allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
allowedOrigins := []string{} allowedOrigins := []string{}
hasWildCard := false
for _, val := range allowedOriginsSplit { for _, val := range allowedOriginsSplit {
trimVal := strings.TrimSpace(val) trimVal := strings.TrimSpace(val)
if trimVal != "" { if trimVal != "" {
allowedOrigins = append(allowedOrigins, trimVal) if trimVal != "*" {
host, port := utils.GetHostParts(trimVal)
allowedOrigins = append(allowedOrigins, host+":"+port)
} else {
hasWildCard = true
allowedOrigins = append(allowedOrigins, trimVal)
break
}
} }
} }
if len(allowedOrigins) > 1 && hasWildCard {
allowedOrigins = []string{"*"}
}
if len(allowedOrigins) == 0 { if len(allowedOrigins) == 0 {
allowedOrigins = []string{"*"} allowedOrigins = []string{"*"}
} }
constants.ALLOWED_ORIGINS = allowedOrigins constants.ALLOWED_ORIGINS = allowedOrigins
if *ARG_AUTHORIZER_URL != "" { if *ARG_AUTHORIZER_URL != "" {

View File

@ -49,7 +49,7 @@ func AppHandler() gin.HandlerFunc {
stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/") stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/")
// validate redirect url with allowed origins // validate redirect url with allowed origins
if !utils.IsValidRedirectURL(stateObj.RedirectURL) { if !utils.IsValidOrigin(stateObj.RedirectURL) {
c.JSON(400, gin.H{"error": "invalid redirect url"}) c.JSON(400, gin.H{"error": "invalid redirect url"})
return return
} }

View File

@ -0,0 +1,44 @@
package integration_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/env"
"github.com/authorizerdev/authorizer/server/middlewares"
"github.com/gin-contrib/location"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestCors(t *testing.T) {
constants.ENV_PATH = "../../.env.local"
env.InitEnv()
r := gin.Default()
r.Use(location.Default())
r.Use(middlewares.GinContextToContextMiddleware())
r.Use(middlewares.CORSMiddleware())
allowedOrigin := "http://localhost:8080" // The allowed origin that you want to check
notAllowedOrigin := "http://myapp.com"
server := httptest.NewServer(r)
defer server.Close()
client := &http.Client{}
req, _ := http.NewRequest(
"GET",
"http://"+server.Listener.Addr().String()+"/api",
nil,
)
req.Header.Add("Origin", allowedOrigin)
get, _ := client.Do(req)
// You should get your origin (or a * depending on your config) if the
// passed origin is allowed.
o := get.Header.Get("Access-Control-Allow-Origin")
assert.NotEqual(t, o, notAllowedOrigin, "Origins should not match")
assert.Equal(t, o, allowedOrigin, "Origins don't match")
}

View File

@ -1,13 +1,10 @@
package main package main
import ( import (
"context"
"log"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/env" "github.com/authorizerdev/authorizer/server/env"
"github.com/authorizerdev/authorizer/server/handlers" "github.com/authorizerdev/authorizer/server/handlers"
"github.com/authorizerdev/authorizer/server/middlewares"
"github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/oauth"
"github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/session"
"github.com/authorizerdev/authorizer/server/utils" "github.com/authorizerdev/authorizer/server/utils"
@ -15,39 +12,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func GinContextToContextMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if constants.AUTHORIZER_URL == "" {
url := location.Get(c)
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
}
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// TODO use allowed origins for cors origin
// TODO throw error if url is not allowed
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
constants.APP_URL = origin
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}
func main() { func main() {
env.InitEnv() env.InitEnv()
db.InitDB() db.InitDB()
@ -57,8 +21,8 @@ func main() {
r := gin.Default() r := gin.Default()
r.Use(location.Default()) r.Use(location.Default())
r.Use(GinContextToContextMiddleware()) r.Use(middlewares.GinContextToContextMiddleware())
r.Use(CORSMiddleware()) r.Use(middlewares.CORSMiddleware())
r.GET("/", handlers.PlaygroundHandler()) r.GET("/", handlers.PlaygroundHandler())
r.POST("/graphql", handlers.GraphqlHandler()) r.POST("/graphql", handlers.GraphqlHandler())

View File

@ -0,0 +1,23 @@
package middlewares
import (
"context"
"log"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/gin-contrib/location"
"github.com/gin-gonic/gin"
)
func GinContextToContextMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if constants.AUTHORIZER_URL == "" {
url := location.Get(c)
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
}
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}

View File

@ -0,0 +1,29 @@
package middlewares
import (
"github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/utils"
"github.com/gin-gonic/gin"
)
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
constants.APP_URL = origin
if utils.IsValidOrigin(origin) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
}
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@ -10,7 +10,7 @@ import (
func SetCookie(gc *gin.Context, token string) { func SetCookie(gc *gin.Context, token string) {
secure := true secure := true
httpOnly := true httpOnly := true
host := GetHostName(constants.AUTHORIZER_URL) host, _ := GetHostParts(constants.AUTHORIZER_URL)
domain := GetDomainName(constants.AUTHORIZER_URL) domain := GetDomainName(constants.AUTHORIZER_URL)
if domain != "localhost" { if domain != "localhost" {
domain = "." + domain domain = "." + domain
@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) {
secure := true secure := true
httpOnly := true httpOnly := true
host := GetDomainName(constants.AUTHORIZER_URL) host, _ := GetHostParts(constants.AUTHORIZER_URL)
domain := GetDomainName(constants.AUTHORIZER_URL) domain := GetDomainName(constants.AUTHORIZER_URL)
if domain != "localhost" { if domain != "localhost" {
domain = "." + domain domain = "." + domain

View File

@ -5,21 +5,32 @@ import (
"strings" "strings"
) )
// GetHostName function to get hostname // GetHostName function returns hostname and port
func GetHostName(auth_url string) string { func GetHostParts(uri string) (string, string) {
u, err := url.Parse(auth_url) tempURI := uri
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
tempURI = "https://" + tempURI
}
u, err := url.Parse(tempURI)
if err != nil { if err != nil {
return `localhost` return "localhost", "8080"
} }
host := u.Hostname() host := u.Hostname()
port := u.Port()
return host return host, port
} }
// GetDomainName function to get domain name // GetDomainName function to get domain name
func GetDomainName(auth_url string) string { func GetDomainName(uri string) string {
u, err := url.Parse(auth_url) tempURI := uri
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
tempURI = "https://" + tempURI
}
u, err := url.Parse(tempURI)
if err != nil { if err != nil {
return `localhost` return `localhost`
} }

View File

@ -7,12 +7,13 @@ import (
) )
func TestGetHostName(t *testing.T) { func TestGetHostName(t *testing.T) {
authorizer_url := "http://test.herokuapp.com" authorizer_url := "http://test.herokuapp.com:80"
got := GetHostName(authorizer_url) host, port := GetHostParts(authorizer_url)
want := "test.herokuapp.com" expectedHost := "test.herokuapp.com"
assert.Equal(t, got, want, "hostname should be equal") assert.Equal(t, host, expectedHost, "hostname should be equal")
assert.Equal(t, port, "80", "port should be 80")
} }
func TestGetDomainName(t *testing.T) { func TestGetDomainName(t *testing.T) {

View File

@ -2,6 +2,7 @@ package utils
import ( import (
"net/mail" "net/mail"
"regexp"
"strings" "strings"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
@ -13,16 +14,32 @@ func IsValidEmail(email string) bool {
return err == nil return err == nil
} }
func IsValidRedirectURL(url string) bool { func IsValidOrigin(url string) bool {
if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" { if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" {
return true return true
} }
hasValidURL := false hasValidURL := false
urlDomain := GetDomainName(url) hostName, port := GetHostParts(url)
currentOrigin := hostName + ":" + port
for _, val := range constants.ALLOWED_ORIGINS { for _, origin := range constants.ALLOWED_ORIGINS {
if strings.Contains(val, urlDomain) { replacedString := origin
// if has regex whitelisted domains
if strings.Contains(origin, "*") {
replacedString = strings.Replace(origin, ".", "\\.", -1)
replacedString = strings.Replace(replacedString, "*", ".*", -1)
if strings.HasPrefix(replacedString, ".*") {
replacedString += "\\b"
}
if strings.HasSuffix(replacedString, ".*") {
replacedString = "\\b" + replacedString
}
}
if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched {
hasValidURL = true hasValidURL = true
break break
} }

View File

@ -3,6 +3,7 @@ package utils
import ( import (
"testing" "testing"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -15,3 +16,19 @@ func TestIsValidEmail(t *testing.T) {
assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email") assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email")
assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email") assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email")
} }
func TestIsValidOrigin(t *testing.T) {
// don't use portocal(http/https) for ALLOWED_ORIGINS while testing,
// as we trim them off while running the main function
constants.ALLOWED_ORIGINS = []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"}
assert.False(t, IsValidOrigin("http://myapp.com"), "it should be invalid origin")
assert.False(t, IsValidOrigin("http://appgoogle.com"), "it should be invalid origin")
assert.True(t, IsValidOrigin("http://app.google.com"), "it should be valid origin")
assert.False(t, IsValidOrigin("http://app.google.ind"), "it should be invalid origin")
assert.True(t, IsValidOrigin("http://app.google.in"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://xyx.abc.com"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://xyx.abc.in"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://xyxabc.in"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://localhost:8080"), "it should be valid origin")
}