ipfsconn/ipfshttp: handle cid args passed in url path correctly

The extractCid function was added to enable the extraction of
a cid argument from either the url path or query string.
This puts the proxy behaviour on par with the current IPFS API.
The function does rely on the fact that ipfs-cluster doesn't
intercept any command that has more than one subcommand.
If that changes, this function will have to be updated.

License: MIT
Signed-off-by: Adrian Lanzafame <adrianlanzafame92@gmail.com>
This commit is contained in:
Adrian Lanzafame 2018-04-20 14:09:07 +10:00
parent dbcc5c2fde
commit 22ec210c25
No known key found for this signature in database
GPG Key ID: 87E40C5D62EAE192
3 changed files with 440 additions and 214 deletions

View File

@ -155,11 +155,13 @@ func NewConnector(cfg *Config) (*Connector, error) {
client: c, client: c,
} }
smux.HandleFunc("/", ipfs.handle) smux.HandleFunc("/", ipfs.defaultHandler)
ipfs.handlers["/api/v0/pin/add"] = ipfs.pinHandler smux.HandleFunc("/api/v0/pin/add/", ipfs.pinHandler)
ipfs.handlers["/api/v0/pin/rm"] = ipfs.unpinHandler smux.HandleFunc("/api/v0/pin/rm/", ipfs.unpinHandler)
ipfs.handlers["/api/v0/pin/ls"] = ipfs.pinLsHandler smux.HandleFunc("/api/v0/pin/ls", ipfs.pinLsHandler) // to handle /pin/ls for all pins
ipfs.handlers["/api/v0/add"] = ipfs.addHandler smux.HandleFunc("/api/v0/pin/ls/", ipfs.pinLsHandler)
smux.HandleFunc("/api/v0/add", ipfs.addHandler)
smux.HandleFunc("/api/v0/add/", ipfs.addHandler)
go ipfs.run() go ipfs.run()
return ipfs, nil return ipfs, nil
@ -179,9 +181,11 @@ func (ipfs *Connector) run() {
ipfs.wg.Add(1) ipfs.wg.Add(1)
go func() { go func() {
defer ipfs.wg.Done() defer ipfs.wg.Done()
logger.Infof("IPFS Proxy: %s -> %s", logger.Infof(
"IPFS Proxy: %s -> %s",
ipfs.config.ProxyAddr, ipfs.config.ProxyAddr,
ipfs.config.NodeAddr) ipfs.config.NodeAddr,
)
err := ipfs.server.Serve(ipfs.listener) // hangs here err := ipfs.server.Serve(ipfs.listener) // hangs here
if err != nil && !strings.Contains(err.Error(), "closed network connection") { if err != nil && !strings.Contains(err.Error(), "closed network connection") {
logger.Error(err) logger.Error(err)
@ -209,17 +213,6 @@ func (ipfs *Connector) run() {
}() }()
} }
// This will run a custom handler if we have one for a URL.Path, or
// otherwise just proxy the requests.
func (ipfs *Connector) handle(w http.ResponseWriter, r *http.Request) {
if customHandler, ok := ipfs.handlers[r.URL.Path]; ok {
customHandler(w, r)
} else {
ipfs.defaultHandler(w, r)
}
}
func (ipfs *Connector) proxyRequest(r *http.Request) (*http.Response, error) { func (ipfs *Connector) proxyRequest(r *http.Request) (*http.Response, error) {
newURL := *r.URL newURL := *r.URL
newURL.Host = ipfs.nodeAddr newURL.Host = ipfs.nodeAddr
@ -282,26 +275,26 @@ func ipfsErrorResponder(w http.ResponseWriter, errMsg string) {
} }
func (ipfs *Connector) pinOpHandler(op string, w http.ResponseWriter, r *http.Request) { func (ipfs *Connector) pinOpHandler(op string, w http.ResponseWriter, r *http.Request) {
argA := r.URL.Query()["arg"] arg, ok := extractArgument(r.URL)
if len(argA) == 0 { if !ok {
ipfsErrorResponder(w, "Error: bad argument") ipfsErrorResponder(w, "Error: bad argument")
return return
} }
arg := argA[0]
_, err := cid.Decode(arg) _, err := cid.Decode(arg)
if err != nil { if err != nil {
ipfsErrorResponder(w, "Error parsing CID: "+err.Error()) ipfsErrorResponder(w, "Error parsing CID: "+err.Error())
return return
} }
err = ipfs.rpcClient.Call("", err = ipfs.rpcClient.Call(
"",
"Cluster", "Cluster",
op, op,
api.PinSerial{ api.PinSerial{
Cid: arg, Cid: arg,
}, },
&struct{}{}) &struct{}{},
)
if err != nil { if err != nil {
ipfsErrorResponder(w, err.Error()) ipfsErrorResponder(w, err.Error())
return return
@ -329,24 +322,23 @@ func (ipfs *Connector) pinLsHandler(w http.ResponseWriter, r *http.Request) {
pinLs := ipfsPinLsResp{} pinLs := ipfsPinLsResp{}
pinLs.Keys = make(map[string]ipfsPinType) pinLs.Keys = make(map[string]ipfsPinType)
q := r.URL.Query() arg, ok := extractArgument(r.URL)
arg := q.Get("arg") if ok {
if arg != "" {
c, err := cid.Decode(arg) c, err := cid.Decode(arg)
if err != nil { if err != nil {
ipfsErrorResponder(w, err.Error()) ipfsErrorResponder(w, err.Error())
return return
} }
var pin api.PinSerial var pin api.PinSerial
err = ipfs.rpcClient.Call("", err = ipfs.rpcClient.Call(
"",
"Cluster", "Cluster",
"PinGet", "PinGet",
api.PinCid(c).ToSerial(), api.PinCid(c).ToSerial(),
&pin) &pin,
)
if err != nil { if err != nil {
ipfsErrorResponder(w, fmt.Sprintf( ipfsErrorResponder(w, fmt.Sprintf("Error: path '%s' is not pinned", arg))
"Error: path '%s' is not pinned",
arg))
return return
} }
pinLs.Keys[pin.Cid] = ipfsPinType{ pinLs.Keys[pin.Cid] = ipfsPinType{
@ -354,12 +346,13 @@ func (ipfs *Connector) pinLsHandler(w http.ResponseWriter, r *http.Request) {
} }
} else { } else {
var pins []api.PinSerial var pins []api.PinSerial
err := ipfs.rpcClient.Call("", err := ipfs.rpcClient.Call(
"",
"Cluster", "Cluster",
"Pins", "Pins",
struct{}{}, struct{}{},
&pins) &pins,
)
if err != nil { if err != nil {
ipfsErrorResponder(w, err.Error()) ipfsErrorResponder(w, err.Error())
return return
@ -455,13 +448,15 @@ func (ipfs *Connector) addHandler(w http.ResponseWriter, r *http.Request) {
logger.Debugf("proxy /add request and will pin %s", pinHashes) logger.Debugf("proxy /add request and will pin %s", pinHashes)
for _, pin := range pinHashes { for _, pin := range pinHashes {
err := ipfs.rpcClient.Call("", err := ipfs.rpcClient.Call(
"",
"Cluster", "Cluster",
"Pin", "Pin",
api.PinSerial{ api.PinSerial{
Cid: pin, Cid: pin,
}, },
&struct{}{}) &struct{}{},
)
if err != nil { if err != nil {
// we need to fail the operation and make sure the // we need to fail the operation and make sure the
// user knows about it. // user knows about it.
@ -782,11 +777,13 @@ func (ipfs *Connector) apiURL() string {
// triggers ipfs swarm connect requests // triggers ipfs swarm connect requests
func (ipfs *Connector) ConnectSwarms() error { func (ipfs *Connector) ConnectSwarms() error {
var idsSerial []api.IDSerial var idsSerial []api.IDSerial
err := ipfs.rpcClient.Call("", err := ipfs.rpcClient.Call(
"",
"Cluster", "Cluster",
"Peers", "Peers",
struct{}{}, struct{}{},
&idsSerial) &idsSerial,
)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return err return err
@ -799,8 +796,7 @@ func (ipfs *Connector) ConnectSwarms() error {
// This is a best effort attempt // This is a best effort attempt
// We ignore errors which happens // We ignore errors which happens
// when passing in a bunch of addresses // when passing in a bunch of addresses
_, err := ipfs.post( _, err := ipfs.post(fmt.Sprintf("swarm/connect?arg=%s", addr))
fmt.Sprintf("swarm/connect?arg=%s", addr))
if err != nil { if err != nil {
logger.Debug(err) logger.Debug(err)
continue continue
@ -918,3 +914,24 @@ func (ipfs *Connector) SwarmPeers() (api.SwarmPeers, error) {
} }
return swarm, nil return swarm, nil
} }
// extractArgument extracts the cid argument from a url.URL, either via
// the query string parameters or from the url path itself.
func extractArgument(u *url.URL) (string, bool) {
arg := u.Query().Get("arg")
if arg != "" {
return arg, true
}
p := strings.TrimPrefix(u.Path, "/api/v0/")
segs := strings.Split(p, "/")
if len(segs) > 2 {
warnMsg := "You are using an undocumented form of the IPFS API."
warnMsg += "Consider passing your command arguments"
warnMsg += "with the '?arg=' query parameter"
logger.Warning(warnMsg)
return segs[len(segs)-1], true
}
return "", false
}

View File

@ -11,12 +11,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/ipfs/ipfs-cluster/api"
"github.com/ipfs/ipfs-cluster/test"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/ipfs/ipfs-cluster/api"
"github.com/ipfs/ipfs-cluster/test"
) )
func init() { func init() {
@ -199,48 +199,100 @@ func TestIPFSProxyPin(t *testing.T) {
defer mock.Close() defer mock.Close()
defer ipfs.Shutdown() defer ipfs.Shutdown()
res, err := http.Post(fmt.Sprintf("%s/pin/add?arg=%s", proxyURL(ipfs), test.TestCid1), "", nil) type args struct {
if err != nil { urlPath string
t.Fatal("should have succeeded: ", err) testCid string
statusCode int
} }
defer res.Body.Close() tests := []struct {
name string
if res.StatusCode != http.StatusOK { args args
t.Error("the request should have succeeded") want string
wantErr bool
}{
{
"pin good cid query arg",
args{
"/pin/add?arg=",
test.TestCid1,
http.StatusOK,
},
test.TestCid1,
false,
},
{
"pin good cid url arg",
args{
"/pin/add/",
test.TestCid1,
http.StatusOK,
},
test.TestCid1,
false,
},
{
"pin bad cid query arg",
args{
"/pin/add?arg=",
test.ErrorCid,
http.StatusInternalServerError,
},
"",
true,
},
{
"pin bad cid url arg",
args{
"/pin/add/",
test.ErrorCid,
http.StatusInternalServerError,
},
"",
true,
},
} }
resBytes, _ := ioutil.ReadAll(res.Body) for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := fmt.Sprintf("%s%s%s", proxyURL(ipfs), tt.args.urlPath, tt.args.testCid)
res, err := http.Post(u, "", nil)
if err != nil {
t.Fatal("should have succeeded: ", err)
}
defer res.Body.Close()
var resp ipfsPinOpResp if res.StatusCode != tt.args.statusCode {
err = json.Unmarshal(resBytes, &resp) t.Errorf("statusCode: got = %v, want %v", res.StatusCode, tt.args.statusCode)
if err != nil { }
t.Fatal(err)
}
if len(resp.Pins) != 1 || resp.Pins[0] != test.TestCid1 { resBytes, _ := ioutil.ReadAll(res.Body)
t.Error("wrong response")
}
// Try with a bad cid switch tt.wantErr {
res2, err := http.Post(fmt.Sprintf("%s/pin/add?arg=%s", proxyURL(ipfs), test.ErrorCid), "", nil) case false:
if err != nil { var resp ipfsPinOpResp
t.Fatal("request should work: ", err) err = json.Unmarshal(resBytes, &resp)
} if err != nil {
defer res2.Body.Close() t.Fatal(err)
}
t.Log(fmt.Sprintf("%s/pin/add?arg=%s", proxyURL(ipfs), test.ErrorCid)) if len(resp.Pins) != 1 {
if res2.StatusCode != http.StatusInternalServerError { t.Fatalf("wrong number of pins: got = %d, want %d", len(resp.Pins), 1)
t.Error("the request should return with InternalServerError") }
}
resBytes, _ = ioutil.ReadAll(res2.Body) if resp.Pins[0] != tt.want {
var respErr ipfsError t.Errorf("wrong pin cid: got = %s, want = %s", resp.Pins[0], tt.want)
err = json.Unmarshal(resBytes, &respErr) }
if err != nil { case true:
t.Fatal(err) var respErr ipfsError
} err = json.Unmarshal(resBytes, &respErr)
if err != nil {
t.Fatal(err)
}
if respErr.Message != test.ErrBadCid.Error() { if respErr.Message != test.ErrBadCid.Error() {
t.Error("wrong response") t.Errorf("wrong response: got = %s, want = %s", respErr.Message, test.ErrBadCid.Error())
}
}
})
} }
} }
@ -249,48 +301,100 @@ func TestIPFSProxyUnpin(t *testing.T) {
defer mock.Close() defer mock.Close()
defer ipfs.Shutdown() defer ipfs.Shutdown()
res, err := http.Post(fmt.Sprintf("%s/pin/rm?arg=%s", proxyURL(ipfs), test.TestCid1), "", nil) type args struct {
if err != nil { urlPath string
t.Fatal("should have succeeded: ", err) testCid string
statusCode int
} }
defer res.Body.Close() tests := []struct {
name string
if res.StatusCode != http.StatusOK { args args
t.Error("the request should have succeeded") want string
wantErr bool
}{
{
"unpin good cid query arg",
args{
"/pin/rm?arg=",
test.TestCid1,
http.StatusOK,
},
test.TestCid1,
false,
},
{
"unpin good cid url arg",
args{
"/pin/rm/",
test.TestCid1,
http.StatusOK,
},
test.TestCid1,
false,
},
{
"unpin bad cid query arg",
args{
"/pin/rm?arg=",
test.ErrorCid,
http.StatusInternalServerError,
},
"",
true,
},
{
"unpin bad cid url arg",
args{
"/pin/rm/",
test.ErrorCid,
http.StatusInternalServerError,
},
"",
true,
},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := fmt.Sprintf("%s%s%s", proxyURL(ipfs), tt.args.urlPath, tt.args.testCid)
res, err := http.Post(u, "", nil)
if err != nil {
t.Fatal("should have succeeded: ", err)
}
defer res.Body.Close()
resBytes, _ := ioutil.ReadAll(res.Body) if res.StatusCode != tt.args.statusCode {
t.Errorf("statusCode: got = %v, want %v", res.StatusCode, tt.args.statusCode)
}
var resp ipfsPinOpResp resBytes, _ := ioutil.ReadAll(res.Body)
err = json.Unmarshal(resBytes, &resp)
if err != nil {
t.Fatal(err)
}
if len(resp.Pins) != 1 || resp.Pins[0] != test.TestCid1 { switch tt.wantErr {
t.Error("wrong response") case false:
} var resp ipfsPinOpResp
err = json.Unmarshal(resBytes, &resp)
if err != nil {
t.Fatal(err)
}
// Try with a bad cid if len(resp.Pins) != 1 {
res2, err := http.Post(fmt.Sprintf("%s/pin/rm?arg=%s", proxyURL(ipfs), test.ErrorCid), "", nil) t.Fatalf("wrong number of pins: got = %d, want %d", len(resp.Pins), 1)
if err != nil { }
t.Fatal("request should work: ", err)
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusInternalServerError { if resp.Pins[0] != tt.want {
t.Error("the request should return with InternalServerError") t.Errorf("wrong pin cid: got = %s, want = %s", resp.Pins[0], tt.want)
} }
case true:
var respErr ipfsError
err = json.Unmarshal(resBytes, &respErr)
if err != nil {
t.Fatal(err)
}
resBytes, _ = ioutil.ReadAll(res2.Body) if respErr.Message != test.ErrBadCid.Error() {
var respErr ipfsError t.Errorf("wrong response: got = %s, want = %s", respErr.Message, test.ErrBadCid.Error())
err = json.Unmarshal(resBytes, &respErr) }
if err != nil { }
t.Fatal(err) })
}
if respErr.Message != test.ErrBadCid.Error() {
t.Error("wrong response")
} }
} }
@ -299,55 +403,84 @@ func TestIPFSProxyPinLs(t *testing.T) {
defer mock.Close() defer mock.Close()
defer ipfs.Shutdown() defer ipfs.Shutdown()
res, err := http.Post(fmt.Sprintf("%s/pin/ls?arg=%s", proxyURL(ipfs), test.TestCid1), "", nil) t.Run("pin/ls query arg", func(t *testing.T) {
if err != nil { res, err := http.Post(fmt.Sprintf("%s/pin/ls?arg=%s", proxyURL(ipfs), test.TestCid1), "", nil)
t.Fatal("should have succeeded: ", err) if err != nil {
} t.Fatal("should have succeeded: ", err)
defer res.Body.Close() }
if res.StatusCode != http.StatusOK { defer res.Body.Close()
t.Error("the request should have succeeded") if res.StatusCode != http.StatusOK {
} t.Error("the request should have succeeded")
}
resBytes, _ := ioutil.ReadAll(res.Body) resBytes, _ := ioutil.ReadAll(res.Body)
var resp ipfsPinLsResp
err = json.Unmarshal(resBytes, &resp)
if err != nil {
t.Fatal(err)
}
var resp ipfsPinLsResp _, ok := resp.Keys[test.TestCid1]
err = json.Unmarshal(resBytes, &resp) if len(resp.Keys) != 1 || !ok {
if err != nil { t.Error("wrong response")
t.Fatal(err) }
} })
_, ok := resp.Keys[test.TestCid1] t.Run("pin/ls url arg", func(t *testing.T) {
if len(resp.Keys) != 1 || !ok { res, err := http.Post(fmt.Sprintf("%s/pin/ls/%s", proxyURL(ipfs), test.TestCid1), "", nil)
t.Error("wrong response") if err != nil {
} t.Fatal("should have succeeded: ", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Error("the request should have succeeded")
}
res2, err := http.Post(fmt.Sprintf("%s/pin/ls", proxyURL(ipfs)), "", nil) resBytes, _ := ioutil.ReadAll(res.Body)
if err != nil { var resp ipfsPinLsResp
t.Fatal("should have succeeded: ", err) err = json.Unmarshal(resBytes, &resp)
} if err != nil {
defer res2.Body.Close() t.Fatal(err)
if res2.StatusCode != http.StatusOK { }
t.Error("the request should have succeeded")
}
resBytes, _ = ioutil.ReadAll(res2.Body) _, ok := resp.Keys[test.TestCid1]
err = json.Unmarshal(resBytes, &resp) if len(resp.Keys) != 1 || !ok {
if err != nil { t.Error("wrong response")
t.Fatal(err) }
} })
if len(resp.Keys) != 3 { t.Run("pin/ls all no arg", func(t *testing.T) {
t.Error("wrong response") res2, err := http.Post(fmt.Sprintf("%s/pin/ls", proxyURL(ipfs)), "", nil)
} if err != nil {
t.Fatal("should have succeeded: ", err)
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
t.Error("the request should have succeeded")
}
res3, err := http.Post(fmt.Sprintf("%s/pin/ls?arg=%s", proxyURL(ipfs), test.ErrorCid), "", nil) resBytes, _ := ioutil.ReadAll(res2.Body)
if err != nil { var resp ipfsPinLsResp
t.Fatal("should have succeeded: ", err) err = json.Unmarshal(resBytes, &resp)
} if err != nil {
defer res3.Body.Close() t.Fatal(err)
if res3.StatusCode != http.StatusInternalServerError { }
t.Error("the request should have failed")
} if len(resp.Keys) != 3 {
t.Error("wrong response")
}
})
t.Run("pin/ls bad cid query arg", func(t *testing.T) {
res3, err := http.Post(fmt.Sprintf("%s/pin/ls?arg=%s", proxyURL(ipfs), test.ErrorCid), "", nil)
if err != nil {
t.Fatal("should have succeeded: ", err)
}
defer res3.Body.Close()
if res3.StatusCode != http.StatusInternalServerError {
t.Error("the request should have failed")
}
})
} }
func TestProxyAdd(t *testing.T) { func TestProxyAdd(t *testing.T) {
@ -388,42 +521,44 @@ func TestProxyAdd(t *testing.T) {
} }
for i := 0; i < len(urlQueries); i++ { for i := 0; i < len(urlQueries); i++ {
res, err := http.DefaultClient.Do(reqs[i]) t.Run(urlQueries[i], func(t *testing.T) {
if err != nil { res, err := http.DefaultClient.Do(reqs[i])
t.Fatal("should have succeeded: ", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatal("Bad response status")
}
var hash ipfsAddResp
// We might return a progress notification, so we do it
// like this to ignore it easily
dec := json.NewDecoder(res.Body)
for dec.More() {
var resp ipfsAddResp
err := dec.Decode(&resp)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal("should have succeeded: ", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Fatalf("Bad response status: got = %d, want = %d", res.StatusCode, http.StatusOK)
} }
if resp.Bytes != 0 { var hash ipfsAddResp
continue
} else {
hash = resp
}
}
if hash.Hash != test.TestCid3 { // We might return a progress notification, so we do it
t.Logf("%+v", hash) // like this to ignore it easily
t.Error("expected TestCid1 as it is hardcoded in ipfs mock") dec := json.NewDecoder(res.Body)
} for dec.More() {
if hash.Name != "testfile" { var resp ipfsAddResp
t.Logf("%+v", hash) err := dec.Decode(&resp)
t.Error("expected testfile for hash name") if err != nil {
} t.Fatal(err)
}
if resp.Bytes != 0 {
continue
} else {
hash = resp
}
}
if hash.Hash != test.TestCid3 {
t.Logf("%+v", hash)
t.Error("expected TestCid1 as it is hardcoded in ipfs mock")
}
if hash.Name != "testfile" {
t.Logf("%+v", hash)
t.Error("expected testfile for hash name")
}
})
} }
} }
@ -437,8 +572,7 @@ func TestProxyAddError(t *testing.T) {
} }
res.Body.Close() res.Body.Close()
if res.StatusCode != http.StatusInternalServerError { if res.StatusCode != http.StatusInternalServerError {
t.Log(res.StatusCode) t.Errorf("wrong status code: got = %d, want = %d", res.StatusCode, http.StatusInternalServerError)
t.Error("expected an error")
} }
} }
@ -634,3 +768,72 @@ func proxyURL(c *Connector) string {
addr := c.listener.Addr() addr := c.listener.Addr()
return fmt.Sprintf("http://%s/api/v0", addr.String()) return fmt.Sprintf("http://%s/api/v0", addr.String())
} }
func Test_extractArgument(t *testing.T) {
type args struct {
handlePath string
u *url.URL
}
tests := []struct {
name string
args args
want string
want1 bool
}{
{
"pin/add url arg",
args{
"add",
mustParseURL(fmt.Sprintf("/api/v0/pin/add/%s", test.TestCid1)),
},
test.TestCid1,
true,
},
{
"pin/add query arg",
args{
"add",
mustParseURL(fmt.Sprintf("/api/v0/pin/add?arg=%s", test.TestCid1)),
},
test.TestCid1,
true,
},
{
"pin/ls url arg",
args{
"pin/ls",
mustParseURL(fmt.Sprintf("/api/v0/pin/ls/%s", test.TestCid1)),
},
test.TestCid1,
true,
},
{
"pin/ls query arg",
args{
"pin/ls",
mustParseURL(fmt.Sprintf("/api/v0/pin/ls?arg=%s", test.TestCid1)),
},
test.TestCid1,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := extractArgument(tt.args.u)
if got != tt.want {
t.Errorf("extractCid() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("extractCid() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return u
}

View File

@ -99,7 +99,6 @@ func NewIpfsMock() *IpfsMock {
func (m *IpfsMock) handler(w http.ResponseWriter, r *http.Request) { func (m *IpfsMock) handler(w http.ResponseWriter, r *http.Request) {
p := r.URL.Path p := r.URL.Path
endp := strings.TrimPrefix(p, "/api/v0/") endp := strings.TrimPrefix(p, "/api/v0/")
var cidStr string
switch endp { switch endp {
case "id": case "id":
resp := mockIDResp{ resp := mockIDResp{
@ -138,45 +137,40 @@ func (m *IpfsMock) handler(w http.ResponseWriter, r *http.Request) {
j, _ := json.Marshal(resp) j, _ := json.Marshal(resp)
w.Write(j) w.Write(j)
case "pin/add": case "pin/add":
query := r.URL.Query() arg, ok := extractCid(r.URL)
arg, ok := query["arg"] if !ok {
if !ok || len(arg) != 1 {
goto ERROR goto ERROR
} }
cidStr = arg[0] if arg == ErrorCid {
if cidStr == ErrorCid {
goto ERROR goto ERROR
} }
c, err := cid.Decode(cidStr) c, err := cid.Decode(arg)
if err != nil { if err != nil {
goto ERROR goto ERROR
} }
m.pinMap.Add(api.PinCid(c)) m.pinMap.Add(api.PinCid(c))
resp := mockPinResp{ resp := mockPinResp{
Pins: []string{cidStr}, Pins: []string{arg},
} }
j, _ := json.Marshal(resp) j, _ := json.Marshal(resp)
w.Write(j) w.Write(j)
case "pin/rm": case "pin/rm":
query := r.URL.Query() arg, ok := extractCid(r.URL)
arg, ok := query["arg"] if !ok {
if !ok || len(arg) != 1 {
goto ERROR goto ERROR
} }
cidStr = arg[0] c, err := cid.Decode(arg)
c, err := cid.Decode(cidStr)
if err != nil { if err != nil {
goto ERROR goto ERROR
} }
m.pinMap.Rm(c) m.pinMap.Rm(c)
resp := mockPinResp{ resp := mockPinResp{
Pins: []string{cidStr}, Pins: []string{arg},
} }
j, _ := json.Marshal(resp) j, _ := json.Marshal(resp)
w.Write(j) w.Write(j)
case "pin/ls": case "pin/ls":
query := r.URL.Query() arg, ok := extractCid(r.URL)
arg, ok := query["arg"]
if !ok { if !ok {
rMap := make(map[string]mockPinType) rMap := make(map[string]mockPinType)
pins := m.pinMap.List() pins := m.pinMap.List()
@ -187,11 +181,8 @@ func (m *IpfsMock) handler(w http.ResponseWriter, r *http.Request) {
w.Write(j) w.Write(j)
break break
} }
if len(arg) != 1 {
goto ERROR
}
cidStr = arg[0]
cidStr := arg
c, err := cid.Decode(cidStr) c, err := cid.Decode(cidStr)
if err != nil { if err != nil {
goto ERROR goto ERROR
@ -209,12 +200,11 @@ func (m *IpfsMock) handler(w http.ResponseWriter, r *http.Request) {
w.Write(j) w.Write(j)
} }
case "swarm/connect": case "swarm/connect":
query := r.URL.Query() arg, ok := extractCid(r.URL)
arg, ok := query["arg"]
if !ok { if !ok {
goto ERROR goto ERROR
} }
addr := arg[0] addr := arg
splits := strings.Split(addr, "/") splits := strings.Split(addr, "/")
pid := splits[len(splits)-1] pid := splits[len(splits)-1]
resp := struct { resp := struct {
@ -256,13 +246,12 @@ func (m *IpfsMock) handler(w http.ResponseWriter, r *http.Request) {
j, _ := json.Marshal(resp) j, _ := json.Marshal(resp)
w.Write(j) w.Write(j)
case "refs": case "refs":
query := r.URL.Query() arg, ok := extractCid(r.URL)
arg, ok := query["arg"]
if !ok { if !ok {
goto ERROR goto ERROR
} }
resp := mockRefsResp{ resp := mockRefsResp{
Ref: arg[0], Ref: arg,
} }
j, _ := json.Marshal(resp) j, _ := json.Marshal(resp)
w.Write(j) w.Write(j)
@ -281,3 +270,20 @@ ERROR:
func (m *IpfsMock) Close() { func (m *IpfsMock) Close() {
m.server.Close() m.server.Close()
} }
// extractCid extracts the cid argument from a url.URL, either via
// the query string parameters or from the url path itself.
func extractCid(u *url.URL) (string, bool) {
arg := u.Query().Get("arg")
if arg != "" {
return arg, true
}
p := strings.TrimPrefix(u.Path, "/api/v0/")
segs := strings.Split(p, "/")
if len(segs) > 2 {
return segs[len(segs)-1], true
}
return "", false
}