fix: add valid origin check for cors (#83)

Resolves #72
This commit is contained in:
Lakhan Samani
2021-12-21 18:46:54 +05:30
committed by GitHub
parent bdbbe4adee
commit 8f7582e1ec
12 changed files with 180 additions and 59 deletions

View File

@@ -10,7 +10,7 @@ import (
func SetCookie(gc *gin.Context, token string) {
secure := true
httpOnly := true
host := GetHostName(constants.AUTHORIZER_URL)
host, _ := GetHostParts(constants.AUTHORIZER_URL)
domain := GetDomainName(constants.AUTHORIZER_URL)
if domain != "localhost" {
domain = "." + domain
@@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) {
secure := true
httpOnly := true
host := GetDomainName(constants.AUTHORIZER_URL)
host, _ := GetHostParts(constants.AUTHORIZER_URL)
domain := GetDomainName(constants.AUTHORIZER_URL)
if domain != "localhost" {
domain = "." + domain

View File

@@ -5,21 +5,32 @@ import (
"strings"
)
// GetHostName function to get hostname
func GetHostName(auth_url string) string {
u, err := url.Parse(auth_url)
// GetHostName function returns hostname and port
func GetHostParts(uri string) (string, string) {
tempURI := uri
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
tempURI = "https://" + tempURI
}
u, err := url.Parse(tempURI)
if err != nil {
return `localhost`
return "localhost", "8080"
}
host := u.Hostname()
port := u.Port()
return host
return host, port
}
// GetDomainName function to get domain name
func GetDomainName(auth_url string) string {
u, err := url.Parse(auth_url)
func GetDomainName(uri string) string {
tempURI := uri
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
tempURI = "https://" + tempURI
}
u, err := url.Parse(tempURI)
if err != nil {
return `localhost`
}

View File

@@ -7,12 +7,13 @@ import (
)
func TestGetHostName(t *testing.T) {
authorizer_url := "http://test.herokuapp.com"
authorizer_url := "http://test.herokuapp.com:80"
got := GetHostName(authorizer_url)
want := "test.herokuapp.com"
host, port := GetHostParts(authorizer_url)
expectedHost := "test.herokuapp.com"
assert.Equal(t, got, want, "hostname should be equal")
assert.Equal(t, host, expectedHost, "hostname should be equal")
assert.Equal(t, port, "80", "port should be 80")
}
func TestGetDomainName(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package utils
import (
"net/mail"
"regexp"
"strings"
"github.com/authorizerdev/authorizer/server/constants"
@@ -13,16 +14,32 @@ func IsValidEmail(email string) bool {
return err == nil
}
func IsValidRedirectURL(url string) bool {
func IsValidOrigin(url string) bool {
if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" {
return true
}
hasValidURL := false
urlDomain := GetDomainName(url)
hostName, port := GetHostParts(url)
currentOrigin := hostName + ":" + port
for _, val := range constants.ALLOWED_ORIGINS {
if strings.Contains(val, urlDomain) {
for _, origin := range constants.ALLOWED_ORIGINS {
replacedString := origin
// if has regex whitelisted domains
if strings.Contains(origin, "*") {
replacedString = strings.Replace(origin, ".", "\\.", -1)
replacedString = strings.Replace(replacedString, "*", ".*", -1)
if strings.HasPrefix(replacedString, ".*") {
replacedString += "\\b"
}
if strings.HasSuffix(replacedString, ".*") {
replacedString = "\\b" + replacedString
}
}
if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched {
hasValidURL = true
break
}

View File

@@ -3,6 +3,7 @@ package utils
import (
"testing"
"github.com/authorizerdev/authorizer/server/constants"
"github.com/stretchr/testify/assert"
)
@@ -15,3 +16,19 @@ func TestIsValidEmail(t *testing.T) {
assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email")
assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email")
}
func TestIsValidOrigin(t *testing.T) {
// don't use portocal(http/https) for ALLOWED_ORIGINS while testing,
// as we trim them off while running the main function
constants.ALLOWED_ORIGINS = []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"}
assert.False(t, IsValidOrigin("http://myapp.com"), "it should be invalid origin")
assert.False(t, IsValidOrigin("http://appgoogle.com"), "it should be invalid origin")
assert.True(t, IsValidOrigin("http://app.google.com"), "it should be valid origin")
assert.False(t, IsValidOrigin("http://app.google.ind"), "it should be invalid origin")
assert.True(t, IsValidOrigin("http://app.google.in"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://xyx.abc.com"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://xyx.abc.in"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://xyxabc.in"), "it should be valid origin")
assert.True(t, IsValidOrigin("http://localhost:8080"), "it should be valid origin")
}