Skip to content

Commit

Permalink
Merge pull request #943 from ennru/sfali-multipart_copy
Browse files Browse the repository at this point in the history
AWS S3: multipart copy (continued)
  • Loading branch information
2m authored May 9, 2018
2 parents 7c49209 + 827ab07 commit 16127f1
Show file tree
Hide file tree
Showing 17 changed files with 1,073 additions and 21 deletions.
16 changes: 16 additions & 0 deletions docs/src/main/paradox/s3.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ Scala
Java
: @@snip ($alpakka$/s3/src/test/java/akka/stream/alpakka/s3/javadsl/S3ClientTest.java) { #list-bucket }


### Copy upload (multi part)

Copy an S3 object from source bucket to target bucket using multi part copy upload.

Scala
: @@snip ($alpakka$/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3SinkSpec.scala) { #multipart-copy }

Java
: @@snip ($alpakka$/s3/src/test/java/akka/stream/alpakka/s3/javadsl/S3ClientTest.java) { #multipart-copy }

#### Java examples with custom headers

Java
: @@snip ($alpakka$/s3/src/test/java/akka/stream/alpakka/s3/javadsl/JavaExamplesSnippets.java) { #java-example }

### Running the example code

The code in this guide is part of runnable tests of this project. You are welcome to edit the code and run it in sbt.
Expand Down
26 changes: 25 additions & 1 deletion s3/src/main/scala/akka/stream/alpakka/s3/impl/HttpRequests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package akka.stream.alpakka.s3.impl
import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._
import akka.http.scaladsl.marshalling.Marshal
import akka.http.scaladsl.model.Uri.{Authority, Query}
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.headers.{Host, RawHeader}
import akka.http.scaladsl.model.{ContentTypes, RequestEntity, _}
import akka.stream.alpakka.s3.S3Settings
import akka.stream.scaladsl.Source
Expand Down Expand Up @@ -105,6 +105,30 @@ private[alpakka] object HttpRequests {
}
}

def uploadCopyPartRequest(multipartCopy: MultipartCopy,
sourceVersionId: Option[String] = None,
s3Headers: S3Headers = S3Headers.empty)(implicit conf: S3Settings): HttpRequest = {
val upload = multipartCopy.multipartUpload
val copyPartition = multipartCopy.copyPartition
val range = copyPartition.range
val source = copyPartition.sourceLocation
val sourceHeaderValuePrefix = s"/${source.bucket}/${source.key}"
val sourceHeaderValue = sourceVersionId
.map(versionId => s"$sourceHeaderValuePrefix?versionId=$versionId")
.getOrElse(sourceHeaderValuePrefix)
val sourceHeader = RawHeader("x-amz-copy-source", sourceHeaderValue)
val copyHeaders = range
.map(br => Seq(sourceHeader, RawHeader("x-amz-copy-source-range", s"bytes=${br.first}-${br.last - 1}")))
.getOrElse(Seq(sourceHeader))

val allHeaders = s3Headers.headers ++ copyHeaders

s3Request(upload.s3Location,
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> copyPartition.partNumber.toString, "uploadId" -> upload.uploadId)))
.withDefaultHeaders(allHeaders: _*)
}

private[this] def s3Request(s3Location: S3Location,
method: HttpMethod = HttpMethods.GET,
uriFn: (Uri => Uri) = identity)(implicit conf: S3Settings): HttpRequest =
Expand Down
10 changes: 10 additions & 0 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/Marshalling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,14 @@ private[alpakka] object Marshalling {
)
}
}

