Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Events expansion for the progress tracker to provide more detailed information about pin status #38

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 3 additions & 34 deletions merkledag.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ func init() {
ipld.Register(cid.DagCBOR, ipldcbor.DecodeBlock)
}

// contextKey is a type to use as value for the ProgressTracker contexts.
type contextKey string

const progressContextKey contextKey = "progress"

// NewDAGService constructs a new DAGService (using the default implementation).
// Note that the default implementation is also an ipld.LinkGetter.
func NewDAGService(bs bserv.BlockService) *dagService {
Expand Down Expand Up @@ -196,14 +191,14 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
return false
}

v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
if v == nil {
progressTracker := GetProgressTracker(ctx)
if progressTracker == nil {
return EnumerateChildrenAsyncDepth(ctx, GetLinksDirect(ng), root, 0, visit)
}

visitProgress := func(c cid.Cid, depth int) bool {
if visit(c, depth) {
v.Increment()
progressTracker.PlanToPin(c)
return true
}
return false
Expand Down Expand Up @@ -314,32 +309,6 @@ func EnumerateChildrenDepth(ctx context.Context, getLinks GetLinks, root cid.Cid
return nil
}

// ProgressTracker is used to show progress when fetching nodes.
type ProgressTracker struct {
Total int
lk sync.Mutex
}

// DeriveContext returns a new context with value "progress" derived from
// the given one.
func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context {
return context.WithValue(ctx, progressContextKey, p)
}

// Increment adds one to the total progress.
func (p *ProgressTracker) Increment() {
p.lk.Lock()
defer p.lk.Unlock()
p.Total++
}

// Value returns the current progress.
func (p *ProgressTracker) Value() int {
p.lk.Lock()
defer p.lk.Unlock()
return p.Total
}

// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 32
Expand Down
35 changes: 8 additions & 27 deletions merkledag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,47 +742,28 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
}
}

func TestProgressIndicator(t *testing.T) {
testProgressIndicator(t, 5)
}

func TestProgressIndicatorNoChildren(t *testing.T) {
testProgressIndicator(t, 0)
}

