Skip to content

Commit

Permalink
fix trailing checksum
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Nov 11, 2024
1 parent 01f7ee3 commit 4f32699
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 61 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#pragma once
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/HashingUtils.h>

namespace Aws {
namespace Utils {
namespace Stream {

static const size_t AWS_DATA_BUFFER_SIZE = 65536;

template <size_t DataBufferSize = AWS_DATA_BUFFER_SIZE>
class AwsChunkedStream {
public:
AwsChunkedStream(Http::HttpRequest *request, const std::shared_ptr<Aws::IOStream> &stream)
: m_chunkingStream{Aws::MakeShared<StringStream>("AwsChunkedStream")}, m_request(request), m_stream(stream) {}

size_t BufferedRead(char *dst, size_t amountToRead) {
// the chunk has ended and cannot be read from
if (m_chunkEnd) {
return 0;
}

// If we've read all of the underlying stream write the checksum trailing header
// the set that the chunked stream is over.
if (m_stream->eof() && (m_chunkingStream->eof() || m_chunkingStream->peek() == EOF)) {
Aws::StringStream chunkedTrailer;
chunkedTrailer << "0\r\n";
if (m_request->GetRequestHash().second != nullptr) {
chunkedTrailer << "x-amz-checksum-" << m_request->GetRequestHash().first << ":"
<< HashingUtils::Base64Encode(m_request->GetRequestHash().second->GetHash().GetResult()) << "\r\n";
}
chunkedTrailer << "\r\n";
auto trailerSize = chunkedTrailer.str().size();
memcpy(dst, chunkedTrailer.str().c_str(), trailerSize);
m_chunkEnd = true;
return trailerSize;
}

// Try to read in a 64K chunk, if we cant we know the stream is over
size_t bytesRead = 0;
while (m_stream->good() && bytesRead < DataBufferSize) {
m_stream->read(&m_data[bytesRead], DataBufferSize - bytesRead);
bytesRead += m_stream->gcount();
}

// update the trailing checksum to be sent only if we read data and buffered.
if (bytesRead > 0 && m_request->GetRequestHash().second != nullptr) {
m_request->GetRequestHash().second->Update(reinterpret_cast<unsigned char *>(m_data.data()), bytesRead);
}

// Buffer chunked encoding from data if there was data read to the buffer, otherwise leave it alone/
if (bytesRead > 0 && m_chunkingStream != nullptr) {
*m_chunkingStream << Aws::Utils::StringUtils::ToHexString(bytesRead) << "\r\n";
std::copy(m_data.begin(), m_data.begin() + bytesRead, std::ostream_iterator<char>(*m_chunkingStream));
*m_chunkingStream << "\r\n";
auto curr = m_chunkingStream->tellg();
const auto rn = m_chunkingStream->rdbuf();
AWS_UNREFERENCED_PARAM(rn);
m_chunkingStream->seekg(curr);
}

// Read to destination buffer, return how much was read
m_chunkingStream->read(dst, amountToRead);
return m_chunkingStream->gcount();
}

private:
std::array<char, DataBufferSize> m_data;
std::shared_ptr<Aws::IOStream> m_chunkingStream;
bool m_chunkEnd{false};
Http::HttpRequest *m_request;
std::shared_ptr<Aws::IOStream> m_stream;
};
} // namespace Stream
} // namespace Utils
} // namespace Aws
83 changes: 22 additions & 61 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/ratelimiter/RateLimiterInterface.h>
#include <aws/core/utils/stream/AwsChunkedStream.h>
#include <aws/core/utils/DateTime.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/Outcome.h>
Expand All @@ -24,6 +25,7 @@ using namespace Aws::Http;
using namespace Aws::Http::Standard;
using namespace Aws::Utils;
using namespace Aws::Utils::Logging;
using namespace Aws::Utils::Stream;
using namespace Aws::Monitoring;

#ifdef USE_AWS_MEMORY_MANAGEMENT
Expand Down Expand Up @@ -144,25 +146,27 @@ struct CurlWriteCallbackContext
int64_t m_numBytesResponseReceived;
};

static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient";

struct CurlReadCallbackContext
{
CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request, Aws::Utils::RateLimits::RateLimiterInterface* limiter) :
m_client(client),
m_curlHandle(curlHandle),
m_rateLimiter(limiter),
m_request(request),
m_chunkEnd(false)
m_chunkEnd(false),
m_chunkedStream{Aws::MakeShared<AwsChunkedStream<>>(CURL_HTTP_CLIENT_TAG, request, request->GetContentBody())}
{}

const CurlHttpClient* m_client;
CURL* m_curlHandle;
Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter;
HttpRequest* m_request;
bool m_chunkEnd;
std::shared_ptr<Stream::AwsChunkedStream<>> m_chunkedStream;
};

static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient";

static int64_t GetContentLengthFromHeader(CURL* connectionHandle,
bool& hasContentLength) {
#if LIBCURL_VERSION_NUM >= 0x073700 // 7.55.0
Expand Down Expand Up @@ -293,68 +297,25 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo
size_t amountToRead = size * nmemb;
bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) &&
request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER) == Aws::Http::AWS_CHUNKED_VALUE;
// aws-chunk = hex(chunk-size) + CRLF + chunk-data + CRLF
// Needs to reserve bytes of sizeof(hex(chunk-size)) + sizeof(CRLF) + sizeof(CRLF)
if (isAwsChunked)
{
Aws::String amountToReadHexString = Aws::Utils::StringUtils::ToHexString(amountToRead);
amountToRead -= (amountToReadHexString.size() + 4);
}

