Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qpt vertica changes #1853

Merged
merged 8 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions athena-vertica/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
<artifactId>ST4</artifactId>
<version>${antlr.st4.version}</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>athena-jdbc</artifactId>
<version>2022.47.1</version>
AbdulR3hman marked this conversation as resolved.
Show resolved Hide resolved
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,19 @@
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.*;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetTableRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse;
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.vertica.query.QueryFactory;
import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder;
Expand All @@ -41,17 +53,31 @@
import com.amazonaws.services.s3.model.ObjectListing;
import com.amazonaws.services.s3.model.S3ObjectSummary;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import org.apache.commons.lang3.StringUtils;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.stringtemplate.v4.ST;

import java.sql.*;
import java.util.*;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import static com.amazonaws.athena.connectors.vertica.VerticaSchemaUtils.convertToArrowType;


public class VerticaMetadataHandler
Expand All @@ -77,7 +103,8 @@ public class VerticaMetadataHandler
private final VerticaSchemaUtils verticaSchemaUtils;
private AmazonS3 amazonS3;

public VerticaMetadataHandler(java.util.Map<String, String> configOptions)
private final VerticaQueryPassthrough queryPassthrough = new VerticaQueryPassthrough();
public VerticaMetadataHandler(Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
amazonS3 = AmazonS3ClientBuilder.defaultClient();
Expand All @@ -87,15 +114,15 @@ public VerticaMetadataHandler(java.util.Map<String, String> configOptions)

@VisibleForTesting
protected VerticaMetadataHandler(
EncryptionKeyFactory keyFactory,
VerticaConnectionFactory connectionFactory,
AWSSecretsManager awsSecretsManager,
AmazonAthena athena,
String spillBucket,
String spillPrefix,
VerticaSchemaUtils verticaSchemaUtils,
AmazonS3 amazonS3,
java.util.Map<String, String> configOptions)
EncryptionKeyFactory keyFactory,
VerticaConnectionFactory connectionFactory,
AWSSecretsManager awsSecretsManager,
AmazonAthena athena,
String spillBucket,
String spillPrefix,
VerticaSchemaUtils verticaSchemaUtils,
AmazonS3 amazonS3,
Map<String, String> configOptions)
{
super(keyFactory, awsSecretsManager, athena, SOURCE_TYPE, spillBucket, spillPrefix, configOptions);
this.connectionFactory = connectionFactory;
Expand Down Expand Up @@ -174,6 +201,51 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque
return new ListTablesResponse(request.getCatalogName(), tables, null);

}
protected ArrowType getArrayArrowTypeFromTypeName(String typeName, int precision, int scale)
{
// Default ARRAY type is VARCHAR.
return new ArrowType.Utf8();
}
@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}


@Override
public GetTableResponse doGetQueryPassthroughSchema(final BlockAllocator blockAllocator, final GetTableRequest getTableRequest)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of this method brings up few issues;

  1. We are using JDBC Implementation here; (almost copy/paste from the JDBC Metadatahandler) whereas doGetTable uses VerticaSchemaUtils class.
  2. This brings up an important question; why are we not using the JDBC metadatahandler to begin with? If there is a good reason for it, then we shouldn't be using the JDBC implementation here either.
  3. If it seems that there in fact we don't need to have different implementation between JDBC metadata handler and Vertica's, we need to use this opportunity to unify the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorporated the changes for vertica datatype to arrow in dogetQuerypassthroughSchema.

throws Exception
{
if (!getTableRequest.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + getTableRequest);
}

queryPassthrough.verify(getTableRequest.getQueryPassthroughArguments());
String customerPassedQuery = getTableRequest.getQueryPassthroughArguments().get(VerticaQueryPassthrough.QUERY);

try (Connection connection = getConnection(getTableRequest)) {
PreparedStatement preparedStatement = connection.prepareStatement(customerPassedQuery);
ResultSetMetaData metadata = preparedStatement.getMetaData();
if (metadata == null) {
throw new UnsupportedOperationException("Query not supported: ResultSetMetaData not available for query: " + customerPassedQuery);
}
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();

for (int columnIndex = 1; columnIndex <= metadata.getColumnCount(); columnIndex++) {
String columnName = metadata.getColumnName(columnIndex);
String columnLabel = metadata.getColumnLabel(columnIndex);
columnName = columnName.equals(columnLabel) ? columnName : columnLabel;
convertToArrowType(schemaBuilder, columnName, metadata.getColumnTypeName(columnIndex));
}

Schema schema = schemaBuilder.build();
return new GetTableResponse(getTableRequest.getCatalogName(), getTableRequest.getTableName(), schema, Collections.emptySet());
}
}

