diff --git a/merkledag.go b/merkledag.go index 3153cf4..9acff3f 100644 --- a/merkledag.go +++ b/merkledag.go @@ -22,11 +22,6 @@ func init() { ipld.Register(cid.DagCBOR, ipldcbor.DecodeBlock) } -// contextKey is a type to use as value for the ProgressTracker contexts. -type contextKey string - -const progressContextKey contextKey = "progress" - // NewDAGService constructs a new DAGService (using the default implementation). // Note that the default implementation is also an ipld.LinkGetter. func NewDAGService(bs bserv.BlockService) *dagService { @@ -196,14 +191,14 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s return false } - v, _ := ctx.Value(progressContextKey).(*ProgressTracker) - if v == nil { + progressTracker := GetProgressTracker(ctx) + if progressTracker == nil { return EnumerateChildrenAsyncDepth(ctx, GetLinksDirect(ng), root, 0, visit) } visitProgress := func(c cid.Cid, depth int) bool { if visit(c, depth) { - v.Increment() + progressTracker.PlanToPin(c) return true } return false @@ -314,32 +309,6 @@ func EnumerateChildrenDepth(ctx context.Context, getLinks GetLinks, root cid.Cid return nil } -// ProgressTracker is used to show progress when fetching nodes. -type ProgressTracker struct { - Total int - lk sync.Mutex -} - -// DeriveContext returns a new context with value "progress" derived from -// the given one. -func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context { - return context.WithValue(ctx, progressContextKey, p) -} - -// Increment adds one to the total progress. -func (p *ProgressTracker) Increment() { - p.lk.Lock() - defer p.lk.Unlock() - p.Total++ -} - -// Value returns the current progress. -func (p *ProgressTracker) Value() int { - p.lk.Lock() - defer p.lk.Unlock() - return p.Total -} - // FetchGraphConcurrency is total number of concurrent fetches that // 'fetchNodes' will start at a time var FetchGraphConcurrency = 32 diff --git a/merkledag_test.go b/merkledag_test.go index bc87f3b..e10779e 100644 --- a/merkledag_test.go +++ b/merkledag_test.go @@ -742,31 +742,14 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) { } } -func TestProgressIndicator(t *testing.T) { - testProgressIndicator(t, 5) -} - -func TestProgressIndicatorNoChildren(t *testing.T) { - testProgressIndicator(t, 0) -} - -func testProgressIndicator(t *testing.T, depth int) { - ds := dstest.Mock() - - top, numChildren := mkDag(ds, depth) +func mkProtoNode() *ProtoNode { + p := new(ProtoNode) + buf := make([]byte, 16) + rand.Read(buf) - v := new(ProgressTracker) - ctx := v.DeriveContext(context.Background()) + p.SetData(buf) - err := FetchGraph(ctx, top, ds) - if err != nil { - t.Fatal(err) - } - - if v.Value() != numChildren+1 { - t.Errorf("wrong number of children reported in progress indicator, expected %d, got %d", - numChildren+1, v.Value()) - } + return p } func mkDag(ds ipld.DAGService, depth int) (cid.Cid, int) { @@ -774,15 +757,13 @@ func mkDag(ds ipld.DAGService, depth int) (cid.Cid, int) { totalChildren := 0 f := func() *ProtoNode { - p := new(ProtoNode) - buf := make([]byte, 16) - rand.Read(buf) + p := mkProtoNode() - p.SetData(buf) err := ds.Add(ctx, p) if err != nil { panic(err) } + return p } diff --git a/progresstracker.go b/progresstracker.go new file mode 100644 index 0000000..ee16530 --- /dev/null +++ b/progresstracker.go @@ -0,0 +1,73 @@ +package merkledag + +import ( + "context" + "sync" + + cid "github.com/ipfs/go-cid" +) + +// contextKey is a type to use as value for the ProgressTracker contexts. +type contextKey string + +const progressContextKey contextKey = "progress" + +func NewProgressTracker() *ProgressTracker { + return &ProgressTracker{ + cidsToPin: make([]cid.Cid, 0), + } +} + +// WithProgressTracker returns a new context with value "progress" derived from +// the given one. +func WithProgressTracker(ctx context.Context, p *ProgressTracker) (nCtx context.Context) { + return context.WithValue(ctx, progressContextKey, p) +} + +// GetProgressTracker returns a progress tracker instance if present +func GetProgressTracker(ctx context.Context) *ProgressTracker { + v, _ := ctx.Value(progressContextKey).(*ProgressTracker) + return v +} + +// ProgressTracker is used to show progress when fetching nodes. +type ProgressTracker struct { + lk sync.Mutex + totalToPin int + cidsToPin []cid.Cid +} + +// DeriveContext returns a new context with value "progress" derived from +// the given one. +func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context { + return context.WithValue(ctx, progressContextKey, p) +} + +// PlanToPin registers cid as a planned to pin +func (p *ProgressTracker) PlanToPin(c cid.Cid) { + p.lk.Lock() + defer p.lk.Unlock() + + p.cidsToPin = append(p.cidsToPin, c) + p.totalToPin++ +} + +// TotalToPin returns how much pins were planned to pin +func (p *ProgressTracker) TotalToPin() int { + p.lk.Lock() + defer p.lk.Unlock() + + return p.totalToPin +} + +// PopPlannedToPin returns cids that were planned to pin since last call +func (p *ProgressTracker) PopPlannedToPin() []cid.Cid { + p.lk.Lock() + defer p.lk.Unlock() + + cids := make([]cid.Cid, len(p.cidsToPin)) + copy(cids, p.cidsToPin) + p.cidsToPin = p.cidsToPin[:0] + + return cids +} diff --git a/progresstracker_test.go b/progresstracker_test.go new file mode 100644 index 0000000..015a4ff --- /dev/null +++ b/progresstracker_test.go @@ -0,0 +1,124 @@ +package merkledag_test + +import ( + "context" + "sync" + "testing" + "time" + + cid "github.com/ipfs/go-cid" + . "github.com/ipfs/go-merkledag" + dstest "github.com/ipfs/go-merkledag/test" +) + +func TestProgressIndicator(t *testing.T) { + testProgressIndicator(t, 5) +} + +func TestProgressIndicatorNoChildren(t *testing.T) { + testProgressIndicator(t, 0) +} + +func testProgressIndicator(t *testing.T, depth int) { + ds := dstest.Mock() + + top, numChildren := mkDag(ds, depth) + + progressTracker := NewProgressTracker() + ctx := WithProgressTracker(context.Background(), progressTracker) + + err := FetchGraph(ctx, top, ds) + if err != nil { + t.Fatal(err) + } + + if progressTracker.TotalToPin() != numChildren+1 { + t.Errorf("wrong number of children reported in progress indicator, expected %d, got %d", + numChildren+1, progressTracker.TotalToPin()) + } + + plannedToPin := progressTracker.PopPlannedToPin() + if len(plannedToPin) != progressTracker.TotalToPin() { + t.Errorf("wrong number of children reported in progress indicator (total does not match concrete cids count), expected %d, got %d", + len(plannedToPin), progressTracker.TotalToPin()) + } +} + +func TestProgressIndicatorFlow(t *testing.T) { + progressTracker := NewProgressTracker() + ctx := WithProgressTracker(context.Background(), progressTracker) + + ongoingCids := make(chan cid.Cid) + actualPinCids := make([]cid.Cid, 0) + registeredToPinCids := make([]cid.Cid, 0) + + go func(ctx context.Context) { + ticker := time.NewTicker(5 * time.Millisecond) + defer func() { + close(ongoingCids) + ticker.Stop() + }() + + progressTracker := GetProgressTracker(ctx) + upTo := time.After(1 * time.Second) + + for { + select { + case <-ticker.C: + node := mkProtoNode() + ongoingCids <- node.Cid() + progressTracker.PlanToPin(node.Cid()) + case <-upTo: + return + } + + } + }(ctx) + + cCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + wg.Add(1) + go func(ctx context.Context) { + defer func() { + registeredToPinCids = append( + registeredToPinCids, + progressTracker.PopPlannedToPin()...) + wg.Done() + }() + + progressTracker := GetProgressTracker(ctx) + ticker := time.NewTicker(3 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + registeredToPinCids = append( + registeredToPinCids, + progressTracker.PopPlannedToPin()...) + case <-ctx.Done(): + return + } + } + }(cCtx) + + for cid := range ongoingCids { + actualPinCids = append(actualPinCids, cid) + } + cancel() + wg.Wait() + + if len(actualPinCids) != len(registeredToPinCids) { + t.Errorf("actual and registered pins mismatch: %d vs %d", + len(actualPinCids), len(registeredToPinCids)) + } + + for i := 0; i < len(actualPinCids)-1; i++ { + if actualPinCids[i] != registeredToPinCids[i] { + t.Errorf("actual and registered pins mismatch at %d: %v vs %v", + i, actualPinCids[i], registeredToPinCids[i]) + } + } +}