fix #636: mitingate long header attack
License: MIT Signed-off-by: Alexey Novikov <alexey@novikov.io>
This commit is contained in:
parent
5081310811
commit
53d624e701
|
@ -15,6 +15,7 @@ import (
|
||||||
const (
|
const (
|
||||||
configKey = "ipfsproxy"
|
configKey = "ipfsproxy"
|
||||||
envConfigKey = "cluster_ipfsproxy"
|
envConfigKey = "cluster_ipfsproxy"
|
||||||
|
minMaxHeaderBytes = 4096
|
||||||
)
|
)
|
||||||
|
|
||||||
// Default values for Config.
|
// Default values for Config.
|
||||||
|
@ -28,6 +29,7 @@ const (
|
||||||
DefaultIdleTimeout = 60 * time.Second
|
DefaultIdleTimeout = 60 * time.Second
|
||||||
DefaultExtractHeadersPath = "/api/v0/version"
|
DefaultExtractHeadersPath = "/api/v0/version"
|
||||||
DefaultExtractHeadersTTL = 5 * time.Minute
|
DefaultExtractHeadersTTL = 5 * time.Minute
|
||||||
|
DefaultMaxHeaderBytes = minMaxHeaderBytes
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config allows to customize behaviour of IPFSProxy.
|
// Config allows to customize behaviour of IPFSProxy.
|
||||||
|
@ -53,6 +55,10 @@ type Config struct {
|
||||||
// Maximum duration before timing out write of the response
|
// Maximum duration before timing out write of the response
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
|
// Maximum cumulative size of HTTP request headers in bytes
|
||||||
|
// accepted by the server
|
||||||
|
MaxHeaderBytes int
|
||||||
|
|
||||||
// Server-side amount of time a Keep-Alive connection will be
|
// Server-side amount of time a Keep-Alive connection will be
|
||||||
// kept idle before being reused
|
// kept idle before being reused
|
||||||
IdleTimeout time.Duration
|
IdleTimeout time.Duration
|
||||||
|
@ -88,6 +94,7 @@ type jsonConfig struct {
|
||||||
ReadHeaderTimeout string `json:"read_header_timeout"`
|
ReadHeaderTimeout string `json:"read_header_timeout"`
|
||||||
WriteTimeout string `json:"write_timeout"`
|
WriteTimeout string `json:"write_timeout"`
|
||||||
IdleTimeout string `json:"idle_timeout"`
|
IdleTimeout string `json:"idle_timeout"`
|
||||||
|
MaxHeaderBytes int `json:"max_header_bytes"`
|
||||||
|
|
||||||
ExtractHeadersExtra []string `json:"extract_headers_extra,omitempty"`
|
ExtractHeadersExtra []string `json:"extract_headers_extra,omitempty"`
|
||||||
ExtractHeadersPath string `json:"extract_headers_path,omitempty"`
|
ExtractHeadersPath string `json:"extract_headers_path,omitempty"`
|
||||||
|
@ -118,6 +125,7 @@ func (cfg *Config) Default() error {
|
||||||
cfg.ExtractHeadersExtra = nil
|
cfg.ExtractHeadersExtra = nil
|
||||||
cfg.ExtractHeadersPath = DefaultExtractHeadersPath
|
cfg.ExtractHeadersPath = DefaultExtractHeadersPath
|
||||||
cfg.ExtractHeadersTTL = DefaultExtractHeadersTTL
|
cfg.ExtractHeadersTTL = DefaultExtractHeadersTTL
|
||||||
|
cfg.MaxHeaderBytes = DefaultMaxHeaderBytes
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -173,6 +181,10 @@ func (cfg *Config) Validate() error {
|
||||||
err = errors.New("ipfsproxy.extract_headers_ttl is invalid")
|
err = errors.New("ipfsproxy.extract_headers_ttl is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.MaxHeaderBytes < minMaxHeaderBytes {
|
||||||
|
err = fmt.Errorf("ipfsproxy.max_header_size must be greater or equal to %d", minMaxHeaderBytes)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,6 +231,12 @@ func (cfg *Config) applyJSONConfig(jcfg *jsonConfig) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if jcfg.MaxHeaderBytes == 0 {
|
||||||
|
cfg.MaxHeaderBytes = DefaultMaxHeaderBytes
|
||||||
|
} else {
|
||||||
|
cfg.MaxHeaderBytes = jcfg.MaxHeaderBytes
|
||||||
|
}
|
||||||
|
|
||||||
if extra := jcfg.ExtractHeadersExtra; extra != nil && len(extra) > 0 {
|
if extra := jcfg.ExtractHeadersExtra; extra != nil && len(extra) > 0 {
|
||||||
cfg.ExtractHeadersExtra = extra
|
cfg.ExtractHeadersExtra = extra
|
||||||
}
|
}
|
||||||
|
@ -255,6 +273,7 @@ func (cfg *Config) toJSONConfig() (jcfg *jsonConfig, err error) {
|
||||||
jcfg.ReadHeaderTimeout = cfg.ReadHeaderTimeout.String()
|
jcfg.ReadHeaderTimeout = cfg.ReadHeaderTimeout.String()
|
||||||
jcfg.WriteTimeout = cfg.WriteTimeout.String()
|
jcfg.WriteTimeout = cfg.WriteTimeout.String()
|
||||||
jcfg.IdleTimeout = cfg.IdleTimeout.String()
|
jcfg.IdleTimeout = cfg.IdleTimeout.String()
|
||||||
|
jcfg.MaxHeaderBytes = cfg.MaxHeaderBytes
|
||||||
jcfg.NodeHTTPS = cfg.NodeHTTPS
|
jcfg.NodeHTTPS = cfg.NodeHTTPS
|
||||||
|
|
||||||
jcfg.ExtractHeadersExtra = cfg.ExtractHeadersExtra
|
jcfg.ExtractHeadersExtra = cfg.ExtractHeadersExtra
|
||||||
|
|
|
@ -15,12 +15,21 @@ var cfgJSON = []byte(`
|
||||||
"read_header_timeout": "5s",
|
"read_header_timeout": "5s",
|
||||||
"write_timeout": "10m0s",
|
"write_timeout": "10m0s",
|
||||||
"idle_timeout": "1m0s",
|
"idle_timeout": "1m0s",
|
||||||
|
"max_header_bytes": 16384,
|
||||||
"extract_headers_extra": [],
|
"extract_headers_extra": [],
|
||||||
"extract_headers_path": "/api/v0/version",
|
"extract_headers_path": "/api/v0/version",
|
||||||
"extract_headers_ttl": "5m"
|
"extract_headers_ttl": "5m"
|
||||||
}
|
}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
|
func TestLoadEmptyJSON(t *testing.T) {
|
||||||
|
cfg := &Config{}
|
||||||
|
err := cfg.LoadJSON([]byte(`{}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadJSON(t *testing.T) {
|
func TestLoadJSON(t *testing.T) {
|
||||||
cfg := &Config{}
|
cfg := &Config{}
|
||||||
err := cfg.LoadJSON(cfgJSON)
|
err := cfg.LoadJSON(cfgJSON)
|
||||||
|
@ -63,6 +72,14 @@ func TestLoadJSON(t *testing.T) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error in extract_headers_ttl")
|
t.Error("expected error in extract_headers_ttl")
|
||||||
}
|
}
|
||||||
|
j = &jsonConfig{}
|
||||||
|
json.Unmarshal(cfgJSON, j)
|
||||||
|
j.MaxHeaderBytes = minMaxHeaderBytes - 1
|
||||||
|
tst, _ = json.Marshal(j)
|
||||||
|
err = cfg.LoadJSON(tst)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error in extract_headers_ttl")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToJSON(t *testing.T) {
|
func TestToJSON(t *testing.T) {
|
||||||
|
|
|
@ -154,6 +154,7 @@ func New(cfg *Config) (*Server, error) {
|
||||||
ReadHeaderTimeout: cfg.ReadHeaderTimeout,
|
ReadHeaderTimeout: cfg.ReadHeaderTimeout,
|
||||||
IdleTimeout: cfg.IdleTimeout,
|
IdleTimeout: cfg.IdleTimeout,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
|
MaxHeaderBytes: cfg.MaxHeaderBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// See: https://github.com/ipfs/go-ipfs/issues/5168
|
// See: https://github.com/ipfs/go-ipfs/issues/5168
|
||||||
|
|
|
@ -24,14 +24,12 @@ func init() {
|
||||||
_ = logging.Logger
|
_ = logging.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func testIPFSProxy(t *testing.T) (*Server, *test.IpfsMock) {
|
func testIPFSProxyWithConfig(t *testing.T, cfg *Config) (*Server, *test.IpfsMock) {
|
||||||
mock := test.NewIpfsMock()
|
mock := test.NewIpfsMock()
|
||||||
nodeMAddr, _ := ma.NewMultiaddr(fmt.Sprintf("/ip4/%s/tcp/%d",
|
nodeMAddr, _ := ma.NewMultiaddr(fmt.Sprintf("/ip4/%s/tcp/%d",
|
||||||
mock.Addr, mock.Port))
|
mock.Addr, mock.Port))
|
||||||
proxyMAddr, _ := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0")
|
proxyMAddr, _ := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0")
|
||||||
|
|
||||||
cfg := &Config{}
|
|
||||||
cfg.Default()
|
|
||||||
cfg.NodeAddr = nodeMAddr
|
cfg.NodeAddr = nodeMAddr
|
||||||
cfg.ListenAddr = proxyMAddr
|
cfg.ListenAddr = proxyMAddr
|
||||||
cfg.ExtractHeadersExtra = []string{
|
cfg.ExtractHeadersExtra = []string{
|
||||||
|
@ -49,6 +47,12 @@ func testIPFSProxy(t *testing.T) (*Server, *test.IpfsMock) {
|
||||||
return proxy, mock
|
return proxy, mock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testIPFSProxy(t *testing.T) (*Server, *test.IpfsMock) {
|
||||||
|
cfg := &Config{}
|
||||||
|
cfg.Default()
|
||||||
|
return testIPFSProxyWithConfig(t, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
func TestIPFSProxyVersion(t *testing.T) {
|
func TestIPFSProxyVersion(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
proxy, mock := testIPFSProxy(t)
|
proxy, mock := testIPFSProxy(t)
|
||||||
|
@ -617,3 +621,42 @@ func TestHeaderExtraction(t *testing.T) {
|
||||||
t.Error("should have refreshed the headers after TTL")
|
t.Error("should have refreshed the headers after TTL")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAttackHeaderSize(t *testing.T) {
|
||||||
|
const testHeaderSize = minMaxHeaderBytes * 4
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := &Config{}
|
||||||
|
cfg.Default()
|
||||||
|
cfg.MaxHeaderBytes = testHeaderSize
|
||||||
|
proxy, mock := testIPFSProxyWithConfig(t, cfg)
|
||||||
|
defer mock.Close()
|
||||||
|
defer proxy.Shutdown(ctx)
|
||||||
|
|
||||||
|
type testcase struct {
|
||||||
|
headerSize int
|
||||||
|
expectedStatus int
|
||||||
|
}
|
||||||
|
testcases := []testcase{
|
||||||
|
testcase{testHeaderSize / 2, http.StatusNotFound},
|
||||||
|
testcase{testHeaderSize * 2, http.StatusRequestHeaderFieldsTooLarge},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", fmt.Sprintf("%s/foo", proxyURL(proxy)), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
for size := 0; size < tc.headerSize; size += 8 {
|
||||||
|
req.Header.Add("Foo", "bar")
|
||||||
|
}
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("should forward requests to ipfs host: ", err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
if res.StatusCode != tc.expectedStatus {
|
||||||
|
t.Errorf("proxy returned unexpected status %d, expected status code was %d",
|
||||||
|
res.StatusCode, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@ import (
|
||||||
const configKey = "restapi"
|
const configKey = "restapi"
|
||||||
const envConfigKey = "cluster_restapi"
|
const envConfigKey = "cluster_restapi"
|
||||||
|
|
||||||
|
const minMaxHeaderBytes = 4096
|
||||||
|
|
||||||
// These are the default values for Config
|
// These are the default values for Config
|
||||||
const (
|
const (
|
||||||
DefaultHTTPListenAddr = "/ip4/127.0.0.1/tcp/9094"
|
DefaultHTTPListenAddr = "/ip4/127.0.0.1/tcp/9094"
|
||||||
|
@ -30,6 +32,7 @@ const (
|
||||||
DefaultReadHeaderTimeout = 5 * time.Second
|
DefaultReadHeaderTimeout = 5 * time.Second
|
||||||
DefaultWriteTimeout = 0
|
DefaultWriteTimeout = 0
|
||||||
DefaultIdleTimeout = 120 * time.Second
|
DefaultIdleTimeout = 120 * time.Second
|
||||||
|
DefaultMaxHeaderBytes = minMaxHeaderBytes
|
||||||
)
|
)
|
||||||
|
|
||||||
// These are the default values for Config.
|
// These are the default values for Config.
|
||||||
|
@ -89,6 +92,10 @@ type Config struct {
|
||||||
// kept idle before being reused
|
// kept idle before being reused
|
||||||
IdleTimeout time.Duration
|
IdleTimeout time.Duration
|
||||||
|
|
||||||
|
// Maximum cumulative size of HTTP request headers in bytes
|
||||||
|
// accepted by the server
|
||||||
|
MaxHeaderBytes int
|
||||||
|
|
||||||
// Listen address for the Libp2p REST API endpoint.
|
// Listen address for the Libp2p REST API endpoint.
|
||||||
Libp2pListenAddr ma.Multiaddr
|
Libp2pListenAddr ma.Multiaddr
|
||||||
|
|
||||||
|
@ -125,6 +132,7 @@ type jsonConfig struct {
|
||||||
ReadHeaderTimeout string `json:"read_header_timeout"`
|
ReadHeaderTimeout string `json:"read_header_timeout"`
|
||||||
WriteTimeout string `json:"write_timeout"`
|
WriteTimeout string `json:"write_timeout"`
|
||||||
IdleTimeout string `json:"idle_timeout"`
|
IdleTimeout string `json:"idle_timeout"`
|
||||||
|
MaxHeaderBytes int `json:"max_header_bytes"`
|
||||||
|
|
||||||
Libp2pListenMultiaddress string `json:"libp2p_listen_multiaddress,omitempty"`
|
Libp2pListenMultiaddress string `json:"libp2p_listen_multiaddress,omitempty"`
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
|
@ -158,6 +166,7 @@ func (cfg *Config) Default() error {
|
||||||
cfg.ReadHeaderTimeout = DefaultReadHeaderTimeout
|
cfg.ReadHeaderTimeout = DefaultReadHeaderTimeout
|
||||||
cfg.WriteTimeout = DefaultWriteTimeout
|
cfg.WriteTimeout = DefaultWriteTimeout
|
||||||
cfg.IdleTimeout = DefaultIdleTimeout
|
cfg.IdleTimeout = DefaultIdleTimeout
|
||||||
|
cfg.MaxHeaderBytes = DefaultMaxHeaderBytes
|
||||||
|
|
||||||
// libp2p
|
// libp2p
|
||||||
cfg.ID = ""
|
cfg.ID = ""
|
||||||
|
@ -208,6 +217,8 @@ func (cfg *Config) Validate() error {
|
||||||
return errors.New("restapi.write_timeout is invalid")
|
return errors.New("restapi.write_timeout is invalid")
|
||||||
case cfg.IdleTimeout < 0:
|
case cfg.IdleTimeout < 0:
|
||||||
return errors.New("restapi.idle_timeout invalid")
|
return errors.New("restapi.idle_timeout invalid")
|
||||||
|
case cfg.MaxHeaderBytes < minMaxHeaderBytes:
|
||||||
|
return fmt.Errorf("restapi.max_header_bytes must be not less then %d", minMaxHeaderBytes)
|
||||||
case cfg.BasicAuthCreds != nil && len(cfg.BasicAuthCreds) == 0:
|
case cfg.BasicAuthCreds != nil && len(cfg.BasicAuthCreds) == 0:
|
||||||
return errors.New("restapi.basic_auth_creds should be null or have at least one entry")
|
return errors.New("restapi.basic_auth_creds should be null or have at least one entry")
|
||||||
case (cfg.pathSSLCertFile != "" || cfg.pathSSLKeyFile != "") && cfg.TLS == nil:
|
case (cfg.pathSSLCertFile != "" || cfg.pathSSLKeyFile != "") && cfg.TLS == nil:
|
||||||
|
@ -280,6 +291,12 @@ func (cfg *Config) loadHTTPOptions(jcfg *jsonConfig) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if jcfg.MaxHeaderBytes == 0 {
|
||||||
|
cfg.MaxHeaderBytes = DefaultMaxHeaderBytes
|
||||||
|
} else {
|
||||||
|
cfg.MaxHeaderBytes = jcfg.MaxHeaderBytes
|
||||||
|
}
|
||||||
|
|
||||||
// CORS
|
// CORS
|
||||||
cfg.CORSAllowedOrigins = jcfg.CORSAllowedOrigins
|
cfg.CORSAllowedOrigins = jcfg.CORSAllowedOrigins
|
||||||
cfg.CORSAllowedMethods = jcfg.CORSAllowedMethods
|
cfg.CORSAllowedMethods = jcfg.CORSAllowedMethods
|
||||||
|
@ -390,6 +407,7 @@ func (cfg *Config) toJSONConfig() (jcfg *jsonConfig, err error) {
|
||||||
ReadHeaderTimeout: cfg.ReadHeaderTimeout.String(),
|
ReadHeaderTimeout: cfg.ReadHeaderTimeout.String(),
|
||||||
WriteTimeout: cfg.WriteTimeout.String(),
|
WriteTimeout: cfg.WriteTimeout.String(),
|
||||||
IdleTimeout: cfg.IdleTimeout.String(),
|
IdleTimeout: cfg.IdleTimeout.String(),
|
||||||
|
MaxHeaderBytes: cfg.MaxHeaderBytes,
|
||||||
BasicAuthCreds: cfg.BasicAuthCreds,
|
BasicAuthCreds: cfg.BasicAuthCreds,
|
||||||
Headers: cfg.Headers,
|
Headers: cfg.Headers,
|
||||||
CORSAllowedOrigins: cfg.CORSAllowedOrigins,
|
CORSAllowedOrigins: cfg.CORSAllowedOrigins,
|
||||||
|
|
|
@ -21,6 +21,7 @@ var cfgJSON = []byte(`
|
||||||
"read_header_timeout": "5s",
|
"read_header_timeout": "5s",
|
||||||
"write_timeout": "1m0s",
|
"write_timeout": "1m0s",
|
||||||
"idle_timeout": "2m0s",
|
"idle_timeout": "2m0s",
|
||||||
|
"max_header_bytes": 16384,
|
||||||
"basic_auth_credentials": null,
|
"basic_auth_credentials": null,
|
||||||
"cors_allowed_origins": ["myorigin"],
|
"cors_allowed_origins": ["myorigin"],
|
||||||
"cors_allowed_methods": ["GET"],
|
"cors_allowed_methods": ["GET"],
|
||||||
|
@ -31,6 +32,14 @@ var cfgJSON = []byte(`
|
||||||
}
|
}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
|
func TestLoadEmptyJSON(t *testing.T) {
|
||||||
|
cfg := &Config{}
|
||||||
|
err := cfg.LoadJSON([]byte(`{}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadJSON(t *testing.T) {
|
func TestLoadJSON(t *testing.T) {
|
||||||
cfg := &Config{}
|
cfg := &Config{}
|
||||||
err := cfg.LoadJSON(cfgJSON)
|
err := cfg.LoadJSON(cfgJSON)
|
||||||
|
@ -108,6 +117,15 @@ func TestLoadJSON(t *testing.T) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error with private key")
|
t.Error("expected error with private key")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
j = &jsonConfig{}
|
||||||
|
json.Unmarshal(cfgJSON, j)
|
||||||
|
j.MaxHeaderBytes = minMaxHeaderBytes - 1
|
||||||
|
tst, _ = json.Marshal(j)
|
||||||
|
err = cfg.LoadJSON(tst)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error with MaxHeaderBytes")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyEnvVars(t *testing.T) {
|
func TestApplyEnvVars(t *testing.T) {
|
||||||
|
|
|
@ -137,12 +137,14 @@ func NewAPIWithHost(ctx context.Context, cfg *Config, h host.Host) (*API, error)
|
||||||
WriteTimeout: cfg.WriteTimeout,
|
WriteTimeout: cfg.WriteTimeout,
|
||||||
IdleTimeout: cfg.IdleTimeout,
|
IdleTimeout: cfg.IdleTimeout,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
|
MaxHeaderBytes: cfg.MaxHeaderBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// See: https://github.com/ipfs/go-ipfs/issues/5168
|
// See: https://github.com/ipfs/go-ipfs/issues/5168
|
||||||
// See: https://github.com/ipfs/ipfs-cluster/issues/548
|
// See: https://github.com/ipfs/ipfs-cluster/issues/548
|
||||||
// on why this is re-enabled.
|
// on why this is re-enabled.
|
||||||
s.SetKeepAlivesEnabled(true)
|
s.SetKeepAlivesEnabled(true)
|
||||||
|
s.MaxHeaderBytes = cfg.MaxHeaderBytes
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -1073,21 +1074,27 @@ type httpTestcase struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpStatusCodeChecker(resp *http.Response, expectedStatus int) error {
|
func httpStatusCodeChecker(resp *http.Response, expectedStatus int) error {
|
||||||
if resp.StatusCode != expectedStatus {
|
if resp.StatusCode == expectedStatus {
|
||||||
return fmt.Errorf("bad HTTP status code: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unexpected HTTP status code: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertHTTPStatusIsUnauthoriazed(resp *http.Response) error {
|
func assertHTTPStatusIsUnauthoriazed(resp *http.Response) error {
|
||||||
return httpStatusCodeChecker(resp, http.StatusUnauthorized)
|
return httpStatusCodeChecker(resp, http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertHTTPStatusIsNotUnauthoriazed(resp *http.Response) error {
|
func assertHTTPStatusIsTooLarge(resp *http.Response) error {
|
||||||
if assertHTTPStatusIsUnauthoriazed(resp) == nil {
|
return httpStatusCodeChecker(resp, http.StatusRequestHeaderFieldsTooLarge)
|
||||||
return fmt.Errorf("unexpected HTTP status code: %d", http.StatusUnauthorized)
|
}
|
||||||
|
|
||||||
|
func makeInvertedHTTPStatusAssert(checker responseChecker) responseChecker {
|
||||||
|
return func(resp *http.Response) error {
|
||||||
|
if checker(resp) == nil {
|
||||||
|
return fmt.Errorf("unexpected HTTP status code: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *httpTestcase) getTestFunction(api *API) testF {
|
func (tc *httpTestcase) getTestFunction(api *API) testF {
|
||||||
|
@ -1115,7 +1122,12 @@ func (tc *httpTestcase) getTestFunction(api *API) testF {
|
||||||
}
|
}
|
||||||
if tc.checker != nil {
|
if tc.checker != nil {
|
||||||
if err := tc.checker(resp); err != nil {
|
if err := tc.checker(resp); err != nil {
|
||||||
t.Error("Assertion failed: ", err)
|
r, e := httputil.DumpRequest(req, true)
|
||||||
|
if e != nil {
|
||||||
|
t.Errorf("Assertion failed with: %q", err)
|
||||||
|
} else {
|
||||||
|
t.Errorf("Assertion failed with: %q on request: \n%.100s", err, r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1128,6 +1140,15 @@ func makeBasicAuthRequestShaper(username, password string) requestShaper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeLongHeaderShaper(size int) requestShaper {
|
||||||
|
return func(req *http.Request) error {
|
||||||
|
for sz := size; sz > 0; sz -= 8 {
|
||||||
|
req.Header.Add("Foo", "bar")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBasicAuth(t *testing.T) {
|
func TestBasicAuth(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
rest := testAPIwithBasicAuth(t)
|
rest := testAPIwithBasicAuth(t)
|
||||||
|
@ -1223,31 +1244,58 @@ func TestBasicAuth(t *testing.T) {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
path: "/foo",
|
path: "/foo",
|
||||||
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
||||||
checker: assertHTTPStatusIsNotUnauthoriazed,
|
checker: makeInvertedHTTPStatusAssert(assertHTTPStatusIsUnauthoriazed),
|
||||||
},
|
},
|
||||||
httpTestcase{
|
httpTestcase{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
path: "/foo",
|
path: "/foo",
|
||||||
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
||||||
checker: assertHTTPStatusIsNotUnauthoriazed,
|
checker: makeInvertedHTTPStatusAssert(assertHTTPStatusIsUnauthoriazed),
|
||||||
},
|
},
|
||||||
httpTestcase{
|
httpTestcase{
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
path: "/foo",
|
path: "/foo",
|
||||||
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
||||||
checker: assertHTTPStatusIsNotUnauthoriazed,
|
checker: makeInvertedHTTPStatusAssert(assertHTTPStatusIsUnauthoriazed),
|
||||||
},
|
},
|
||||||
httpTestcase{
|
httpTestcase{
|
||||||
method: "BAR",
|
method: "BAR",
|
||||||
path: "/foo",
|
path: "/foo",
|
||||||
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
||||||
checker: assertHTTPStatusIsNotUnauthoriazed,
|
checker: makeInvertedHTTPStatusAssert(assertHTTPStatusIsUnauthoriazed),
|
||||||
},
|
},
|
||||||
httpTestcase{
|
httpTestcase{
|
||||||
method: "GET",
|
method: "GET",
|
||||||
path: "/id",
|
path: "/id",
|
||||||
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
shaper: makeBasicAuthRequestShaper(validUserName, validUserPassword),
|
||||||
checker: assertHTTPStatusIsNotUnauthoriazed,
|
checker: makeInvertedHTTPStatusAssert(assertHTTPStatusIsUnauthoriazed),
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
testBothEndpoints(t, tc.getTestFunction(rest))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitMaxHeaderSize(t *testing.T) {
|
||||||
|
const maxHeaderBytes = 4 * DefaultMaxHeaderBytes
|
||||||
|
cfg := &Config{}
|
||||||
|
cfg.Default()
|
||||||
|
cfg.MaxHeaderBytes = maxHeaderBytes
|
||||||
|
ctx := context.Background()
|
||||||
|
rest := testAPIwithConfig(t, cfg, "http with maxHeaderBytes")
|
||||||
|
defer rest.Shutdown(ctx)
|
||||||
|
|
||||||
|
for _, tc := range []httpTestcase{
|
||||||
|
httpTestcase{
|
||||||
|
method: "GET",
|
||||||
|
path: "/foo",
|
||||||
|
shaper: makeLongHeaderShaper(maxHeaderBytes * 2),
|
||||||
|
checker: assertHTTPStatusIsTooLarge,
|
||||||
|
},
|
||||||
|
httpTestcase{
|
||||||
|
method: "GET",
|
||||||
|
path: "/foo",
|
||||||
|
shaper: makeLongHeaderShaper(maxHeaderBytes / 2),
|
||||||
|
checker: makeInvertedHTTPStatusAssert(assertHTTPStatusIsTooLarge),
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
testBothEndpoints(t, tc.getTestFunction(rest))
|
testBothEndpoints(t, tc.getTestFunction(rest))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user