Skip to content

Commit

Permalink
Merge pull request #15388 from cdapio/skip_rdd_conversion_pushdown
Browse files Browse the repository at this point in the history
[CDAP-20658] Remove unneeded DataFrame->RDD conversion
  • Loading branch information
tivv authored Oct 27, 2023
2 parents 766a67e + 1abcf61 commit 6bc4382
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright © 2023 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.cdap.etl.spark.batch;

import io.cdap.cdap.api.data.DatasetContext;
import io.cdap.cdap.api.spark.JavaSparkExecutionContext;
import io.cdap.cdap.etl.spark.function.FunctionCache;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;

/**
* Interface for factories that has all data needed to create a {@link BatchCollection}
* when provided with various contexts. An instance of this factory is backed by a specific
* single data set (RDD / Spark DataSet / Spark DataFrame / ...).
* @see RDDCollectionFactory
* @see DataframeCollectionFactory
*/
public interface BatchCollectionFactory<T> {

/**
* Create new BatchCollection with the data from this factory object using contexts
* provided in the parameters.
* @param sec java spark execution context
* @param jsc java spark context
* @param sqlContext sql context
* @param datasetContext dataset context
* @param sinkFactory sink factory
* @param functionCacheFactory function cache factory
* @return specific instance of BatchCollection backed by this factory data.
*/
BatchCollection<T> create(JavaSparkExecutionContext sec, JavaSparkContext jsc,
SQLContext sqlContext, DatasetContext datasetContext, SparkBatchSinkFactory sinkFactory,
FunctionCache.Factory functionCacheFactory);

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.metrics.Metrics;
import io.cdap.cdap.api.spark.JavaSparkExecutionContext;
import io.cdap.cdap.api.spark.sql.DataFrames;
import io.cdap.cdap.etl.api.StageMetrics;
import io.cdap.cdap.etl.api.engine.sql.SQLEngine;
import io.cdap.cdap.etl.api.engine.sql.SQLEngineException;
Expand Down Expand Up @@ -72,7 +71,6 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -273,14 +271,14 @@ public SQLDataset pushInternal(String datasetName,
* @return Job representing this pull operation.
*/
@SuppressWarnings("unchecked,raw")
public <T> SQLEngineJob<JavaRDD<T>> pull(SQLEngineJob<SQLDataset> job) {
public <T> SQLEngineJob<BatchCollectionFactory<T>> pull(SQLEngineJob<SQLDataset> job) {
//If this job already exists, return the existing instance.
SQLEngineJobKey jobKey = new SQLEngineJobKey(job.getDatasetName(), SQLEngineJobType.PULL);
if (jobs.containsKey(jobKey)) {
return (SQLEngineJob<JavaRDD<T>>) jobs.get(jobKey);
return (SQLEngineJob<BatchCollectionFactory<T>>) jobs.get(jobKey);
}

CompletableFuture<JavaRDD<T>> future = new CompletableFuture<>();
CompletableFuture<BatchCollectionFactory<T>> future = new CompletableFuture<>();

Runnable pullTask = () -> {
try {
Expand All @@ -291,7 +289,7 @@ public <T> SQLEngineJob<JavaRDD<T>> pull(SQLEngineJob<SQLDataset> job) {

// Execute pull operation for the supplied dataset
SQLDataset sqlDataset = job.get();
JavaRDD<T> result = pullInternal(sqlDataset);
BatchCollectionFactory result = pullInternal(sqlDataset);
LOG.debug("Started pull for dataset '{}'", job.getDatasetName());

// Log number of records being pulled into metrics
Expand All @@ -304,7 +302,7 @@ public <T> SQLEngineJob<JavaRDD<T>> pull(SQLEngineJob<SQLDataset> job) {

executorService.submit(pullTask);

SQLEngineJob<JavaRDD<T>> pullJob = new SQLEngineJob<>(jobKey, future);
SQLEngineJob<BatchCollectionFactory<T>> pullJob = new SQLEngineJob<>(jobKey, future);
jobs.put(jobKey, pullJob);

return pullJob;
Expand All @@ -314,11 +312,11 @@ public <T> SQLEngineJob<JavaRDD<T>> pull(SQLEngineJob<SQLDataset> job) {
* Pull implementation. This method has blocking calls and should be executed in a separate thread.
*
* @param dataset the dataset to pull.
* @return {@link JavaRDD} representing the records contained in this dataset.
* @return {@link BatchCollectionFactory} representing the records contained in this dataset.
* @throws SQLEngineException if the pull process fails.
*/
@SuppressWarnings("unchecked,raw")
private <T> JavaRDD<T> pullInternal(SQLDataset dataset) throws SQLEngineException {
private <T> BatchCollectionFactory<T> pullInternal(SQLDataset dataset) throws SQLEngineException {
// Create pull operation for this dataset and wait until completion
SQLPullRequest pullRequest = new SQLPullRequest(dataset);

Expand All @@ -334,12 +332,10 @@ private <T> JavaRDD<T> pullInternal(SQLDataset dataset) throws SQLEngineExceptio
// Note that we only support Spark collections at this time.
// If the collection that got generarted is not an instance of a SparkRecordCollection, skip.
if (recordCollection instanceof SparkRecordCollection) {
Schema schema = dataset.getSchema();
JavaRDD<T> rdd = (JavaRDD<T>) ((SparkRecordCollection) recordCollection).getDataFrame()
.javaRDD()
.map(r -> DataFrames.fromRow((Row) r, schema));
countExecutionStage(SQLEngineJobTypeMetric.SPARK_PULL);
return rdd;
return new DataframeCollectionFactory<T>(
dataset.getSchema(),
((SparkRecordCollection) recordCollection).getDataFrame());
}
}
}
Expand All @@ -357,7 +353,7 @@ private <T> JavaRDD<T> pullInternal(SQLDataset dataset) throws SQLEngineExceptio
return f;
});
countExecutionStage(SQLEngineJobTypeMetric.PULL);
return rdd;
return new RDDCollectionFactory<>(rdd);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright © 2023 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.cdap.etl.spark.batch;

import io.cdap.cdap.api.data.DatasetContext;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.spark.JavaSparkExecutionContext;
import io.cdap.cdap.etl.spark.function.FunctionCache;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

/**
* Factory that creates a {@link DataframeCollection}
*/
public class DataframeCollectionFactory<T> implements BatchCollectionFactory<T> {

private final Schema schema;
private final Dataset<Row> dataFrame;

/**
* Creates dataFrame-based pull result
*
* @param schema schema of the dataFrame
* @param dataFrame result dataFrame
*/
public DataframeCollectionFactory(Schema schema, Dataset<Row> dataFrame) {
this.schema = schema;
this.dataFrame = dataFrame;
}

@Override
public BatchCollection<T> create(JavaSparkExecutionContext sec, JavaSparkContext jsc,
SQLContext sqlContext, DatasetContext datasetContext, SparkBatchSinkFactory sinkFactory,
FunctionCache.Factory functionCacheFactory) {
return (BatchCollection<T>) new DataframeCollection(
schema, dataFrame, sec, jsc, sqlContext, datasetContext,
sinkFactory, functionCacheFactory);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright © 2023 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.cdap.etl.spark.batch;

import io.cdap.cdap.api.data.DatasetContext;
import io.cdap.cdap.api.spark.JavaSparkExecutionContext;
import io.cdap.cdap.etl.spark.function.FunctionCache;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;

/**
* Factory that creates a {@link RDDCollection<T>}
*/
public class RDDCollectionFactory<T> implements BatchCollectionFactory<T> {

private final JavaRDD<T> rdd;

/**
* Creates JavaRDD-based pull result
*/
public RDDCollectionFactory(JavaRDD<T> rdd) {
this.rdd = rdd;
}

@Override
public BatchCollection<T> create(JavaSparkExecutionContext sec, JavaSparkContext jsc,
SQLContext sqlContext, DatasetContext datasetContext, SparkBatchSinkFactory sinkFactory,
FunctionCache.Factory functionCacheFactory) {
return new RDDCollection<>(sec, functionCacheFactory, jsc, sqlContext, datasetContext,
sinkFactory, rdd);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import io.cdap.cdap.etl.spark.function.FunctionCache;
import io.cdap.cdap.etl.spark.join.JoinExpressionRequest;
import io.cdap.cdap.etl.spark.join.JoinRequest;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
Expand Down Expand Up @@ -131,11 +130,11 @@ protected BatchCollection<T> pull() {
// Ensure the local collection is only generated once across multiple threads
synchronized (this) {
if (localCollection == null) {
SQLEngineJob<JavaRDD<T>> pullJob = adapter.pull(job);
SQLEngineJob<BatchCollectionFactory<T>> pullJob = adapter.pull(job);
adapter.waitForJobAndHandleException(pullJob);
JavaRDD<T> rdd = pullJob.waitFor();
localCollection =
new RDDCollection<>(sec, functionCacheFactory, jsc, sqlContext, datasetContext, sinkFactory, rdd);
BatchCollectionFactory<T> pullResult = pullJob.waitFor();
localCollection = pullResult.create(sec, jsc, sqlContext, datasetContext,
sinkFactory, functionCacheFactory);
}
}
return localCollection;
Expand Down

0 comments on commit 6bc4382

Please sign in to comment.