diff --git a/athena-dynamodb/athena-dynamodb.yaml b/athena-dynamodb/athena-dynamodb.yaml index ddcf2c8456..83390efb5c 100644 --- a/athena-dynamodb/athena-dynamodb.yaml +++ b/athena-dynamodb/athena-dynamodb.yaml @@ -105,6 +105,7 @@ Resources: - dynamodb:ListTables - dynamodb:Query - dynamodb:Scan + - dynamodb:PartiQLSelect - glue:GetTableVersions - glue:GetPartitions - glue:GetTables diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java index 9ea540bbb7..4036d4176c 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java @@ -30,6 +30,8 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse; import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest; @@ -40,12 +42,14 @@ import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer; +import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connectors.dynamodb.constants.DynamoDBConstants; import com.amazonaws.athena.connectors.dynamodb.credentials.CrossAccountCredentialsProviderV2; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBIndex; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBPaginatedTables; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBTable; +import com.amazonaws.athena.connectors.dynamodb.qpt.DDBQueryPassthrough; import com.amazonaws.athena.connectors.dynamodb.resolver.DynamoDBTableResolver; import com.amazonaws.athena.connectors.dynamodb.util.DDBPredicateUtils; import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata; @@ -59,6 +63,7 @@ import com.amazonaws.services.secretsmanager.AWSSecretsManager; import com.amazonaws.util.json.Jackson; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -68,6 +73,8 @@ import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementRequest; +import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementResponse; import java.util.ArrayList; import java.util.Collections; @@ -98,6 +105,7 @@ import static com.amazonaws.athena.connectors.dynamodb.constants.DynamoDBConstants.SEGMENT_ID_PROPERTY; import static com.amazonaws.athena.connectors.dynamodb.constants.DynamoDBConstants.TABLE_METADATA; import static com.amazonaws.athena.connectors.dynamodb.throttling.DynamoDBExceptionFilter.EXCEPTION_FILTER; +import static com.amazonaws.athena.connectors.dynamodb.util.DDBTableUtils.SCHEMA_INFERENCE_NUM_RECORDS; /** * Handles metadata requests for the Athena DynamoDB Connector. @@ -134,6 +142,8 @@ public class DynamoDBMetadataHandler private final AWSGlue glueClient; private final DynamoDBTableResolver tableResolver; + private final DDBQueryPassthrough queryPassthrough; + public DynamoDBMetadataHandler(java.util.Map configOptions) { super(SOURCE_TYPE, configOptions); @@ -143,6 +153,7 @@ public DynamoDBMetadataHandler(java.util.Map configOptions) this.glueClient = getAwsGlue(); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); this.tableResolver = new DynamoDBTableResolver(invoker, ddbClient); + this.queryPassthrough = new DDBQueryPassthrough(); } @VisibleForTesting @@ -161,6 +172,16 @@ public DynamoDBMetadataHandler(java.util.Map configOptions) this.ddbClient = ddbClient; this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); this.tableResolver = new DynamoDBTableResolver(invoker, ddbClient); + this.queryPassthrough = new DDBQueryPassthrough(); + } + + @Override + public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request) + { + ImmutableMap.Builder> capabilities = ImmutableMap.builder(); + this.queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, this.configOptions); + + return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build()); } /** @@ -230,6 +251,27 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque return new ListTablesResponse(request.getCatalogName(), new ArrayList<>(combinedTables), token); } + @Override + public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception + { + if (!request.isQueryPassthrough()) { + throw new IllegalArgumentException("No Query passed through [{}]" + request); + } + + queryPassthrough.verify(request.getQueryPassthroughArguments()); + String partiQLStatement = request.getQueryPassthroughArguments().get(DDBQueryPassthrough.QUERY); + ExecuteStatementRequest executeStatementRequest = + ExecuteStatementRequest.builder() + .statement(partiQLStatement) + .limit(SCHEMA_INFERENCE_NUM_RECORDS) + .build(); + //PartiQL on DynamoDB Doesn't allow a dry run; therefore, we look "Peek" over the first few records + ExecuteStatementResponse response = ddbClient.executeStatement(executeStatementRequest); + SchemaBuilder schemaBuilder = DDBTableUtils.buildSchemaFromItems(response.items()); + + return new GetTableResponse(request.getCatalogName(), request.getTableName(), schemaBuilder.build(), Collections.emptySet()); + } + /** * Fetches a table's schema from Glue DataCatalog if present and not disabled, otherwise falls * back to doing a small table scan derives a schema from that. @@ -268,6 +310,10 @@ public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest req @Override public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTableLayoutRequest request) { + if (request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) { + //Query passthrough does not support partition + return; + } // use the source table name from the schema if available (in case Glue table name != actual table name) String tableName = getSourceTableName(request.getSchema()); if (tableName == null) { @@ -414,6 +460,11 @@ private void precomputeAdditionalMetadata(Set columnsToIgnore, Map splits = new HashSet<>(); Block partitions = request.getPartitions(); @@ -509,4 +560,21 @@ private String encodeContinuationToken(int partition) { return String.valueOf(partition); } + + /** + * Helper function that provides a single partition for Query Pass-Through + * + */ + private GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request) + { + //Every split must have a unique location if we wish to spill to avoid failures + SpillLocation spillLocation = makeSpillLocation(request); + + //Since this is QPT query we return a fixed split. + Map qptArguments = request.getConstraints().getQueryPassthroughArguments(); + return new GetSplitsResponse(request.getCatalogName(), + Split.newBuilder(spillLocation, makeEncryptionKey()) + .applyProperties(qptArguments) + .build()); + } } diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java index ce74d92cd1..b06bb5aa1e 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java @@ -30,6 +30,7 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.dynamodb.credentials.CrossAccountCredentialsProviderV2; +import com.amazonaws.athena.connectors.dynamodb.qpt.DDBQueryPassthrough; import com.amazonaws.athena.connectors.dynamodb.resolver.DynamoDBFieldResolver; import com.amazonaws.athena.connectors.dynamodb.util.DDBPredicateUtils; import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata; @@ -50,6 +51,8 @@ import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementRequest; +import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementResponse; import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.QueryResponse; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; @@ -104,6 +107,8 @@ public class DynamoDBRecordHandler private final LoadingCache invokerCache; private final DynamoDbClient ddbClient; + private final DDBQueryPassthrough queryPassthrough = new DDBQueryPassthrough(); + public DynamoDBRecordHandler(java.util.Map configOptions) { super(sourceType, configOptions); @@ -149,6 +154,11 @@ public ThrottlingInvoker load(String tableName) protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) throws ExecutionException { + if (recordsRequest.getConstraints().isQueryPassThrough()) { + logger.info("readWithConstraint for QueryPassthrough PartiQL Query"); + handleQueryPassthroughPartiQLQuery(spiller, recordsRequest, queryStatusChecker); + return; + } Split split = recordsRequest.getSplit(); // use the property instead of the request table name because of case sensitivity String tableName = split.getProperty(TABLE_METADATA); @@ -190,8 +200,43 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor } Iterator> itemIterator = getIterator(split, tableName, recordsRequest.getSchema(), recordsRequest.getConstraints(), disableProjectionAndCasing); + writeItemsToBlock(spiller, recordsRequest, queryStatusChecker, recordMetadata, itemIterator, disableProjectionAndCasing); + } + + private void handleQueryPassthroughPartiQLQuery(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) + { + if (!recordsRequest.getConstraints().isQueryPassThrough()) { + throw new RuntimeException("Attempting to readConstraints with Query Passthrough without PartiQL Query"); + } + queryPassthrough.verify(recordsRequest.getConstraints().getQueryPassthroughArguments()); + + DDBRecordMetadata recordMetadata = new DDBRecordMetadata(recordsRequest.getSchema()); + + String partiQLStatement = recordsRequest.getConstraints().getQueryPassthroughArguments().get(DDBQueryPassthrough.QUERY); + ExecuteStatementRequest executeStatementRequest = + ExecuteStatementRequest.builder() + .statement(partiQLStatement) + .build(); + + ExecuteStatementResponse response = ddbClient.executeStatement(executeStatementRequest); + + Iterator> itemIterator = response.items().iterator(); + writeItemsToBlock(spiller, recordsRequest, queryStatusChecker, recordMetadata, itemIterator, false); + } + + private void writeItemsToBlock( + BlockSpiller spiller, + ReadRecordsRequest recordsRequest, + QueryStatusChecker queryStatusChecker, + DDBRecordMetadata recordMetadata, + Iterator> itemIterator, + boolean disableProjectionAndCasing) + { DynamoDBFieldResolver resolver = new DynamoDBFieldResolver(recordMetadata); + String disableProjectionAndCasingEnvValue = configOptions.getOrDefault(DISABLE_PROJECTION_AND_CASING_ENV, "auto").toLowerCase(); + logger.info(DISABLE_PROJECTION_AND_CASING_ENV + " environment variable set to: " + disableProjectionAndCasingEnvValue); + GeneratedRowWriter.RowWriterBuilder rowWriterBuilder = GeneratedRowWriter.newBuilder(recordsRequest.getConstraints()); //register extract and field writer factory for each field. for (Field next : recordsRequest.getSchema().getFields()) { diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java new file mode 100644 index 0000000000..5e714bdc5c --- /dev/null +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/qpt/DDBQueryPassthrough.java @@ -0,0 +1,93 @@ +/*- + * #%L + * athena-jdbc + * %% + * Copyright (C) 2019 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.dynamodb.qpt; + +import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature; +import com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +public class DDBQueryPassthrough implements QueryPassthroughSignature +{ + // Constant value representing the name of the query. + public static final String NAME = "query"; + + // Constant value representing the domain of the query. + public static final String SCHEMA_NAME = "system"; + + // List of arguments for the query, statically initialized as it always contains the same value. + public static final String QUERY = "QUERY"; + + public static final List ARGUMENTS = Arrays.asList(QUERY); + + private static final Logger LOGGER = LoggerFactory.getLogger(DDBQueryPassthrough.class); + + @Override + public String getFunctionSchema() + { + return SCHEMA_NAME; + } + + @Override + public String getFunctionName() + { + return NAME; + } + + @Override + public List getFunctionArguments() + { + return ARGUMENTS; + } + + @Override + public Logger getLogger() + { + return LOGGER; + } + + @Override + public void customConnectorVerifications(Map engineQptArguments) + { + String partiQLStatement = engineQptArguments.get(QUERY); + String upperCaseStatement = partiQLStatement.trim().toUpperCase(Locale.ENGLISH); + + // Immediately check if the statement starts with "SELECT" + if (!upperCaseStatement.startsWith("SELECT")) { + throw new UnsupportedOperationException("Statement does not start with SELECT."); + } + + // List of disallowed keywords + Set disallowedKeywords = ImmutableSet.of("INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"); + + // Check if the statement contains any disallowed keywords + for (String keyword : disallowedKeywords) { + if (upperCaseStatement.contains(keyword)) { + throw new UnsupportedOperationException("Unaccepted operation; only SELECT statements are allowed. Found: " + keyword); + } + } + } +} diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java index 9c0c97ead1..c1a98945f2 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBTableUtils.java @@ -24,7 +24,6 @@ import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBIndex; import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBTable; import com.google.common.collect.ImmutableList; -import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,11 +42,10 @@ import software.amazon.awssdk.services.dynamodb.model.ScanResponse; import software.amazon.awssdk.services.dynamodb.model.TableDescription; -import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.concurrent.TimeoutException; /** @@ -63,7 +61,7 @@ public final class DDBTableUtils private static final int MIN_SCAN_SEGMENTS = 1; private static final long MAX_BYTES_PER_SEGMENT = 1024L * 1024L * 1024L; private static final double MIN_IO_PER_SEGMENT = 100.0; - private static final int SCHEMA_INFERENCE_NUM_RECORDS = 4; + public static final int SCHEMA_INFERENCE_NUM_RECORDS = 4; private DDBTableUtils() {} @@ -154,19 +152,7 @@ public static Schema peekTableForSchema(String tableName, ThrottlingInvoker invo ScanResponse scanResponse = invoker.invoke(() -> ddbClient.scan(scanRequest)); if (!scanResponse.items().isEmpty()) { List> items = scanResponse.items(); - Set discoveredColumns = new HashSet<>(); - - for (Map item : items) { - for (Map.Entry column : item.entrySet()) { - if (!discoveredColumns.contains(column.getKey())) { - Field field = DDBTypeUtils.inferArrowField(column.getKey(), column.getValue()); - if (field != null) { - schemaBuilder.addField(field); - discoveredColumns.add(column.getKey()); - } - } - } - } + schemaBuilder = buildSchemaFromItems(items); } else { // there's no items, so use any attributes defined in the table metadata @@ -187,6 +173,23 @@ public static Schema peekTableForSchema(String tableName, ThrottlingInvoker invo return schemaBuilder.build(); } + /** + * A utility method that takes a list of items, and returns a schema builder + * @param items a list of a map of DynamoDB elements + * @return schema builder + */ + public static SchemaBuilder buildSchemaFromItems(List> items) + { + SchemaBuilder schemaBuilder = new SchemaBuilder(); + items.stream() + .flatMap(item -> item.entrySet().stream()) + .map(column -> DDBTypeUtils.inferArrowField(column.getKey(), column.getValue())) + .filter(Objects::nonNull) + .distinct() + .forEach(schemaBuilder::addField); + return schemaBuilder; + } + /** * This hueristic determines an optimal segment count to perform Parallel Scans with using the table's capacity * and size. diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java index 502453ed0b..2604b5e228 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java @@ -35,6 +35,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -499,8 +500,10 @@ private void safeClose(AutoCloseable block) private ThreadPoolExecutor makeAsyncSpillPool(SpillConfig config) { int spillQueueCapacity = config.getNumSpillThreads(); - if (configOptions.get(SPILL_QUEUE_CAPACITY) != null) { - spillQueueCapacity = Integer.parseInt(configOptions.get(SPILL_QUEUE_CAPACITY)); + + String capacity = StringUtils.isNotBlank(configOptions.get(SPILL_QUEUE_CAPACITY)) ? configOptions.get(SPILL_QUEUE_CAPACITY) : configOptions.get(SPILL_QUEUE_CAPACITY.toLowerCase()); + if (capacity != null) { + spillQueueCapacity = Integer.parseInt(capacity); logger.debug("Setting Spill Queue Capacity to {}", spillQueueCapacity); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/optimizations/querypassthrough/QueryPassthroughSignature.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/optimizations/querypassthrough/QueryPassthroughSignature.java index 1b5472cc23..4e137d943d 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/optimizations/querypassthrough/QueryPassthroughSignature.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/metadata/optimizations/querypassthrough/QueryPassthroughSignature.java @@ -90,13 +90,13 @@ public default void verify(Map engineQptArguments) } } //Finally, perform any connector-specific verification; - customConnectorVerifications(); + customConnectorVerifications(engineQptArguments); } /** * Provides a mechanism to perform custom connector verification logic. */ - default void customConnectorVerifications() + default void customConnectorVerifications(Map engineQptArguments) { //No Op }