diff --git a/api/rest/restapi.go b/api/rest/restapi.go index eaf1e532..369d1f8d 100644 --- a/api/rest/restapi.go +++ b/api/rest/restapi.go @@ -743,6 +743,12 @@ func (api *API) allocationsHandler(w http.ResponseWriter, r *http.Request) { for _, f := range strings.Split(filterStr, ",") { filter |= types.PinTypeFromString(f) } + + if filter == types.BadType { + api.sendResponse(w, http.StatusBadRequest, errors.New("invalid filter value"), nil) + return + } + var pins []*types.Pin err := api.rpcClient.CallContext( r.Context(), @@ -814,7 +820,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) { filterStr := queryValues.Get("filter") filter := types.TrackerStatusFromString(filterStr) if filter == types.TrackerStatusUndefined && filterStr != "" { - api.sendResponse(w, autoStatus, errors.New("invalid filter value"), nil) + api.sendResponse(w, http.StatusBadRequest, errors.New("invalid filter value"), nil) return } diff --git a/api/rest/restapi_test.go b/api/rest/restapi_test.go index 23551504..956b119a 100644 --- a/api/rest/restapi_test.go +++ b/api/rest/restapi_test.go @@ -748,6 +748,19 @@ func TestAPIAllocationsEndpoint(t *testing.T) { !resp[2].Cid.Equals(test.Cid3) { t.Error("unexpected pin list: ", resp) } + + makeGet(t, rest, url(rest)+"/allocations", &resp) + if len(resp) != 3 || + !resp[0].Cid.Equals(test.Cid1) || !resp[1].Cid.Equals(test.Cid2) || + !resp[2].Cid.Equals(test.Cid3) { + t.Error("unexpected pin list: ", resp) + } + + errResp := api.Error{} + makeGet(t, rest, url(rest)+"/allocations?filter=invalid", &errResp) + if errResp.Code != http.StatusBadRequest { + t.Error("an invalid filter value should 400") + } } testBothEndpoints(t, tf) @@ -851,6 +864,12 @@ func TestAPIStatusAllEndpoint(t *testing.T) { if len(resp7) != 2 { t.Errorf("unexpected statusAll+filter=error,pinned resp:\n %+v", resp7) } + + var errorResp api.Error + makeGet(t, rest, url(rest)+"/pins?filter=invalid", &errorResp) + if errorResp.Code != http.StatusBadRequest { + t.Error("an invalid filter value should 400") + } } testBothEndpoints(t, tf)