if (ioStream != nullptr && amountToRead > 0)
{
size_t amountRead = 0;
if (isStreaming)
{
if (!ioStream->eof() && ioStream->peek() != EOF)
{
amountRead = (size_t) ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof())
{
return CURL_READFUNC_PAUSE;
}
}
else
{
ioStream->read(ptr, amountToRead);
amountRead = static_cast<size_t>(ioStream->gcount());
}

if (isAwsChunked)
{
if (amountRead > 0)
{
if (request->GetRequestHash().second != nullptr)
{
request->GetRequestHash().second->Update(reinterpret_cast<unsigned char*>(ptr), amountRead);
}

Aws::String hex = Aws::Utils::StringUtils::ToHexString(amountRead);
memmove(ptr + hex.size() + 2, ptr, amountRead);
memmove(ptr + hex.size() + 2 + amountRead, "\r\n", 2);
memmove(ptr, hex.c_str(), hex.size());
memmove(ptr + hex.size(), "\r\n", 2);
amountRead += hex.size() + 4;
}
else if (!context->m_chunkEnd)
{
Aws::StringStream chunkedTrailer;
chunkedTrailer << "0\r\n";
if (request->GetRequestHash().second != nullptr)
{
chunkedTrailer << "x-amz-checksum-"
<< request->GetRequestHash().first
<< ":"
<< HashingUtils::Base64Encode(request->GetRequestHash().second->GetHash().GetResult())
<< "\r\n";
}
chunkedTrailer << "\r\n";
amountRead = chunkedTrailer.str().size();
memcpy(ptr, chunkedTrailer.str().c_str(), amountRead);
context->m_chunkEnd = true;
}
}
if (isStreaming) {
if (!ioStream->eof() && ioStream->peek() != EOF) {
amountRead = (size_t)ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof()) {
return CURL_READFUNC_PAUSE;
}
} else if (isAwsChunked) {
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Called with size: " << amountToRead);
amountRead = context->m_chunkedStream->BufferedRead(ptr, amountToRead);
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "read: " << amountRead);
} else {
ioStream->read(ptr, amountToRead);
amountRead = static_cast<size_t>(ioStream->gcount());
}

auto& sentHandler = request->GetDataSentEventHandler();
if (sentHandler)
Expand Down
35 changes: 35 additions & 0 deletions tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/core/http/standard/StandardHttpRequest.h>
#include <aws/core/utils/crypto/CRC32.h>
#include <aws/core/utils/stream/AwsChunkedStream.h>
#include <aws/testing/AwsCppSdkGTestSuite.h>

using namespace Aws;
using namespace Aws::Http::Standard;
using namespace Aws::Utils::Stream;
using namespace Aws::Utils::Crypto;

class AwsChunkedStreamTest : public Aws::Testing::AwsCppSdkGTestSuite {};

const char* TEST_LOG_TAG = "AWS_CHUNKED_STREAM_TEST";

TEST_F(AwsChunkedStreamTest, ChunkedStreamShouldWork) {
StandardHttpRequest request{"www.elda.com/will", Http::HttpMethod::HTTP_GET};
auto requestHash = Aws::MakeShared<CRC32>(TEST_LOG_TAG);
request.SetRequestHash("crc32", requestHash);
std::shared_ptr<IOStream> inputStream = Aws::MakeShared<StringStream>(TEST_LOG_TAG, "1234567890123456789012345");
AwsChunkedStream<10> chunkedStream{&request, inputStream};
std::array<char, 100> outputBuffer{};
Aws::StringStream output;
size_t read = 0;
do {
read = chunkedStream.BufferedRead(outputBuffer.data(), 10);
std::copy(outputBuffer.begin(), outputBuffer.begin() + read, std::ostream_iterator<char>(output));
} while (read > 0);
const auto encodedStr = output.str();
auto expectedStreamWithChecksum = "A\r\n1234567890\r\nA\r\n1234567890\r\n5\r\n12345\r\n0\r\nx-amz-checksum-crc32:78DeVw==\r\n\r\n";
EXPECT_EQ(expectedStreamWithChecksum, encodedStr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2518,4 +2518,38 @@ namespace
}
}
}

TEST_F(BucketAndObjectOperationTest, PutObjectChecksumWithGuarunteedChunkedObject) {
struct ChecksumTestCase {
std::function<PutObjectRequest(PutObjectRequest)> chucksumRequestMutator;
String body;
};

const String fullBucketName = CalculateBucketName(BASE_CHECKSUMS_BUCKET_NAME.c_str());
SCOPED_TRACE(Aws::String("FullBucketName ") + fullBucketName);
CreateBucketRequest createBucketRequest;
createBucketRequest.SetBucket(fullBucketName);
createBucketRequest.SetACL(BucketCannedACL::private_);
CreateBucketOutcome createBucketOutcome = CreateBucket(createBucketRequest);
AWS_ASSERT_SUCCESS(createBucketOutcome);

Vector<ChecksumTestCase> testCases{
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32); },
Aws::String(1024 * 1024, 'e')},
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32C); },
Aws::String(1024 * 1024, 'l')},
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::SHA1); },
Aws::String(1024 * 1024, 'd')},
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::SHA256); },
Aws::String(1024 * 1024, 'a')}};

for (const auto& testCase : testCases) {
auto request = testCase.chucksumRequestMutator(PutObjectRequest().WithBucket(fullBucketName).WithKey("Metaphor"));
std::shared_ptr<IOStream> body =
Aws::MakeShared<StringStream>(ALLOCATION_TAG, testCase.body, std::ios_base::in | std::ios_base::binary);
request.SetBody(body);
const auto response = Client->PutObject(request);
EXPECT_TRUE(response.IsSuccess());
}
}
}

0 comments on commit 4f32699

Please sign in to comment.