diff --git a/shell.go b/shell.go index 870f64e00..df5d0c540 100644 --- a/shell.go +++ b/shell.go @@ -3,6 +3,7 @@ package shell import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -13,6 +14,8 @@ import ( "strings" "time" + pstore "github.com/libp2p/go-libp2p-peerstore" + notif "github.com/libp2p/go-libp2p-routing/notifications" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr-net" files "github.com/whyrusleeping/go-multipart-files" @@ -371,6 +374,53 @@ func (s *Shell) FindPeer(peer string) (*PeerInfo, error) { return &str.Responses[0], nil } +func (s *Shell) FindProvs(ctx context.Context, cid string) (<-chan pstore.PeerInfo, error) { + ctx, cancel := context.WithCancel(ctx) + + resp, err := s.newRequest("dht/findprovs", cid).Send(s.httpcli) + if err != nil { + return nil, err + } + + if resp.Error != nil { + return nil, resp.Error + } + + // 4 is arbitrary here just to make the channel buffered + outchan := make(chan pstore.PeerInfo, 4) + + go func() { + defer close(outchan) + defer cancel() + + var n notif.QueryEvent + decoder := json.NewDecoder(resp.Output) + for { + err := decoder.Decode(&n) + if err != nil { + return + } + + if n.Type == notif.Provider { + for _, p := range n.Responses { + select { + case outchan <- *p: + case <-ctx.Done(): + return + } + } + } + } + }() + + go func() { + <-ctx.Done() + resp.Close() + }() + + return outchan, nil +} + func (s *Shell) Refs(hash string, recursive bool) (<-chan string, error) { req := s.newRequest("refs", hash) if recursive { diff --git a/shell_test.go b/shell_test.go index 9d21ebf29..809646a5e 100644 --- a/shell_test.go +++ b/shell_test.go @@ -2,6 +2,7 @@ package shell import ( "bytes" + "context" "crypto/md5" "fmt" "io" @@ -212,3 +213,18 @@ func TestObjectStat(t *testing.T) { is.Equal(stat.LinksSize, 3) is.Equal(stat.CumulativeSize, 1688) } + +func TestFindProvs(t *testing.T) { + is := is.New(t) + s := NewShell(shellUrl) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c, err := s.FindProvs(ctx, "Qme1g4e3m2SmdiSGGU3vSWmUStwUjc5oECnEriaK9Xa1HU") + is.Nil(err) + + p := <-c + t.Logf("prov: %s", p) + is.NotNil(p) + is.NotNil(p.ID) +}