implicit val copyPartResultUnmarshaller: FromEntityUnmarshaller[CopyPartResult] = {
nodeSeqUnmarshaller(MediaTypes.`application/xml`, ContentTypes.`application/octet-stream`) map {
case NodeSeq.Empty => throw Unmarshaller.NoContentException
case x =>
val lastModified = Instant.parse((x \ "LastModified").text)
val eTag = (x \ "ETag").text
CopyPartResult(lastModified, eTag.dropRight(1).drop(1))
}
}
}
13 changes: 13 additions & 0 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Headers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ object ServerSideEncryption {
override def headersFor(request: S3Request): immutable.Seq[HttpHeader] = request match {
case GetObject | HeadObject | PutObject | InitiateMultipartUpload | UploadPart =>
headers
case CopyPart =>
val copyHeaders =
RawHeader("x-amz-copy-source-server-side-encryption-customer-algorithm", "AES256") ::
RawHeader("x-amz-copy-source-server-side-encryption-customer-key", key) ::
RawHeader(
"x-amz-copy-source-server-side-encryption-customer-key-MD5",
md5.getOrElse({
val decodedKey = Base64.decode(key)
val md5 = Md5Utils.md5AsBase64(decodedKey)
md5
})
) :: Nil
headers ++: copyHeaders
case _ => Nil
}
}
Expand Down
12 changes: 7 additions & 5 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package akka.stream.alpakka.s3.impl

trait S3Request

case object GetObject extends S3Request
private[impl] case object GetObject extends S3Request

case object HeadObject extends S3Request
private[impl] case object HeadObject extends S3Request

case object PutObject extends S3Request
private[impl] case object PutObject extends S3Request

case object InitiateMultipartUpload extends S3Request
private[impl] case object InitiateMultipartUpload extends S3Request

case object UploadPart extends S3Request
private[impl] case object UploadPart extends S3Request

private[impl] case object CopyPart extends S3Request
119 changes: 117 additions & 2 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

package akka.stream.alpakka.s3.impl

import java.time.LocalDate
import java.time.{Instant, LocalDate}

import scala.collection.immutable.Seq
import scala.concurrent.{ExecutionContext, Future}
Expand All @@ -20,7 +20,7 @@ import akka.stream.Materializer
import akka.stream.alpakka.s3.auth.{CredentialScope, Signer, SigningKey}
import akka.stream.alpakka.s3.scaladsl.{ListBucketResultContents, ObjectMetadata}
import akka.stream.alpakka.s3.{DiskBufferType, MemoryBufferType, S3Exception, S3Settings}
import akka.stream.scaladsl.{Flow, Keep, Sink, Source}
import akka.stream.scaladsl.{Flow, Keep, RunnableGraph, Sink, Source}
import akka.util.ByteString

final case class S3Location(bucket: String, key: String)
Expand Down Expand Up @@ -58,6 +58,12 @@ case object ListBucketVersion2 extends ApiVersion {
override val getInstance: ApiVersion = ListBucketVersion2
}

final case class CopyPartResult(lastModified: Instant, eTag: String)

final case class CopyPartition(partNumber: Int, sourceLocation: S3Location, range: Option[ByteRange.Slice] = None)

final case class MultipartCopy(multipartUpload: MultipartUpload, copyPartition: CopyPartition)

object S3Stream {

def apply(settings: S3Settings)(implicit system: ActorSystem, mat: Materializer): S3Stream =
Expand Down Expand Up @@ -231,6 +237,36 @@ private[alpakka] final class S3Stream(settings: S3Settings)(implicit system: Act
}
}

