Restapi: Add configurable response headers

By default, CORS headers allowing GET requests from everywhere are
set. This should facilitate the IPFS Web UI integration with the
Cluster API.

This commit refactors the sendResponse methods in the API, merging
them into one as it was difficult to follow the flows that actually
send something to the client. All tests now check the presence of
the configured headers too, to make sure no route was missed.

License: MIT
Signed-off-by: Hector Sanjuan <code@hector.link>
This commit is contained in:
Hector Sanjuan 2018-10-16 15:23:06 +02:00
parent f65349e9c8
commit 322e87dd59
4 changed files with 148 additions and 90 deletions

View File

@ -27,6 +27,15 @@ const (
DefaultIdleTimeout = 120 * time.Second
)
// These are the default values for Config.
var (
DefaultHeaders = map[string][]string{
"Access-Control-Allow-Headers": []string{"X-Requested-With", "Range"},
"Access-Control-Allow-Methods": []string{"GET"},
"Access-Control-Allow-Origin": []string{"*"},
}
)
// Config is used to intialize the API object and allows to
// customize the behaviour of it. It implements the config.ComponentConfig
// interface.
@ -71,6 +80,10 @@ type Config struct {
// BasicAuthCreds is a map of username-password pairs
// which are authorized to use Basic Authentication
BasicAuthCreds map[string]string
// Headers provides customization for the headers returned
// by the API. By default it sets a CORS policy.
Headers map[string][]string
}
type jsonConfig struct {
@ -87,7 +100,8 @@ type jsonConfig struct {
ID string `json:"id,omitempty"`
PrivateKey string `json:"private_key,omitempty"`
BasicAuthCreds map[string]string `json:"basic_auth_credentials"`
BasicAuthCreds map[string]string `json:"basic_auth_credentials"`
Headers map[string][]string `json:"headers"`
}
// ConfigKey returns a human-friendly identifier for this type of
@ -116,6 +130,9 @@ func (cfg *Config) Default() error {
// Auth
cfg.BasicAuthCreds = nil
// Headers
cfg.Headers = DefaultHeaders
return nil
}
@ -177,6 +194,7 @@ func (cfg *Config) LoadJSON(raw []byte) error {
// Other options
cfg.BasicAuthCreds = jcfg.BasicAuthCreds
cfg.Headers = jcfg.Headers
return cfg.Validate()
}
@ -295,6 +313,7 @@ func (cfg *Config) ToJSON() (raw []byte, err error) {
WriteTimeout: cfg.WriteTimeout.String(),
IdleTimeout: cfg.IdleTimeout.String(),
BasicAuthCreds: cfg.BasicAuthCreds,
Headers: cfg.Headers,
}
if cfg.ID != "" {

View File

@ -55,6 +55,9 @@ var (
ErrHTTPEndpointNotEnabled = errors.New("the HTTP endpoint is not enabled")
)
// Used by sendResponse to set the right status
const autoStatus = -1
// For making a random sharding ID
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
@ -479,7 +482,7 @@ func (api *API) idHandler(w http.ResponseWriter, r *http.Request) {
struct{}{},
&idSerial)
sendResponse(w, err, idSerial)
api.sendResponse(w, autoStatus, err, idSerial)
}
func (api *API) versionHandler(w http.ResponseWriter, r *http.Request) {
@ -490,7 +493,7 @@ func (api *API) versionHandler(w http.ResponseWriter, r *http.Request) {
struct{}{},
&v)
sendResponse(w, err, v)
api.sendResponse(w, autoStatus, err, v)
}
func (api *API) graphHandler(w http.ResponseWriter, r *http.Request) {
@ -500,22 +503,24 @@ func (api *API) graphHandler(w http.ResponseWriter, r *http.Request) {
"ConnectGraph",
struct{}{},
&graph)
sendResponse(w, err, graph)
api.sendResponse(w, autoStatus, err, graph)
}
func (api *API) addHandler(w http.ResponseWriter, r *http.Request) {
reader, err := r.MultipartReader()
if err != nil {
sendErrorResponse(w, http.StatusBadRequest, err.Error())
api.sendResponse(w, http.StatusBadRequest, err, nil)
return
}
params, err := types.AddParamsFromQuery(r.URL.Query())
if err != nil {
sendErrorResponse(w, http.StatusBadRequest, err.Error())
api.sendResponse(w, http.StatusBadRequest, err, nil)
return
}
api.setHeaders(w)
// any errors sent as trailer
adderutils.AddMultipartHTTPHandler(
api.ctx,
@ -537,7 +542,7 @@ func (api *API) peerListHandler(w http.ResponseWriter, r *http.Request) {
struct{}{},
&peersSerial)
sendResponse(w, err, peersSerial)
api.sendResponse(w, autoStatus, err, peersSerial)
}
func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) {
@ -547,13 +552,13 @@ func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) {
var addInfo peerAddBody
err := dec.Decode(&addInfo)
if err != nil {
sendErrorResponse(w, 400, "error decoding request body")
api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding request body"), nil)
return
}
_, err = peer.IDB58Decode(addInfo.PeerID)
if err != nil {
sendErrorResponse(w, 400, "error decoding peer_id")
api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding peer_id"), nil)
return
}
@ -563,22 +568,22 @@ func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) {
"PeerAdd",
addInfo.PeerID,
&ids)
sendResponse(w, err, ids)
api.sendResponse(w, autoStatus, err, ids)
}
func (api *API) peerRemoveHandler(w http.ResponseWriter, r *http.Request) {
if p := parsePidOrError(w, r); p != "" {
if p := api.parsePidOrError(w, r); p != "" {
err := api.rpcClient.Call("",
"Cluster",
"PeerRemove",
p,
&struct{}{})
sendEmptyResponse(w, err)
api.sendResponse(w, autoStatus, err, nil)
}
}
func (api *API) pinHandler(w http.ResponseWriter, r *http.Request) {
if ps := parseCidOrError(w, r); ps.Cid != "" {
if ps := api.parseCidOrError(w, r); ps.Cid != "" {
logger.Debugf("rest api pinHandler: %s", ps.Cid)
err := api.rpcClient.Call("",
@ -586,20 +591,20 @@ func (api *API) pinHandler(w http.ResponseWriter, r *http.Request) {
"Pin",
ps,
&struct{}{})
sendAcceptedResponse(w, err)
api.sendResponse(w, http.StatusAccepted, err, nil)
logger.Debug("rest api pinHandler done")
}
}
func (api *API) unpinHandler(w http.ResponseWriter, r *http.Request) {
if ps := parseCidOrError(w, r); ps.Cid != "" {
if ps := api.parseCidOrError(w, r); ps.Cid != "" {
logger.Debugf("rest api unpinHandler: %s", ps.Cid)
err := api.rpcClient.Call("",
"Cluster",
"Unpin",
ps,
&struct{}{})
sendAcceptedResponse(w, err)
api.sendResponse(w, http.StatusAccepted, err, nil)
logger.Debug("rest api unpinHandler done")
}
}
@ -626,11 +631,11 @@ func (api *API) allocationsHandler(w http.ResponseWriter, r *http.Request) {
outPins = append(outPins, pinS)
}
}
sendResponse(w, err, outPins)
api.sendResponse(w, autoStatus, err, outPins)
}
func (api *API) allocationHandler(w http.ResponseWriter, r *http.Request) {
if ps := parseCidOrError(w, r); ps.Cid != "" {
if ps := api.parseCidOrError(w, r); ps.Cid != "" {
var pin types.PinSerial
err := api.rpcClient.Call("",
"Cluster",
@ -638,10 +643,10 @@ func (api *API) allocationHandler(w http.ResponseWriter, r *http.Request) {
ps,
&pin)
if err != nil { // errors here are 404s
sendErrorResponse(w, 404, err.Error())
api.sendResponse(w, http.StatusNotFound, err, nil)
return
}
sendJSONResponse(w, 200, pin)
api.sendResponse(w, autoStatus, nil, pin)
}
}
@ -656,7 +661,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) {
"StatusAllLocal",
struct{}{},
&pinInfos)
sendResponse(w, err, pinInfosToGlobal(pinInfos))
api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos))
} else {
var pinInfos []types.GlobalPinInfoSerial
err := api.rpcClient.Call("",
@ -664,7 +669,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) {
"StatusAll",
struct{}{},
&pinInfos)
sendResponse(w, err, pinInfos)
api.sendResponse(w, autoStatus, err, pinInfos)
}
}
@ -672,7 +677,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) {
queryValues := r.URL.Query()
local := queryValues.Get("local")
if ps := parseCidOrError(w, r); ps.Cid != "" {
if ps := api.parseCidOrError(w, r); ps.Cid != "" {
if local == "true" {
var pinInfo types.PinInfoSerial
err := api.rpcClient.Call("",
@ -680,7 +685,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) {
"StatusLocal",
ps,
&pinInfo)
sendResponse(w, err, pinInfoToGlobal(pinInfo))
api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo))
} else {
var pinInfo types.GlobalPinInfoSerial
err := api.rpcClient.Call("",
@ -688,7 +693,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) {
"Status",
ps,
&pinInfo)
sendResponse(w, err, pinInfo)
api.sendResponse(w, autoStatus, err, pinInfo)
}
}
}
@ -704,7 +709,7 @@ func (api *API) syncAllHandler(w http.ResponseWriter, r *http.Request) {
"SyncAllLocal",
struct{}{},
&pinInfos)
sendResponse(w, err, pinInfosToGlobal(pinInfos))
api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos))
} else {
var pinInfos []types.GlobalPinInfoSerial
err := api.rpcClient.Call("",
@ -712,7 +717,7 @@ func (api *API) syncAllHandler(w http.ResponseWriter, r *http.Request) {
"SyncAll",
struct{}{},
&pinInfos)
sendResponse(w, err, pinInfos)
api.sendResponse(w, autoStatus, err, pinInfos)
}
}
@ -720,7 +725,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) {
queryValues := r.URL.Query()
local := queryValues.Get("local")
if ps := parseCidOrError(w, r); ps.Cid != "" {
if ps := api.parseCidOrError(w, r); ps.Cid != "" {
if local == "true" {
var pinInfo types.PinInfoSerial
err := api.rpcClient.Call("",
@ -728,7 +733,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) {
"SyncLocal",
ps,
&pinInfo)
sendResponse(w, err, pinInfoToGlobal(pinInfo))
api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo))
} else {
var pinInfo types.GlobalPinInfoSerial
err := api.rpcClient.Call("",
@ -736,7 +741,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) {
"Sync",
ps,
&pinInfo)
sendResponse(w, err, pinInfo)
api.sendResponse(w, autoStatus, err, pinInfo)
}
}
}
@ -751,9 +756,9 @@ func (api *API) recoverAllHandler(w http.ResponseWriter, r *http.Request) {
"RecoverAllLocal",
struct{}{},
&pinInfos)
sendResponse(w, err, pinInfosToGlobal(pinInfos))
api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos))
} else {
sendErrorResponse(w, 400, "only requests with parameter local=true are supported")
api.sendResponse(w, http.StatusBadRequest, errors.New("only requests with parameter local=true are supported"), nil)
}
}
@ -761,7 +766,7 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) {
queryValues := r.URL.Query()
local := queryValues.Get("local")
if ps := parseCidOrError(w, r); ps.Cid != "" {
if ps := api.parseCidOrError(w, r); ps.Cid != "" {
if local == "true" {
var pinInfo types.PinInfoSerial
err := api.rpcClient.Call("",
@ -769,7 +774,7 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) {
"RecoverLocal",
ps,
&pinInfo)
sendResponse(w, err, pinInfoToGlobal(pinInfo))
api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo))
} else {
var pinInfo types.GlobalPinInfoSerial
err := api.rpcClient.Call("",
@ -777,18 +782,18 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) {
"Recover",
ps,
&pinInfo)
sendResponse(w, err, pinInfo)
api.sendResponse(w, autoStatus, err, pinInfo)
}
}
}
func parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial {
func (api *API) parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial {
vars := mux.Vars(r)
hash := vars["hash"]
_, err := cid.Decode(hash)
if err != nil {
sendErrorResponse(w, 400, "error decoding Cid: "+err.Error())
api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding Cid: "+err.Error()), nil)
return types.PinSerial{Cid: ""}
}
@ -827,12 +832,12 @@ func parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial {
return pin
}
func parsePidOrError(w http.ResponseWriter, r *http.Request) peer.ID {
func (api *API) parsePidOrError(w http.ResponseWriter, r *http.Request) peer.ID {
vars := mux.Vars(r)
idStr := vars["peer"]
pid, err := peer.IDB58Decode(idStr)
if err != nil {
sendErrorResponse(w, 400, "error decoding Peer ID: "+err.Error())
api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding Peer ID: "+err.Error()), nil)
return ""
}
return pid
@ -855,64 +860,70 @@ func pinInfosToGlobal(pInfos []types.PinInfoSerial) []types.GlobalPinInfoSerial
return gPInfos
}
func sendResponse(w http.ResponseWriter, err error, resp interface{}) {
if checkErr(w, err) {
sendJSONResponse(w, 200, resp)
}
}
// sendResponse wraps all the logic for writing the response to a request:
// * Write configured headers
// * Write application/json content type
// * Write status: determined automatically if given 0
// * Write an error if there is or write the response if there is
func (api *API) sendResponse(
w http.ResponseWriter,
status int,
err error,
resp interface{},
) {
// checkErr takes care of returning standard error responses if we
// pass an error to it. It returns true when everythings OK (no error
// was handled), or false otherwise.
func checkErr(w http.ResponseWriter, err error) bool {
api.setHeaders(w)
enc := json.NewEncoder(w)
// Send an error
if err != nil {
sendErrorResponse(w, http.StatusInternalServerError, err.Error())
return false
}
return true
}
if status == autoStatus || status < 400 { // set a default error status
status = http.StatusInternalServerError
}
w.WriteHeader(status)
func sendEmptyResponse(w http.ResponseWriter, err error) {
if checkErr(w, err) {
w.WriteHeader(http.StatusNoContent)
}
}
errorResp := types.Error{
Code: status,
Message: err.Error(),
}
logger.Errorf("sending error response: %d: %s", status, err.Error())
func sendAcceptedResponse(w http.ResponseWriter, err error) {
if checkErr(w, err) {
w.WriteHeader(http.StatusAccepted)
}
}
func sendJSONResponse(w http.ResponseWriter, code int, resp interface{}) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(resp); err != nil {
logger.Error(err)
}
}
func sendErrorResponse(w http.ResponseWriter, code int, msg string) {
errorResp := types.Error{
Code: code,
Message: msg,
}
logger.Errorf("sending error response: %d: %s", code, msg)
sendJSONResponse(w, code, errorResp)
}
func sendStreamResponse(w http.ResponseWriter, err error, resp <-chan interface{}) {
if !checkErr(w, err) {
if err := enc.Encode(errorResp); err != nil {
logger.Error(err)
}
return
}
enc := json.NewEncoder(w)
w.Header().Add("Content-Type", "application/octet-stream")
w.WriteHeader(http.StatusOK)
for v := range resp {
err := enc.Encode(v)
if err != nil {
// Send a body
if resp != nil {
if status == autoStatus {
status = http.StatusOK
}
w.WriteHeader(status)
if err = enc.Encode(resp); err != nil {
logger.Error(err)
}
return
}
// Empty response
if status == autoStatus {
status = http.StatusNoContent
}
w.WriteHeader(status)
}
// this sets all the headers that are common to all responses
// from this API. Called from sendResponse() and /add.
func (api *API) setHeaders(w http.ResponseWriter) {
for header, values := range api.config.Headers {
for _, val := range values {
w.Header().Add(header, val)
}
}
w.Header().Add("Content-Type", "application/json")
}

View File

@ -124,6 +124,17 @@ func processStreamingResp(t *testing.T, httpResp *http.Response, err error, resp
}
}
func checkHeaders(t *testing.T, rest *API, url string, headers http.Header) {
for k, v := range rest.config.Headers {
if strings.Join(v, ",") != strings.Join(headers[k], ",") {
t.Errorf("%s does not show configured headers: %s", url, k)
}
}
if headers.Get("Content-Type") != "application/json" {
t.Errorf("%s is not application/json", url)
}
}
// makes a libp2p host that knows how to talk to the rest API host.
func makeHost(t *testing.T, rest *API) host.Host {
h, err := libp2p.New(context.Background())
@ -185,6 +196,7 @@ func makeGet(t *testing.T, rest *API, url string, resp interface{}) {
c := httpClient(t, h, isHTTPS(url))
httpResp, err := c.Get(url)
processResp(t, httpResp, err, resp)
checkHeaders(t, rest, url, httpResp.Header)
}
func makePost(t *testing.T, rest *API, url string, body []byte, resp interface{}) {
@ -193,6 +205,7 @@ func makePost(t *testing.T, rest *API, url string, body []byte, resp interface{}
c := httpClient(t, h, isHTTPS(url))
httpResp, err := c.Post(url, "application/json", bytes.NewReader(body))
processResp(t, httpResp, err, resp)
checkHeaders(t, rest, url, httpResp.Header)
}
func makeDelete(t *testing.T, rest *API, url string, resp interface{}) {
@ -202,6 +215,7 @@ func makeDelete(t *testing.T, rest *API, url string, resp interface{}) {
req, _ := http.NewRequest("DELETE", url, bytes.NewReader([]byte{}))
httpResp, err := c.Do(req)
processResp(t, httpResp, err, resp)
checkHeaders(t, rest, url, httpResp.Header)
}
func makeStreamingPost(t *testing.T, rest *API, url string, body io.Reader, contentType string, resp interface{}) {
@ -210,6 +224,7 @@ func makeStreamingPost(t *testing.T, rest *API, url string, body io.Reader, cont
c := httpClient(t, h, isHTTPS(url))
httpResp, err := c.Post(url, contentType, body)
processStreamingResp(t, httpResp, err, resp)
checkHeaders(t, rest, url, httpResp.Header)
}
type testF func(t *testing.T, url urlF)
@ -251,6 +266,7 @@ func TestRestAPIIDEndpoint(t *testing.T) {
rest := testAPI(t)
httpsrest := testHTTPSAPI(t)
defer rest.Shutdown()
defer httpsrest.Shutdown()
tf := func(t *testing.T, url urlF) {
id := api.IDSerial{}

View File

@ -50,7 +50,19 @@ var testingAPICfg = []byte(`{
"read_timeout": "0",
"read_header_timeout": "5s",
"write_timeout": "0",
"idle_timeout": "2m0s"
"idle_timeout": "2m0s",
"headers": {
"Access-Control-Allow-Headers": [
"X-Requested-With",
"Range"
],
"Access-Control-Allow-Methods": [
"GET"
],
"Access-Control-Allow-Origin": [
"*"
]
}
}`)
var testingIpfsCfg = []byte(`{