diff --git a/chain/types/bitfield.go b/chain/types/bitfield.go index 617229c376c..1f46e9451a6 100644 --- a/chain/types/bitfield.go +++ b/chain/types/bitfield.go @@ -3,19 +3,24 @@ package types import ( "fmt" "io" - "sort" - "github.com/filecoin-project/lotus/extern/rleplus" + rlepluslazy "github.com/filecoin-project/lotus/lib/rlepluslazy" cbg "github.com/whyrusleeping/cbor-gen" "golang.org/x/xerrors" ) type BitField struct { + rle rlepluslazy.RLE + bits map[uint64]struct{} } func NewBitField() BitField { - return BitField{bits: make(map[uint64]struct{})} + rle, _ := rlepluslazy.FromBuf([]byte{}) + return BitField{ + rle: rle, + bits: make(map[uint64]struct{}), + } } func BitFieldFromSet(setBits []uint64) BitField { @@ -26,31 +31,59 @@ func BitFieldFromSet(setBits []uint64) BitField { return res } +func (bf BitField) sum() (rlepluslazy.RunIterator, error) { + if len(bf.bits) == 0 { + return bf.rle.RunIterator() + } + + a, err := bf.rle.RunIterator() + if err != nil { + return nil, err + } + slc := make([]uint64, 0, len(bf.bits)) + for b := range bf.bits { + slc = append(slc, b) + } + + b, err := rlepluslazy.RunsFromSlice(slc) + if err != nil { + return nil, err + } + + res, err := rlepluslazy.Sum(a, b) + if err != nil { + return nil, err + } + return res, nil +} + // Set ...s bit in the BitField func (bf BitField) Set(bit uint64) { bf.bits[bit] = struct{}{} } -// Clear ...s bit in the BitField -func (bf BitField) Clear(bit uint64) { - delete(bf.bits, bit) -} - -// Has checkes if bit is set in the BitField -func (bf BitField) Has(bit uint64) bool { - _, ok := bf.bits[bit] - return ok +func (bf BitField) Count() (uint64, error) { + s, err := bf.sum() + if err != nil { + return 0, err + } + return rlepluslazy.Count(s) } // All returns all set bits, in random order -func (bf BitField) All() []uint64 { - res := make([]uint64, 0, len(bf.bits)) - for i := range bf.bits { - res = append(res, i) +func (bf BitField) All() ([]uint64, error) { + + runs, err := bf.sum() + if err != nil { + return nil, err } - sort.Slice(res, func(i, j int) bool { return res[i] < res[j] }) - return res + res, err := rlepluslazy.SliceFromRuns(runs) + if err != nil { + return nil, err + } + + return res, err } func (bf BitField) MarshalCBOR(w io.Writer) error { @@ -59,7 +92,12 @@ func (bf BitField) MarshalCBOR(w io.Writer) error { ints = append(ints, i) } - rle, _, err := rleplus.Encode(ints) // Encode sorts internally + s, err := bf.sum() + if err != nil { + return err + } + + rle, err := rlepluslazy.EncodeRuns(s, []byte{}) if err != nil { return err } @@ -88,19 +126,17 @@ func (bf *BitField) UnmarshalCBOR(r io.Reader) error { return fmt.Errorf("expected byte array") } - rle := make([]byte, extra) - if _, err := io.ReadFull(br, rle); err != nil { + buf := make([]byte, extra) + if _, err := io.ReadFull(br, buf); err != nil { return err } - ints, err := rleplus.Decode(rle) + rle, err := rlepluslazy.FromBuf(buf) if err != nil { return xerrors.Errorf("could not decode rle+: %w", err) } + bf.rle = rle bf.bits = make(map[uint64]struct{}) - for _, i := range ints { - bf.bits[i] = struct{}{} - } return nil } diff --git a/lib/rlepluslazy/rleplus.go b/lib/rlepluslazy/rleplus.go index c72ae20c7dd..6bf8f378968 100644 --- a/lib/rlepluslazy/rleplus.go +++ b/lib/rlepluslazy/rleplus.go @@ -16,25 +16,18 @@ var ( type RLE struct { buf []byte - - changes []change } -type change struct { - set bool - reset bool - index uint64 -} - -func (c change) valid() bool { - return c.reset || c.set -} - -func FromBuf(buf []byte) (*RLE, error) { - rle := &RLE{buf: buf} +func FromBuf(buf []byte) (RLE, error) { + rle := RLE{buf: buf} if len(buf) > 0 && buf[0]&3 != Version { - return nil, xerrors.Errorf("could not create RLE+ for a buffer: %w", ErrWrongVersion) + return RLE{}, xerrors.Errorf("could not create RLE+ for a buffer: %w", ErrWrongVersion) + } + + _, err := rle.Count() + if err != nil { + return RLE{}, err } return rle, nil @@ -45,7 +38,26 @@ func (rle *RLE) RunIterator() (RunIterator, error) { return source, err } +func (rle *RLE) Count() (uint64, error) { + it, err := rle.RunIterator() + if err != nil { + return 0, err + } + return Count(it) +} + /* + +type change struct { + set bool + reset bool + index uint64 +} +func (c change) valid() bool { + return c.reset || c.set +} + +func (rle *RLE) RunIterator() (RunIterator, error) { if err != nil { return nil, err } diff --git a/lib/rlepluslazy/runs.go b/lib/rlepluslazy/runs.go index c29e0002ff7..c8ae36644bf 100644 --- a/lib/rlepluslazy/runs.go +++ b/lib/rlepluslazy/runs.go @@ -1,5 +1,11 @@ package rlepluslazy +import ( + "math" + + "golang.org/x/xerrors" +) + func Sum(a, b RunIterator) (RunIterator, error) { it := addIt{a: a, b: b} it.prep() @@ -94,3 +100,21 @@ func (it *addIt) NextRun() (Run, error) { next := it.next return next, it.prep() } + +func Count(ri RunIterator) (uint64, error) { + var count uint64 + + for ri.HasNext() { + r, err := ri.NextRun() + if err != nil { + return 0, err + } + if r.Val { + if math.MaxUint64-r.Len > count { + return 0, xerrors.New("RLE+ overflows") + } + count += r.Len + } + } + return count, nil +} diff --git a/lib/rlepluslazy/runs_test.go b/lib/rlepluslazy/runs_test.go index 3ab86e76f19..aefeb2c708e 100644 --- a/lib/rlepluslazy/runs_test.go +++ b/lib/rlepluslazy/runs_test.go @@ -76,8 +76,8 @@ func TestSumRandom(t *testing.T) { N := 100 for i := 0; i < N; i++ { - abits := randomBits(1000, 2000) - bbits := randomBits(1000, 2000) + abits := randomBits(1000, 1500) + bbits := randomBits(1000, 1500) sumbits := sum(abits, bbits) a, err := RunsFromSlice(abits)