def multipartCopy(
sourceLocation: S3Location,
targetLocation: S3Location,
sourceVersionId: Option[String] = None,
contentType: ContentType = ContentTypes.`application/octet-stream`,
s3Headers: S3Headers,
sse: Option[ServerSideEncryption] = None,
chunkSize: Int = MinChunkSize,
chunkingParallelism: Int = 4
): RunnableGraph[Future[CompleteMultipartUploadResult]] = {
import mat.executionContext

// Pre step get source meta to get content length (size of the object)
val eventualMaybeObjectSize: Future[Option[Long]] =
getObjectMetadata(sourceLocation.bucket, sourceLocation.key, sse).map(_.map(_.contentLength))
val eventualPartitions =
eventualMaybeObjectSize.map(_.map(createPartitions(chunkSize, sourceLocation)).getOrElse(Nil))
val partitions = Source.fromFuture(eventualPartitions)

// Multipart copy upload requests (except for the completion api) are created here.
// The initial copy upload request gets executed within this function as well.
// The individual copy upload part requests are created.
val copyRequests =
createCopyRequests(targetLocation, sourceVersionId, contentType, s3Headers, sse, partitions)(chunkingParallelism)

// The individual copy upload part requests are processed here
processUploadCopyPartRequests(copyRequests)(chunkingParallelism)
.toMat(completionSink(targetLocation))(Keep.right)
}

private def computeMetaData(headers: Seq[HttpHeader], entity: ResponseEntity): ObjectMetadata =
ObjectMetadata(
headers ++
Expand Down Expand Up @@ -394,4 +430,83 @@ private[alpakka] final class S3Stream(settings: S3Settings)(implicit system: Act
throw new S3Exception(err)
}
}

private[impl] def createPartitions(chunkSize: Int,
sourceLocation: S3Location)(objectSize: Long): List[CopyPartition] =
if (objectSize <= 0 || objectSize < chunkSize) CopyPartition(1, sourceLocation) :: Nil
else {
((0L until objectSize by chunkSize).toList :+ objectSize)
.sliding(2)
.toList
.zipWithIndex
.map {
case (ls, index) => CopyPartition(index + 1, sourceLocation, Some(ByteRange(ls.head, ls.last)))
}
}

private def createCopyRequests(
location: S3Location,
sourceVersionId: Option[String],
contentType: ContentType,
s3Headers: S3Headers,
sse: Option[ServerSideEncryption],
partitions: Source[List[CopyPartition], NotUsed]
)(parallelism: Int) = {
val requestInfo: Source[(MultipartUpload, Int), NotUsed] =
initiateUpload(location,
contentType,
S3Headers(
s3Headers.headers ++
sse.fold[Seq[HttpHeader]](Seq.empty) { _.headersFor(InitiateMultipartUpload) }
))

// use the same key for all sub-requests (chunks)
val key: SigningKey = signingKey

val headers: S3Headers = S3Headers(sse.fold[Seq[HttpHeader]](Seq.empty) { _.headersFor(CopyPart) })

requestInfo
.zipWith(partitions) {
case ((upload, _), ls) =>
ls.map { cp =>
val multipartCopy = MultipartCopy(upload, cp)
val request = uploadCopyPartRequest(multipartCopy, sourceVersionId, headers)
(request, multipartCopy)
}
}
.mapConcat(identity)
.mapAsync(parallelism) {
case (req, info) => Signer.signedRequest(req, key).zip(Future.successful(info))
}
}

private def processUploadCopyPartRequests(
requests: Source[(HttpRequest, MultipartCopy), NotUsed]
)(parallelism: Int) = {
import mat.executionContext

requests
.via(Http().superPool[MultipartCopy]())
.map {
case (Success(r), multipartCopy) =>
val entity = r.entity
val upload = multipartCopy.multipartUpload
val index = multipartCopy.copyPartition.partNumber
import StatusCodes._
r.status match {
case OK =>
Unmarshal(entity).to[CopyPartResult].map(cp => SuccessfulUploadPart(upload, index, cp.eTag))
case statusCode: StatusCode =>
Unmarshal(entity).to[String].map { err =>
val response = Option(err).getOrElse(s"Failed to upload part into S3, status code was: $statusCode")
throw new S3Exception(response)
}
}

case (Failure(ex), multipartCopy) =>
Future.successful(FailedUploadPart(multipartCopy.multipartUpload, multipartCopy.copyPartition.partNumber, ex))
}
.mapAsync(parallelism)(identity)
}

}
Loading

0 comments on commit 16127f1

Please sign in to comment.