-
Notifications
You must be signed in to change notification settings - Fork 0
/
cors_handler_test.go
148 lines (123 loc) · 4.86 KB
/
cors_handler_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAllowAll(t *testing.T) {
t.Run("Return true on allow all", func(t *testing.T) {
allowedOrigins = []string{"*"}
assert.True(t, allowAll(), "should be true")
})
t.Run("Return false on allow all", func(t *testing.T) {
allowedOrigins = []string{"http://example.com"}
assert.False(t, allowAll(), "should be false")
})
}
func TestAllowedOrigin(t *testing.T) {
t.Run("Return true if origin exists", func(t *testing.T) {
allowedOrigins = []string{"http://www.hawry.net", "http://www.benefactory.se"}
assert.True(t, isAllowedOrigin("http://www.benefactory.se"), "should be true")
})
t.Run("Return false if origin doesn't exists", func(t *testing.T) {
allowedOrigins = []string{"http://www.hawry.net", "http://www.benefactory.se"}
assert.False(t, isAllowedOrigin("http://www.example.com"), "should be false")
assert.False(t, isAllowedOrigin(""), "should be false")
})
}
func TestAllowedMethod(t *testing.T) {
t.Run("Return false if method isn't specified", func(t *testing.T) {
allowedCORSMethods = []string{"GET", "POST"}
assert.False(t, isAllowedMethod("DELETE"), "should be false")
assert.False(t, isAllowedMethod(""), "should be false")
})
t.Run("Return true if method is specified", func(t *testing.T) {
allowedCORSMethods = []string{"GET", "POST"}
assert.True(t, isAllowedMethod("GET"), "should be true")
})
}
func TestAllowedHeaders(t *testing.T) {
allowedHeaders = []string{"X-Real-IP", "Content-Type"}
reqHds := []string{"X-Real-IP", "X-Requested-With"}
shouldBe := []string{"X-Real-IP"}
assert.Subset(t, getAllowedHeaders(reqHds), shouldBe)
}
func TestSetMethods(t *testing.T) {
req, err := http.NewRequest("OPTIONS", "/index", nil)
if err != nil {
t.Fatal(err)
}
//Clear any leftover data from other tests
allowedCORSMethods = []string{}
allowedHeaders = []string{}
allowedOrigins = []string{}
exposedHeaders = []string{}
AllowCORSOrigins("http://example.com", "http://localhost")
AllowCORSMethods("GET", "POST")
AllowCORSHeaders("X-Real-IP", "Content-Type")
AllowCORSExposedHeaders("X-Exposed-Header")
corsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
})
t.Run("Requesting allowed origin", func(t *testing.T) {
req.Header.Set(corsOrigin, "http://localhost")
rr := httptest.NewRecorder()
handler := CORSHandler(corsHandler)
handler.ServeHTTP(rr, req)
assert.Equal(t, "http://localhost", rr.Header().Get(corsAccessControlAllowOrigin), "should be equal")
})
t.Run("Requesting disallowed origin", func(t *testing.T) {
req.Header.Set(corsOrigin, "http://notexamples")
rr := httptest.NewRecorder()
handler := CORSHandler(corsHandler)
handler.ServeHTTP(rr, req)
assert.Empty(t, rr.Header().Get(corsAccessControlAllowOrigin), "should be empty")
})
t.Run("Preflight request with methods and headers", func(t *testing.T) {
req, err := http.NewRequest(http.MethodOptions, "/index", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set(corsOrigin, "http://localhost")
req.Header.Set(corsAccessControlRequestMethod, "GET")
req.Header.Set(corsAccessControlRequestHeaders, "X-Real-IP")
rr := httptest.NewRecorder()
handler := CORSHandler(corsHandler)
handler.ServeHTTP(rr, req)
assert.Equal(t, "http://localhost", rr.Header().Get(corsAccessControlAllowOrigin))
assert.Equal(t, "GET, POST", rr.Header().Get(corsAccessControlAllowMethods))
assert.Equal(t, "X-Exposed-Header", rr.Header().Get(corsAccessControlExposeHeaders))
assert.Equal(t, "X-Real-IP", rr.Header().Get(corsAccessControlAllowHeaders))
})
}
func TestSupportCredentials(t *testing.T) {
allowedCORSMethods = []string{}
allowedHeaders = []string{}
allowedOrigins = []string{}
exposedHeaders = []string{}
req, err := http.NewRequest("GET", "/index", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Origin", "http://localhost")
SupportCredentials(true)
AllowCORSOrigins("*")
rr := httptest.NewRecorder()
handler := CORSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
handler.ServeHTTP(rr, req)
assert.Equal(t, "true", rr.Header().Get(corsAccessControlAllowCredentials))
assert.NotEqual(t, "*", rr.Header().Get(corsAccessControlAllowOrigin))
// assert.Equal(t, "Origin", rr.Header().Get(corsVary))
assert.Equal(t, "http://localhost", rr.Header().Get(corsAccessControlAllowOrigin))
}
func ExampleCORSHandler() {
AllowCORSOrigins("http://example.com") //requests from http://example.com are allowed
AllowCORSMethods("GET", "POST", "HEAD") //requests with method GET, POST and HEAD are allowed
AllowCORSHeaders("X-Requested-With") //the server will be able to handle the header X-Requested-With
corsHandler := CORSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ... do something
}))
http.Handle("/", corsHandler)
http.ListenAndServe(":8080", nil)
}