Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#32929] Add OrderedListState support to Prism. #33350

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)).

## Breaking Changes

Expand Down
4 changes: 0 additions & 4 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@ def createPrismValidatesRunnerTask = { name, environmentType ->
excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService'
excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'

// Not yet implemented in Prism
// https://github.com/apache/beam/issues/32929
excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'

// Not supported in Portable Java SDK yet.
// https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState
excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
Expand Down
97 changes: 97 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/engine/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ package engine

import (
"bytes"
"cmp"
"fmt"
"log/slog"
"slices"
"sort"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"google.golang.org/protobuf/encoding/protowire"
)

// StateData is a "union" between Bag state and MultiMap state to increase common code.
Expand All @@ -42,6 +46,10 @@ type TimerKey struct {
type TentativeData struct {
Raw map[string][][]byte

// stateTypeLen is a map from LinkID to valueLen function for parsing data.
// Only used by OrderedListState, since Prism must manipulate these datavalues,
// which isn't expected, or a requirement of other state values.
stateTypeLen map[LinkID]func([]byte) int
// state is a map from transformID + UserStateID, to window, to userKey, to datavalues.
state map[LinkID]map[typex.Window]map[string]StateData
// timers is a map from the Timer transform+family to the encoded timer.
Expand Down Expand Up @@ -220,3 +228,92 @@ func (d *TentativeData) ClearMultimapKeysState(stateID LinkID, wKey, uKey []byte
kmap[string(uKey)] = StateData{}
slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey))
}

// AppendOrderedListState appends the incoming timestamped data to the existing tentative data bundle.
// Assumes the data is TimestampedValue encoded, which has a BigEndian int64 suffixed to the data.
// This means we may always use the last 8 bytes to determine the value sorting.
//
// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively.
func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey []byte, data []byte) {
kmap := d.appendState(stateID, wKey)
typeLen := d.stateTypeLen[stateID]
var datums [][]byte

// We need to parse out all values individually for later sorting.
//
// OrderedListState is encoded as KVs with varint encoded millis followed by the value.
// This is not the standard TimestampValueCoder encoding, which
// uses a big-endian long as a suffix to the value. This is important since
// values may be concatenated, and we'll need to split them out out.
//
// The TentativeData.stateTypeLen is populated with a function to extract
// the length of a the next value, so we can skip through elements individually.
for i := 0; i < len(data); {
// Get the length of the VarInt for the timestamp.
_, tn := protowire.ConsumeVarint(data[i:])

// Get the length of the encoded value.
vn := typeLen(data[i+tn:])
prev := i
i += tn + vn
datums = append(datums, data[prev:i])
}

s := StateData{Bag: append(kmap[string(uKey)].Bag, datums...)}
sort.SliceStable(s.Bag, func(i, j int) bool {
vi := s.Bag[i]
vj := s.Bag[j]
return compareTimestampSuffixes(vi, vj)
})
kmap[string(uKey)] = s
slog.Debug("State() OrderedList.Append", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", s))
}

func compareTimestampSuffixes(vi, vj []byte) bool {
ims, _ := protowire.ConsumeVarint(vi)
jms, _ := protowire.ConsumeVarint(vj)
return (int64(ims)) < (int64(jms))
}

// GetOrderedListState available state from the tentative bundle data.
// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively.
func (d *TentativeData) GetOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) [][]byte {
winMap := d.state[stateID]
w := d.toWindow(wKey)
data := winMap[w][string(uKey)]

lo, hi := findRange(data.Bag, start, end)
slog.Debug("State() OrderedList.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), slog.Group("outrange", slog.Int("lo", lo), slog.Int("hi", hi)), slog.Any("Data", data.Bag[lo:hi]))
return data.Bag[lo:hi]
}

func cmpSuffix(vs [][]byte, target int64) func(i int) int {
return func(i int) int {
v := vs[i]
ims, _ := protowire.ConsumeVarint(v)
tvsbi := cmp.Compare(target, int64(ims))
slog.Debug("cmpSuffix", "target", target, "bi", ims, "tvsbi", tvsbi)
return tvsbi
}
}

func findRange(bag [][]byte, start, end int64) (int, int) {
lo, _ := sort.Find(len(bag), cmpSuffix(bag, start))
hi, _ := sort.Find(len(bag), cmpSuffix(bag, end))
return lo, hi
}

func (d *TentativeData) ClearOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) {
winMap := d.state[stateID]
w := d.toWindow(wKey)
kMap := winMap[w]
data := kMap[string(uKey)]

lo, hi := findRange(data.Bag, start, end)
slog.Debug("State() OrderedList.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), "lo", lo, "hi", hi, slog.Any("PreClearData", data.Bag))

cleared := slices.Delete(data.Bag, lo, hi)
// Zero the current entry to clear.
// Delete makes it difficult to delete the persisted stage state for the key.
kMap[string(uKey)] = StateData{Bag: cleared}
}
222 changes: 222 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package engine

import (
"bytes"
"encoding/binary"
"math"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/protowire"
)

func TestCompareTimestampSuffixes(t *testing.T) {
t.Run("simple", func(t *testing.T) {
loI := int64(math.MinInt64)
hiI := int64(math.MaxInt64)

loB := binary.BigEndian.AppendUint64(nil, uint64(loI))
hiB := binary.BigEndian.AppendUint64(nil, uint64(hiI))

if compareTimestampSuffixes(loB, hiB) != (loI < hiI) {
t.Errorf("lo vs Hi%v < %v: bytes %v vs %v, %v %v", loI, hiI, loB, hiB, loI < hiI, compareTimestampSuffixes(loB, hiB))
}
})
}

func TestOrderedListState(t *testing.T) {
time1 := protowire.AppendVarint(nil, 11)
time2 := protowire.AppendVarint(nil, 22)
time3 := protowire.AppendVarint(nil, 33)
time4 := protowire.AppendVarint(nil, 44)
time5 := protowire.AppendVarint(nil, 55)

wKey := []byte{} // global window.
uKey := []byte("\u0007userkey")
linkID := LinkID{
Transform: "dofn",
Local: "localStateName",
}
cc := func(a []byte, b ...byte) []byte {
return bytes.Join([][]byte{a, b}, []byte{})
}

t.Run("bool", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(_ []byte) int {
return 1
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, 1),
cc(time2, 0),
cc(time3, 1),
cc(time4, 0),
cc(time5, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList booleans \n%v", d)
}

d.ClearOrderedListState(linkID, wKey, uKey, 12, 54)
got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want = [][]byte{
cc(time1, 1),
cc(time5, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList booleans, after clear\n%v", d)
}
})
t.Run("float64", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(_ []byte) int {
return 8
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0, 0, 0, 0, 0, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 0, 0, 0, 0, 0, 0, 1, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 0, 0, 0, 0, 1, 0, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0, 0, 0, 0, 1, 0, 0, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 0, 1, 0, 0, 0, 0))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, 0, 0, 0, 0, 0, 0, 1, 0),
cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
cc(time3, 0, 0, 0, 0, 0, 1, 0, 0),
cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
cc(time5, 0, 0, 0, 0, 0, 0, 0, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList float64s \n%v", d)
}

d.ClearOrderedListState(linkID, wKey, uKey, 11, 12)
d.ClearOrderedListState(linkID, wKey, uKey, 33, 34)
d.ClearOrderedListState(linkID, wKey, uKey, 55, 56)

got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want = [][]byte{
cc(time2, 0, 0, 0, 0, 1, 0, 0, 0),
cc(time4, 0, 0, 0, 1, 0, 0, 0, 0),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList float64s, after clear \n%v", d)
}
})

