diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java index 5da311bc588..f1eeb3458e6 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java @@ -38,6 +38,8 @@ default boolean isWritableSupported() { SszBitvector withBit(int i); + SszBitvector or(SszBitvector other); + /** Returns individual bit value */ boolean getBit(int i); diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java index fc6bab3a421..3b7112a0840 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java @@ -76,6 +76,16 @@ public List getSetBitIndices() { return data.stream().boxed().toList(); } + public BitvectorImpl or(final BitvectorImpl other) { + if (other.getSize() != getSize()) { + throw new IllegalArgumentException( + "Argument bitfield size is greater: " + other.getSize() + " > " + getSize()); + } + final BitSet newData = (BitSet) this.data.clone(); + newData.or(other.data); + return new BitvectorImpl(newData, size); + } + public BitvectorImpl withBit(final int i) { checkElementIndex(i, size); BitSet newSet = (BitSet) data.clone(); diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java index cad2d1135bd..e39e5755015 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java @@ -89,6 +89,11 @@ public SszBitvector withBit(final int i) { return new SszBitvectorImpl(getSchema(), value.withBit(i)); } + @Override + public SszBitvector or(final SszBitvector other) { + return new SszBitvectorImpl(getSchema(), value.or(toBitvectorImpl(other))); + } + @Override protected int sizeImpl() { return getSchema().getLength(); @@ -108,4 +113,8 @@ public boolean isWritableSupported() { public String toString() { return "SszBitvector{size=" + this.size() + ", " + value.toString() + "}"; } + + private BitvectorImpl toBitvectorImpl(final SszBitvector bv) { + return ((SszBitvectorImpl) bv).value; + } } diff --git a/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java b/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java index 0d3a801d16b..56326b5a878 100644 --- a/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java +++ b/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java @@ -97,6 +97,36 @@ void testTreeRoundtrip(SszBitvector bitvector1) { SszDataAssert.assertThatSszData(bitvector2).isEqualByAllMeansTo(bitvector1); } + @ParameterizedTest + @MethodSource("bitvectorArgs") + void or_testEqualList(SszBitvector bitvector) { + SszBitvector res = bitvector.or(bitvector); + assertThat(res).isEqualTo(bitvector); + } + + @ParameterizedTest + @MethodSource("bitvectorArgs") + void or_shouldThrowIfBitvectorSizeIsLarger(SszBitvector bitvector) { + SszBitvectorSchema largerSchema = + SszBitvectorSchema.create(bitvector.getSchema().getMaxLength() + 1); + SszBitvector largerBitvector = largerSchema.ofBits(bitvector.size() - 1, bitvector.size()); + assertThatThrownBy(() -> bitvector.or(largerBitvector)) + .isInstanceOf(IllegalArgumentException.class); + } + + @ParameterizedTest + @MethodSource("bitvectorArgs") + void or_shouldThrowIfBitvectorSizeIsSmaller(SszBitvector bitvector) { + if (bitvector.getSchema().getMaxLength() == 1) { + return; + } + SszBitvectorSchema smallerSchema = + SszBitvectorSchema.create(bitvector.getSchema().getMaxLength() - 1); + SszBitvector smallerBitvector = smallerSchema.ofBits(); + assertThatThrownBy(() -> bitvector.or(smallerBitvector)) + .isInstanceOf(IllegalArgumentException.class); + } + @ParameterizedTest @MethodSource("bitvectorArgs") void getBitCount_shouldReturnCorrectCount(SszBitvector bitvector) { @@ -177,6 +207,25 @@ void testBitMethodsAreConsistent(SszBitvector vector) { assertThat(vector.getBitCount()).isEqualTo(bitsIndices.size()); } + @ParameterizedTest + @MethodSource("bitvectorArgs") + void testOr(SszBitvector bitvector) { + SszBitvector orVector = random(bitvector.getSchema()); + SszBitvector res = bitvector.or(orVector); + assertThat(res.size()).isEqualTo(bitvector.size()); + assertThat(res.getSchema()).isEqualTo(bitvector.getSchema()); + for (int i = 0; i < bitvector.size(); i++) { + assertThat(res.getBit(i)).isEqualTo(bitvector.getBit(i) || orVector.getBit(i)); + } + } + + @ParameterizedTest + @MethodSource("bitvectorArgs") + void testOrWithEmptyBitvector(SszBitvector bitvector) { + SszBitvector empty = bitvector.getSchema().ofBits(); + assertThat(bitvector.or(empty)).isEqualTo(bitvector); + } + @ParameterizedTest @MethodSource("bitvectorArgs") void get_shouldThrowIndexOutOfBounds(SszBitvector vector) {