Skip to content

Commit

Permalink
Move cors fields into object
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffail committed Sep 15, 2021
1 parent 4fe5c8a commit 46c86d0
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ All notable changes to this project will be documented in this file.
### Added

- Fields `cache_control`, `content_disposition`, `content_language` and `website_redirect_location` added to the `aws_s3` output.
- Field `enable_cors` added to the server wide `http` config.
- Field `cors.enabled` and `cors.allowed_origins` added to the server wide `http` config.

### Fixed

Expand Down
19 changes: 16 additions & 3 deletions lib/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ import (

//------------------------------------------------------------------------------

// CORS contains configuration for allowing CORS headers.
type CORS struct {
Enabled bool `json:"enabled" yaml:"enabled"`
AllowedOrigins []string `json:"allowed_origins" yaml:"allowed_origins"`
}

// Config contains the configuration fields for the Benthos API.
type Config struct {
Address string `json:"address" yaml:"address"`
Expand All @@ -30,7 +36,7 @@ type Config struct {
DebugEndpoints bool `json:"debug_endpoints" yaml:"debug_endpoints"`
CertFile string `json:"cert_file" yaml:"cert_file"`
KeyFile string `json:"key_file" yaml:"key_file"`
EnableCORS bool `json:"enable_cors" yaml:"enable_cors"`
CORS CORS `json:"cors" yaml:"cors"`
}

// NewConfig creates a new API config with default values.
Expand All @@ -43,7 +49,10 @@ func NewConfig() Config {
DebugEndpoints: false,
CertFile: "",
KeyFile: "",
EnableCORS: false,
CORS: CORS{
Enabled: false,
AllowedOrigins: []string{},
},
}
}

Expand Down Expand Up @@ -98,8 +107,12 @@ func New(
gMux := mux.NewRouter()

var handler http.Handler = gMux
if conf.EnableCORS {
if conf.CORS.Enabled {
if len(conf.CORS.AllowedOrigins) == 0 {
return nil, errors.New("must specify at least one allowed origin")
}
handler = handlers.CORS(
handlers.AllowedOrigins(conf.CORS.AllowedOrigins),
handlers.AllowedMethods([]string{"GET", "HEAD", "POST", "DELETE"}),
)(gMux)
}
Expand Down
53 changes: 52 additions & 1 deletion lib/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import (

func TestAPIEnableCORS(t *testing.T) {
conf := NewConfig()
conf.EnableCORS = true
conf.CORS.Enabled = true
conf.CORS.AllowedOrigins = []string{"*"}

s, err := New("", "", conf, nil, log.Noop(), metrics.Noop())
require.NoError(t, err)
Expand All @@ -30,3 +31,53 @@ func TestAPIEnableCORS(t *testing.T) {
assert.Equal(t, http.StatusOK, response.Code)
assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin"))
}

func TestAPIEnableCORSOrigins(t *testing.T) {
conf := NewConfig()
conf.CORS.Enabled = true
conf.CORS.AllowedOrigins = []string{"foo", "bar"}

s, err := New("", "", conf, nil, log.Noop(), metrics.Noop())
require.NoError(t, err)

handler := s.server.Handler

request, _ := http.NewRequest("OPTIONS", "/version", nil)
request.Header.Add("Origin", "foo")
request.Header.Add("Access-Control-Request-Method", "POST")

response := httptest.NewRecorder()
handler.ServeHTTP(response, request)

assert.Equal(t, http.StatusOK, response.Code)
assert.Equal(t, "foo", response.Header().Get("Access-Control-Allow-Origin"))

request, _ = http.NewRequest("OPTIONS", "/version", nil)
request.Header.Add("Origin", "bar")
request.Header.Add("Access-Control-Request-Method", "POST")

response = httptest.NewRecorder()
handler.ServeHTTP(response, request)

assert.Equal(t, http.StatusOK, response.Code)
assert.Equal(t, "bar", response.Header().Get("Access-Control-Allow-Origin"))

request, _ = http.NewRequest("OPTIONS", "/version", nil)
request.Header.Add("Origin", "baz")
request.Header.Add("Access-Control-Request-Method", "POST")

response = httptest.NewRecorder()
handler.ServeHTTP(response, request)

assert.Equal(t, http.StatusOK, response.Code)
assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin"))
}

func TestAPIEnableCORSNoHeaders(t *testing.T) {
conf := NewConfig()
conf.CORS.Enabled = true

_, err := New("", "", conf, nil, log.Noop(), metrics.Noop())
require.Error(t, err)
assert.Contains(t, err.Error(), "must specify at least one allowed origin")
}
5 changes: 4 additions & 1 deletion lib/api/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ func Spec() docs.FieldSpecs {
).HasDefault(false),
docs.FieldString("cert_file", "An optional certificate file for enabling TLS.").Advanced().HasDefault(""),
docs.FieldString("key_file", "An optional key file for enabling TLS.").Advanced().HasDefault(""),
docs.FieldBool("enable_cors", "Adds Cross-Origin Resource Sharing headers.").Advanced().HasDefault(false),
docs.FieldAdvanced("cors", "Adds Cross-Origin Resource Sharing headers.").WithChildren(
docs.FieldBool("enabled", "Whether to allow CORS requests.").HasDefault(false),
docs.FieldString("allowed_origins", "An explicit list of origins that are allowed for CORS requests.").Array().HasDefault([]string{}),
),
docs.FieldDeprecated("read_timeout"),
}
}
4 changes: 3 additions & 1 deletion website/docs/components/http/about.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ http:
debug_endpoints: false
cert_file: ""
key_file: ""
enable_cors: false
cors:
enabled: false
allowed_origins: []
```
The field `enabled` can be set to `false` in order to disable the server.
Expand Down

0 comments on commit 46c86d0

Please sign in to comment.