diff --git a/api/rest/restapi.go b/api/rest/restapi.go index 6ea6088e..60693349 100644 --- a/api/rest/restapi.go +++ b/api/rest/restapi.go @@ -21,9 +21,10 @@ import ( "sync" "time" + "github.com/rs/cors" + "github.com/ipfs/ipfs-cluster/adder/adderutils" types "github.com/ipfs/ipfs-cluster/api" - "github.com/rs/cors" mux "github.com/gorilla/mux" gostream "github.com/hsanjuan/go-libp2p-gostream" @@ -109,15 +110,20 @@ func NewAPIWithHost(cfg *Config, h host.Host) (*API, error) { return nil, err } + // Our handler is a gorilla router, + // wrapped with the cors handler, + // wrapped with the basic auth handler. router := mux.NewRouter().StrictSlash(true) - c := cors.New(*cfg.corsOptions()) - withCorsRouter := c.Handler(router) + handler := basicAuthHandler( + cfg.BasicAuthCreds, + cors.New(*cfg.corsOptions()).Handler(router), + ) s := &http.Server{ ReadTimeout: cfg.ReadTimeout, ReadHeaderTimeout: cfg.ReadHeaderTimeout, WriteTimeout: cfg.WriteTimeout, IdleTimeout: cfg.IdleTimeout, - Handler: withCorsRouter, + Handler: handler, } // See: https://github.com/ipfs/go-ipfs/issues/5168 @@ -228,9 +234,6 @@ func (api *API) Host() host.Host { func (api *API) addRoutes(router *mux.Router) { for _, route := range api.routes() { - if api.config.BasicAuthCreds != nil { - route.HandlerFunc = basicAuth(route.HandlerFunc, api.config.BasicAuthCreds) - } router. Methods(route.Method). Path(route.Pattern). @@ -240,8 +243,13 @@ func (api *API) addRoutes(router *mux.Router) { api.router = router } -func basicAuth(h http.HandlerFunc, credentials map[string]string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +// basicAuth wraps a given handler with basic authentication +func basicAuthHandler(credentials map[string]string, h http.Handler) http.Handler { + if credentials == nil { + return h + } + + wrap := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) username, password, ok := r.BasicAuth() if !ok { @@ -271,6 +279,7 @@ func basicAuth(h http.HandlerFunc, credentials map[string]string) http.HandlerFu } h.ServeHTTP(w, r) } + return http.HandlerFunc(wrap) } func unauthorizedResp() (string, error) {