t.Run("varint", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
_, n := protowire.ConsumeVarint(b)
return int(n)
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, protowire.AppendVarint(nil, 56)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, protowire.AppendVarint(nil, 20067)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, protowire.AppendVarint(nil, 7777777)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, protowire.AppendVarint(nil, 424242)...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, protowire.AppendVarint(nil, 0)...))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, protowire.AppendVarint(nil, 424242)...),
cc(time2, protowire.AppendVarint(nil, 56)...),
cc(time3, protowire.AppendVarint(nil, 7777777)...),
cc(time4, protowire.AppendVarint(nil, 20067)...),
cc(time5, protowire.AppendVarint(nil, 0)...),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
t.Run("lp", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
},
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, []byte("\u0003one")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, []byte("\u0003two")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, []byte("\u0005three")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, []byte("\u0004four")...))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, []byte("\u0003one")...),
cc(time2, []byte("\u0003two")...),
cc(time3, []byte("\u0005three")...),
cc(time4, []byte("\u0004four")...),
cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
t.Run("lp_onecall", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
},
},
}
d.AppendOrderedListState(linkID, wKey, uKey, bytes.Join([][]byte{
time5, []byte("\u0019FourHundredAndEleventyTwo"),
time3, []byte("\u0005three"),
time2, []byte("\u0003two"),
time1, []byte("\u0003one"),
time4, []byte("\u0004four"),
}, nil))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, []byte("\u0003one")...),
cc(time2, []byte("\u0003two")...),
cc(time3, []byte("\u0005three")...),
cc(time4, []byte("\u0004four")...),
cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
}
11 changes: 8 additions & 3 deletions sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,10 @@ func (em *ElementManager) StageAggregates(ID string) {

// StageStateful marks the given stage as stateful, which means elements are
// processed by key.
func (em *ElementManager) StageStateful(ID string) {
em.stages[ID].stateful = true
func (em *ElementManager) StageStateful(ID string, stateTypeLen map[LinkID]func([]byte) int) {
ss := em.stages[ID]
ss.stateful = true
ss.stateTypeLen = stateTypeLen
}

// StageOnWindowExpiration marks the given stage as stateful, which means elements are
Expand Down Expand Up @@ -669,7 +671,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData {
ss := em.stages[rb.StageID]
ss.mu.Lock()
defer ss.mu.Unlock()
var ret TentativeData
ret := TentativeData{
stateTypeLen: ss.stateTypeLen,
}
keys := ss.inprogressKeysByBundle[rb.BundleID]
// TODO(lostluck): Also track windows per bundle, to reduce copying.
if len(ss.state) > 0 {
Expand Down Expand Up @@ -1136,6 +1140,7 @@ type stageState struct {
inprogressKeys set[string] // all keys that are assigned to bundles.
inprogressKeysByBundle map[string]set[string] // bundle to key assignments.
state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey
stateTypeLen map[LinkID]func([]byte) int // map from state to a function that will produce the total length of a single value in bytes.

// Accounting for handling watermark holds for timers.
// We track the count of timers with the same hold, and clear it from
Expand Down
Loading
Loading