Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DH-18472 correct the ordering of result columns in update_by #6630

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1392,15 +1392,26 @@ public static Table updateBy(@NotNull final QueryTable source,

final Map<String, ColumnSource<?>> resultSources = new LinkedHashMap<>(source.getColumnSourceMap());

// We have the source table and the row redirection; we can initialize the operators and add the output
// columns to the result sources
final Map<String, ColumnSource<?>> unorderedResultSources = new HashMap<>();
// We have the source table and the row redirection; we can initialize the operators and collect the output
// columns.
for (UpdateByWindow win : operatorCollection.windowArr) {
for (UpdateByOperator op : win.operators) {
op.initializeSources(source, rowRedirection);
resultSources.putAll(op.getOutputColumns());
unorderedResultSources.putAll(op.getOutputColumns());
}
}

// Add the output result sources to the table column map in the order specified by the updateBy call.
for (String outputColumnName : operatorCollection.outputColumnNames) {
final ColumnSource<?> cs = unorderedResultSources.get(outputColumnName);
if (cs == null) {
throw new IllegalStateException(
"Requested output column '" + outputColumnName + "' was not found in operator output");
}
resultSources.put(outputColumnName, cs);
}

if (operatorCollection.byColumnNames.length == 0) {
return LivenessScopeStack.computeEnclosed(() -> {
final ZeroKeyUpdateByManager zkm = new ZeroKeyUpdateByManager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,21 @@ private class OutputColumnVisitor implements UpdateByOperation.Visitor<Void> {
@Override
public Void visit(@NotNull final ColumnUpdateOperation clause) {
final UpdateBySpec spec = clause.spec();
// Need to handle some specs uniquely
if (spec instanceof CumCountWhereSpec) {
outputColumns.add(((CumCountWhereSpec) spec).column().name());
return null;
}
if (spec instanceof RollingCountWhereSpec) {
outputColumns.add(((RollingCountWhereSpec) spec).column().name());
return null;
}
if (spec instanceof RollingFormulaSpec && ((RollingFormulaSpec) spec).paramToken().isEmpty()) {
// The presence of the paramToken indicates that this is a multi-column formula and we have a single
// output column in #selectable()
outputColumns.add(((RollingFormulaSpec) spec).selectable().newColumn().name());
return null;
}
final MatchPair[] pairs =
createColumnsToAddIfMissing(tableDef, parseMatchPairs(clause.columns()), spec, groupByColumns);
for (MatchPair pair : pairs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.deephaven.engine.table.impl.QueryTable;
import io.deephaven.engine.table.impl.UpdateErrorReporter;
import io.deephaven.engine.table.impl.util.AsyncClientErrorNotifier;
import io.deephaven.engine.table.impl.util.ColumnHolder;
import io.deephaven.engine.testutil.ControlledUpdateGraph;
import io.deephaven.engine.testutil.EvalNugget;
import io.deephaven.engine.table.impl.TableDefaults;
Expand All @@ -29,18 +30,19 @@
import junit.framework.TestCase;
import org.jetbrains.annotations.NotNull;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.categories.Category;

import java.time.Duration;
import java.util.*;

import static io.deephaven.api.updateby.UpdateByOperation.*;
import static io.deephaven.engine.testutil.GenerateTableUpdates.generateAppends;
import static io.deephaven.engine.testutil.TstUtils.*;
import static io.deephaven.engine.testutil.testcase.RefreshingTableTestCase.simulateShiftAwareStep;
import static io.deephaven.engine.util.TableTools.col;
import static io.deephaven.engine.util.TableTools.intCol;
import static io.deephaven.engine.util.TableTools.*;
import static io.deephaven.time.DateTimeUtils.MINUTE;

@Category(OutOfBandTest.class)
Expand Down Expand Up @@ -138,9 +140,9 @@ protected Table e() {
UpdateByOperation.RollingMin("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
makeOpColNames(columnNamesArray, "_rollmintimerev", "Sym", "ts", "boolCol")),

UpdateByOperation.RollingMax(50, 50,
RollingMax(50, 50,
makeOpColNames(columnNamesArray, "_rollmaxticksrev", "Sym", "ts", "boolCol")),
UpdateByOperation.RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
makeOpColNames(columnNamesArray, "_rollmaxtimerev", "Sym", "ts", "boolCol")),

// Excluding 'bigDecimalCol' because we need fuzzy matching which doesn't exist for BD
Expand All @@ -154,8 +156,8 @@ protected Table e() {
UpdateByOperation.Ema(skipControl, "ts", 10 * MINUTE,
makeOpColNames(columnNamesArray, "_ema", "Sym", "ts", "boolCol")),
UpdateByOperation.CumSum(makeOpColNames(columnNamesArray, "_sum", "Sym", "ts")),
UpdateByOperation.CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
UpdateByOperation.CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
UpdateByOperation
.CumProd(makeOpColNames(columnNamesArray, "_prod", "Sym", "ts", "boolCol")));
final UpdateByControl control = UpdateByControl.builder().useRedirection(redirected).build();
Expand Down Expand Up @@ -272,38 +274,36 @@ public void testInMemoryColumn() {
final Collection<? extends UpdateByOperation> clauses = List.of(
UpdateByOperation.Fill(),

UpdateByOperation.RollingGroup(50, 50,
RollingGroup(50, 50,
makeOpColNames(columnNamesArray, "_rollgroupfwdrev", "Sym", "ts")),
UpdateByOperation.RollingGroup("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
RollingGroup("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
makeOpColNames(columnNamesArray, "_rollgrouptimefwdrev", "Sym", "ts")),

UpdateByOperation.RollingSum(100, 0,
RollingSum(100, 0,
makeOpColNames(columnNamesArray, "_rollsumticksrev", "Sym", "ts", "boolCol")),
UpdateByOperation.RollingSum("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
RollingSum("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
makeOpColNames(columnNamesArray, "_rollsumtimerev", "Sym", "ts", "boolCol")),

UpdateByOperation.RollingAvg(100, 0,
RollingAvg(100, 0,
makeOpColNames(columnNamesArray, "_rollavgticksrev", "Sym", "ts", "boolCol")),
UpdateByOperation.RollingAvg("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
RollingAvg("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
makeOpColNames(columnNamesArray, "_rollavgtimerev", "Sym", "ts", "boolCol")),

UpdateByOperation.RollingMin(100, 0,
RollingMin(100, 0,
makeOpColNames(columnNamesArray, "_rollminticksrev", "Sym", "ts", "boolCol")),
UpdateByOperation.RollingMin("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
RollingMin("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
makeOpColNames(columnNamesArray, "_rollmintimerev", "Sym", "ts", "boolCol")),

UpdateByOperation.RollingMax(100, 0,
RollingMax(100, 0,
makeOpColNames(columnNamesArray, "_rollmaxticksrev", "Sym", "ts", "boolCol")),
UpdateByOperation.RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
makeOpColNames(columnNamesArray, "_rollmaxtimerev", "Sym", "ts", "boolCol")),

UpdateByOperation.Ema(skipControl, "ts", 10 * MINUTE,
makeOpColNames(columnNamesArray, "_ema", "Sym", "ts", "boolCol")),
UpdateByOperation.CumSum(makeOpColNames(columnNamesArray, "_sum", "Sym", "ts")),
UpdateByOperation.CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
UpdateByOperation.CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
UpdateByOperation
.CumProd(makeOpColNames(columnNamesArray, "_prod", "Sym", "ts", "boolCol")));
Ema(skipControl, "ts", 10 * MINUTE, makeOpColNames(columnNamesArray, "_ema", "Sym", "ts", "boolCol")),
CumSum(makeOpColNames(columnNamesArray, "_sum", "Sym", "ts")),
CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
CumProd(makeOpColNames(columnNamesArray, "_prod", "Sym", "ts", "boolCol")));
final UpdateByControl control = UpdateByControl.builder().useRedirection(false).build();

final Table table = result.t.updateBy(control, clauses, ColumnName.from("Sym"));
Expand All @@ -322,4 +322,76 @@ public void run() {
}
});
}

@Test
public void testResultColumnOrdering() {
final Table source = emptyTable(5).update("X=ii");

final ColumnHolder<?> x = longCol("X", 0, 1, 2, 3, 4);
final ColumnHolder<?> cumMin = longCol("cumMin", 0, 0, 0, 0, 0);
final ColumnHolder<?> cumMax = longCol("cumMax", 0, 1, 2, 3, 4);
final ColumnHolder<?> rollingMin = longCol("rollingMin", 0, 0, 1, 2, 3);
final ColumnHolder<?> rollingMax = longCol("rollingMax", 0, 1, 2, 3, 4);

final Table result_1 = source.updateBy(List.of(
CumMin("cumMin=X"),
CumMax("cumMax=X"),
RollingMin(2, "rollingMin=X"),
RollingMax(2, "rollingMax=X")));
final Table expected_1 = TableTools.newTable(x, cumMin, cumMax, rollingMin, rollingMax);
Assert.assertEquals("", diff(result_1, expected_1, 10));

final Table result_2 = source.updateBy(List.of(
CumMax("cumMax=X"),
CumMin("cumMin=X"),
RollingMax(2, "rollingMax=X"),
RollingMin(2, "rollingMin=X")));
final Table expected_2 = TableTools.newTable(x, cumMax, cumMin, rollingMax, rollingMin);
Assert.assertEquals("", diff(result_2, expected_2, 10));

final Table result_3 = source.updateBy(List.of(
RollingMin(2, "rollingMin=X"),
RollingMax(2, "rollingMax=X"),
CumMin("cumMin=X"),
CumMax("cumMax=X")));
final Table expected_3 = TableTools.newTable(x, rollingMin, rollingMax, cumMin, cumMax);
Assert.assertEquals("", diff(result_3, expected_3, 10));

final Table result_4 = source.updateBy(List.of(
RollingMax(2, "rollingMax=X"),
RollingMin(2, "rollingMin=X"),
CumMax("cumMax=X"),
CumMin("cumMin=X")));
final Table expected_4 = TableTools.newTable(x, rollingMax, rollingMin, cumMax, cumMin);
Assert.assertEquals("", diff(result_4, expected_4, 10));

final Table result_5 = source.updateBy(List.of(
CumMin("cumMin=X"),
RollingMin(2, "rollingMin=X"),
CumMax("cumMax=X"),
RollingMax(2, "rollingMax=X")));
final Table expected_5 = TableTools.newTable(x, cumMin, rollingMin, cumMax, rollingMax);
Assert.assertEquals("", diff(result_5, expected_5, 10));

// Trickiest one, since we internally combine groupBy operations.
final Table source_2 = source.update("Y=ii % 2");
final Table result_6 = source_2.updateBy(List.of(
CumMin("cumMin=X"),
RollingGroup(2, "rollingGroupY=Y"),
RollingMin(2, "rollingMin=X"),
CumMax("cumMax=X"),
RollingGroup(2, "rollingGroupX=X"),
RollingMax(2, "rollingMax=X")));

Assert.assertArrayEquals(result_6.getDefinition().getColumnNamesArray(),
new String[] {
"X",
"Y",
"cumMin",
"rollingGroupY",
"rollingMin",
"cumMax",
"rollingGroupX",
"rollingMax"});
}
}