func testProgressIndicator(t *testing.T, depth int) {
ds := dstest.Mock()

top, numChildren := mkDag(ds, depth)
func mkProtoNode() *ProtoNode {
p := new(ProtoNode)
buf := make([]byte, 16)
rand.Read(buf)

v := new(ProgressTracker)
ctx := v.DeriveContext(context.Background())
p.SetData(buf)

err := FetchGraph(ctx, top, ds)
if err != nil {
t.Fatal(err)
}

if v.Value() != numChildren+1 {
t.Errorf("wrong number of children reported in progress indicator, expected %d, got %d",
numChildren+1, v.Value())
}
return p
}

func mkDag(ds ipld.DAGService, depth int) (cid.Cid, int) {
ctx := context.Background()

totalChildren := 0
f := func() *ProtoNode {
p := new(ProtoNode)
buf := make([]byte, 16)
rand.Read(buf)
p := mkProtoNode()

p.SetData(buf)
err := ds.Add(ctx, p)
if err != nil {
panic(err)
}

return p
}

Expand Down
73 changes: 73 additions & 0 deletions progresstracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package merkledag

import (
"context"
"sync"

cid "github.com/ipfs/go-cid"
)

// contextKey is a type to use as value for the ProgressTracker contexts.
type contextKey string

const progressContextKey contextKey = "progress"

func NewProgressTracker() *ProgressTracker {
return &ProgressTracker{
cidsToPin: make([]cid.Cid, 0),
}
}

// WithProgressTracker returns a new context with value "progress" derived from
// the given one.
func WithProgressTracker(ctx context.Context, p *ProgressTracker) (nCtx context.Context) {
return context.WithValue(ctx, progressContextKey, p)
}

// GetProgressTracker returns a progress tracker instance if present
func GetProgressTracker(ctx context.Context) *ProgressTracker {
v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
return v
}

// ProgressTracker is used to show progress when fetching nodes.
type ProgressTracker struct {
lk sync.Mutex
totalToPin int
cidsToPin []cid.Cid
}

// DeriveContext returns a new context with value "progress" derived from
// the given one.
func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context {
return context.WithValue(ctx, progressContextKey, p)
}

// PlanToPin registers cid as a planned to pin
func (p *ProgressTracker) PlanToPin(c cid.Cid) {
p.lk.Lock()
defer p.lk.Unlock()

p.cidsToPin = append(p.cidsToPin, c)
p.totalToPin++
}

// TotalToPin returns how much pins were planned to pin
func (p *ProgressTracker) TotalToPin() int {
p.lk.Lock()
defer p.lk.Unlock()

return p.totalToPin
}

// PopPlannedToPin returns cids that were planned to pin since last call
func (p *ProgressTracker) PopPlannedToPin() []cid.Cid {
p.lk.Lock()
defer p.lk.Unlock()

cids := make([]cid.Cid, len(p.cidsToPin))
copy(cids, p.cidsToPin)
p.cidsToPin = p.cidsToPin[:0]

return cids
}
124 changes: 124 additions & 0 deletions progresstracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package merkledag_test

import (
"context"
"sync"
"testing"
"time"

cid "github.com/ipfs/go-cid"
. "github.com/ipfs/go-merkledag"
dstest "github.com/ipfs/go-merkledag/test"
)

func TestProgressIndicator(t *testing.T) {
testProgressIndicator(t, 5)
}

func TestProgressIndicatorNoChildren(t *testing.T) {
testProgressIndicator(t, 0)
}

func testProgressIndicator(t *testing.T, depth int) {
ds := dstest.Mock()

top, numChildren := mkDag(ds, depth)

progressTracker := NewProgressTracker()
ctx := WithProgressTracker(context.Background(), progressTracker)

err := FetchGraph(ctx, top, ds)
if err != nil {
t.Fatal(err)
}

if progressTracker.TotalToPin() != numChildren+1 {
t.Errorf("wrong number of children reported in progress indicator, expected %d, got %d",
numChildren+1, progressTracker.TotalToPin())
}

plannedToPin := progressTracker.PopPlannedToPin()
if len(plannedToPin) != progressTracker.TotalToPin() {
t.Errorf("wrong number of children reported in progress indicator (total does not match concrete cids count), expected %d, got %d",
len(plannedToPin), progressTracker.TotalToPin())
}
}

func TestProgressIndicatorFlow(t *testing.T) {
progressTracker := NewProgressTracker()
ctx := WithProgressTracker(context.Background(), progressTracker)

ongoingCids := make(chan cid.Cid)
actualPinCids := make([]cid.Cid, 0)
registeredToPinCids := make([]cid.Cid, 0)

go func(ctx context.Context) {
ticker := time.NewTicker(5 * time.Millisecond)
defer func() {
close(ongoingCids)
ticker.Stop()
}()

progressTracker := GetProgressTracker(ctx)
upTo := time.After(1 * time.Second)

for {
select {
case <-ticker.C:
node := mkProtoNode()
ongoingCids <- node.Cid()
progressTracker.PlanToPin(node.Cid())
case <-upTo:
return
}

}
}(ctx)

cCtx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup
wg.Add(1)
go func(ctx context.Context) {
defer func() {
registeredToPinCids = append(
registeredToPinCids,
progressTracker.PopPlannedToPin()...)
wg.Done()
}()

progressTracker := GetProgressTracker(ctx)
ticker := time.NewTicker(3 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ticker.C:
registeredToPinCids = append(
registeredToPinCids,
progressTracker.PopPlannedToPin()...)
case <-ctx.Done():
return
}
}
}(cCtx)

for cid := range ongoingCids {
actualPinCids = append(actualPinCids, cid)
}
cancel()
wg.Wait()

if len(actualPinCids) != len(registeredToPinCids) {
t.Errorf("actual and registered pins mismatch: %d vs %d",
len(actualPinCids), len(registeredToPinCids))
}

for i := 0; i < len(actualPinCids)-1; i++ {
if actualPinCids[i] != registeredToPinCids[i] {
t.Errorf("actual and registered pins mismatch at %d: %v vs %v",
i, actualPinCids[i], registeredToPinCids[i])
}
}
}