/**
* Used to get definition (field names, types, descriptions, etc...) of a Table.
Expand All @@ -200,7 +272,7 @@ public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest req
request.getTableName(),
schema,
partitionCols
);
);
}

/**
Expand Down Expand Up @@ -228,7 +300,7 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl
* @param queryStatusChecker A QueryStatusChecker that you can use to stop doing work for a query that has already terminated
*/
@Override
public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws SQLException {
public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws SQLException {
logger.info("in getPartitions: "+ request);

Schema schemaName = request.getSchema();
Expand All @@ -243,20 +315,29 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request

//Build the SQL query
Connection connection = getConnection(request);
DatabaseMetaData dbMetadata = connection.getMetaData();
ResultSet definition = dbMetadata.getColumns(null, tableName.getSchemaName(), tableName.getTableName(), null);

// if QPT get input query from Athena console
//else old logic

VerticaExportQueryBuilder queryBuilder = queryFactory.createVerticaExportQueryBuilder();
String preparedSQLStmt;

String preparedSQLStmt = queryBuilder.withS3ExportBucket(s3ExportBucket)
.withQueryID(queryID)
.withColumns(definition, schemaName)
.fromTable(tableName.getSchemaName(), tableName.getTableName())
.withConstraints(constraints, schemaName)
.build();
if (!request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) {
aimethed marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Vertica Export Statement: {}", preparedSQLStmt);
DatabaseMetaData dbMetadata = connection.getMetaData();
ResultSet definition = dbMetadata.getColumns(null, tableName.getSchemaName(), tableName.getTableName(), null);

preparedSQLStmt = queryBuilder.withS3ExportBucket(s3ExportBucket)
.withQueryID(queryID)
.withColumns(definition, schemaName)
.fromTable(tableName.getSchemaName(), tableName.getTableName())
.withConstraints(constraints, schemaName)
.build();
} else {
preparedSQLStmt = null;
}

logger.info("Vertica Export Statement: {}", preparedSQLStmt);
// Build the Set AWS Region SQL
String awsRegionSql = queryBuilder.buildSetAwsRegionSql(amazonS3.getRegion().toString());

Expand Down Expand Up @@ -294,28 +375,40 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
Set<Split> splits = new HashSet<>();
String exportBucket = getS3ExportBucket();
String queryId = request.getQueryId().replace("-","");

Constraints constraints = request.getConstraints();
String s3ExportBucket = getS3ExportBucket();
String sqlStatement;
//testing if the user has access to the requested table
testAccess(connection, request.getTableName());

FieldReader fieldReaderQid = request.getPartitions().getFieldReader("queryId");
String queryID = fieldReaderQid.readText().toString();

//get the SQL statement which was created in getPartitions
FieldReader fieldReaderPS = request.getPartitions().getFieldReader("preparedStmt");
String sqlStatement = fieldReaderPS.readText().toString();
if (constraints.isQueryPassThrough()) {
String preparedSQL = buildQueryPassthroughSql(constraints);
VerticaExportQueryBuilder queryBuilder = queryFactory.createQptVerticaExportQueryBuilder();
sqlStatement = queryBuilder.withS3ExportBucket(s3ExportBucket)
.withQueryID(queryID)
.withPreparedStatementSQL(preparedSQL).build();
logger.info("Vertica Export Statement: {}", sqlStatement);
}
else {
testAccess(connection, request.getTableName());
sqlStatement = fieldReaderPS.readText().toString();
}
String catalogName = request.getCatalogName();

FieldReader fieldReaderQid = request.getPartitions().getFieldReader("queryId");
String queryID = fieldReaderQid.readText().toString();

FieldReader fieldReaderAwsRegion = request.getPartitions().getFieldReader("awsRegionSql");
String awsRegionSql = fieldReaderAwsRegion.readText().toString();


//execute the queries on Vertica
executeQueriesOnVertica(connection, sqlStatement, awsRegionSql);

/*
* For each generated S3 object, create a split and add data to the split.
*/
/*
* For each generated S3 object, create a split and add data to the split.
*/
Split split;
List<S3ObjectSummary> s3ObjectSummaries = getlistExportedObjects(exportBucket, queryId);

Expand All @@ -336,19 +429,19 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
return new GetSplitsResponse(catalogName, splits);
}
else
{
//No records were exported by Vertica for the issued query, creating a "empty" split
logger.info("No records were exported by Vertica");
split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey())
.add("query_id", queryID)
.add(VERTICA_CONN_STR, getConnStr(request))
.add("exportBucket", exportBucket)
.add("s3ObjectKey", EMPTY_STRING)
.build();
splits.add(split);
logger.info("doGetSplits: exit - " + splits.size());
return new GetSplitsResponse(catalogName,split);
}
{
//No records were exported by Vertica for the issued query, creating a "empty" split
logger.info("No records were exported by Vertica");
split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey())
.add("query_id", queryID)
.add(VERTICA_CONN_STR, getConnStr(request))
.add("exportBucket", exportBucket)
.add("s3ObjectKey", EMPTY_STRING)
.build();
splits.add(split);
logger.info("doGetSplits: exit - " + splits.size());
return new GetSplitsResponse(catalogName,split);
}

}

Expand Down Expand Up @@ -409,7 +502,13 @@ private void testAccess(Connection conn, TableName table) {

public String getS3ExportBucket()
{
return configOptions.get(EXPORT_BUCKET_KEY);
return configOptions.get(EXPORT_BUCKET_KEY);
}

public String buildQueryPassthroughSql(Constraints constraints) throws SQLException
{
queryPassthrough.verify(constraints.getQueryPassthroughArguments());
return constraints.getQueryPassthroughArguments().get(VerticaQueryPassthrough.QUERY);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*-
* #%L
* athena-vertica
* %%
* Copyright (C) 2019 - 2024 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.vertica;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;

public class VerticaQueryPassthrough implements QueryPassthroughSignature
{
// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String QUERY = "QUERY";

public static final List<String> ARGUMENTS = Arrays.asList(QUERY);

private static final Logger LOGGER = LoggerFactory.getLogger(VerticaQueryPassthrough.class);

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return ARGUMENTS;
}

@Override
public Logger getLogger()
{
return LOGGER;
}

}
Loading