diff --git a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py index e48778ad70..4afb257f9d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py @@ -1,4 +1,3 @@ -import pathlib from typing import Type from pyspark.ml import PipelineModel @@ -24,22 +23,17 @@ def to_literal( python_type: Type[PipelineModel], expected: LiteralType, ) -> Literal: - local_path = ctx.file_access.get_random_local_path() - pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - python_val.save(local_path) - + # Must write to remote directory remote_dir = ctx.file_access.get_random_remote_directory() - ctx.file_access.upload_directory(local_path, remote_dir) + python_val.write().overwrite().save(remote_dir) return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PipelineModel] ) -> PipelineModel: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.download_directory(lv.scalar.blob.uri, local_dir) - - return PipelineModel.load(local_dir) + remote_dir = lv.scalar.blob.uri + return PipelineModel.load(remote_dir) TypeEngine.register(PySparkPipelineModelTransformer())