Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add Paddle-to-ONNX model conversion script #2722

Merged
merged 17 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions paddlex/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@ def console_entry() -> int:
try:
# Flush output here to force SIGPIPE to be triggered while inside this
# try block.
code = main()
main()
sys.stdout.flush()
sys.stderr.flush()
return code
except BrokenPipeError:
# Python flushes standard streams on exit;
# redirect remaining output to devnull to avoid another BrokenPipeError
# at shutdown.
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())
return 1
sys.exit(1)


if __name__ == "__main__":
sys.exit(console_entry())
console_entry()
1 change: 1 addition & 0 deletions paddlex/paddle2onnx_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
paddle2onnx>=1.3
152 changes: 144 additions & 8 deletions paddlex/paddlex_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,20 @@
import argparse
import subprocess
import sys
import shutil
from pathlib import Path

from importlib_resources import files, as_file

from . import create_pipeline
from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
from .repo_manager import setup, get_all_supported_repo_names
from .utils.flags import FLAGS_json_format_model
from .utils import logging
from .utils.interactive_get_pipeline import interactive_get_pipeline
from .utils.pipeline_arguments import PIPELINE_ARGUMENTS


def _install_serving_deps():
with as_file(files("paddlex").joinpath("serving_requirements.txt")) as req_file:
return subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", str(req_file)]
)


def args_cfg():
"""parse cli arguments"""

Expand All @@ -57,6 +53,7 @@ def parse_str(s):
install_group = parser.add_argument_group("Install PaddleX Options")
pipeline_group = parser.add_argument_group("Pipeline Predict Options")
serving_group = parser.add_argument_group("Serving Options")
paddle2onnx_group = parser.add_argument_group("Paddle2ONNX Options")

################# install pdx #################
install_group.add_argument(
Expand Down Expand Up @@ -148,6 +145,23 @@ def parse_str(s):
help="Port number to serve on (default: 8080).",
)

################# paddle2onnx #################
paddle2onnx_group.add_argument(
"--paddle2onnx", action="store_true", help="Convert Paddle model to ONNX format"
)
paddle2onnx_group.add_argument(
"--paddle_model_dir", type=str, help="Directory containing the Paddle model"
)
paddle2onnx_group.add_argument(
"--onnx_model_dir",
type=str,
default="onnx",
help="Output directory for the ONNX model",
)
paddle2onnx_group.add_argument(
"--opset_version", type=int, help="Version of the ONNX opset to use"
)

# Parse known arguments to get the pipeline name
args, remaining_args = parser.parse_known_args()
pipeline_name = args.pipeline
Expand Down Expand Up @@ -180,6 +194,21 @@ def parse_str(s):

def install(args):
"""install paddlex"""

def _install_serving_deps():
with as_file(files("paddlex").joinpath("serving_requirements.txt")) as req_file:
return subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", str(req_file)]
)

def _install_paddle2onnx_deps():
with as_file(
files("paddlex").joinpath("paddle2onnx_requirements.txt")
) as req_file:
return subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", str(req_file)]
)

# Enable debug info
os.environ["PADDLE_PDX_DEBUG"] = "True"
# Disable eager initialization
Expand All @@ -189,9 +218,20 @@ def install(args):

if "serving" in plugins:
plugins.remove("serving")
if plugins:
logging.error("`serving` cannot be used together with other plugins.")
sys.exit(2)
_install_serving_deps()
return

if "paddle2onnx" in plugins:
plugins.remove("paddle2onnx")
if plugins:
logging.error("`paddle2onnx` cannot be used together with other plugins.")
sys.exit(2)
_install_paddle2onnx_deps()
return

if plugins:
repo_names = plugins
elif len(plugins) == 0:
Expand Down Expand Up @@ -234,6 +274,95 @@ def serve(pipeline, *, device, use_hpip, host, port):
run_server(app, host=host, port=port, debug=False)


# TODO: Move to another module
def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
PD_MODEL_FILE_PREFIX = "inference"
PD_PARAMS_FILENAME = "inference.pdiparams"
ONNX_MODEL_FILENAME = "inference.onnx"
CONFIG_FILENAME = "inference.yml"
ADDITIONAL_FILENAMES = ["scaler.pkl"]

def _check_input_dir(input_dir, pd_model_file_ext):
if input_dir is None:
sys.exit("Input directory must be specified")
if not input_dir.exists():
sys.exit(f"{input_dir} does not exist")
if not input_dir.is_dir():
sys.exit(f"{input_dir} is not a directory")
model_path = (input_dir / PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)
if not model_path.exists():
sys.exit(f"{model_path} does not exist")
params_path = input_dir / PD_PARAMS_FILENAME
if not params_path.exists():
sys.exit(f"{params_path} does not exist")
config_path = input_dir / CONFIG_FILENAME
if not config_path.exists():
sys.exit(f"{config_path} does not exist")

def _check_paddle2onnx():
if shutil.which("paddle2onnx") is None:
sys.exit("Paddle2ONNX is not available. Please install the plugin first.")

def _run_paddle2onnx(input_dir, pd_model_file_ext, output_dir, opset_version):
logging.info("Paddle2ONNX conversion starting...")
# XXX: To circumvent Paddle2ONNX's bug
if opset_version is None:
if pd_model_file_ext == ".json":
opset_version = 19
else:
opset_version = 7
logging.info("Using default ONNX opset version: %d", opset_version)
cmd = [
"paddle2onnx",
"--model_dir",
str(input_dir),
"--model_filename",
str(Path(PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)),
"--params_filename",
PD_PARAMS_FILENAME,
"--save_file",
str(output_dir / ONNX_MODEL_FILENAME),
"--opset_version",
str(opset_version),
]
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError as e:
sys.exit(f"Paddle2ONNX conversion failed with exit code {e.returncode}")
logging.info("Paddle2ONNX conversion succeeded")

def _copy_config_file(input_dir, output_dir):
src_path = input_dir / CONFIG_FILENAME
dst_path = output_dir / CONFIG_FILENAME
shutil.copy(src_path, dst_path)
logging.info(f"Copied {src_path} to {dst_path}")

def _copy_additional_files(input_dir, output_dir):
for filename in ADDITIONAL_FILENAMES:
src_path = input_dir / filename
if not src_path.exists():
continue
dst_path = output_dir / filename
shutil.copy(src_path, dst_path)
logging.info(f"Copied {src_path} to {dst_path}")

paddle_model_dir = Path(paddle_model_dir)
onnx_model_dir = Path(onnx_model_dir)
logging.info(f"Input dir: {paddle_model_dir}")
logging.info(f"Output dir: {onnx_model_dir}")
pd_model_file_ext = ".json"
if not FLAGS_json_format_model:
if not (paddle_model_dir / f"{PD_MODEL_FILE_PREFIX}.json").exists():
pd_model_file_ext = ".pdmodel"
_check_input_dir(paddle_model_dir, pd_model_file_ext)
_check_paddle2onnx()
_run_paddle2onnx(paddle_model_dir, pd_model_file_ext, onnx_model_dir, opset_version)
if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
_copy_config_file(paddle_model_dir, onnx_model_dir)
_copy_additional_files(paddle_model_dir, onnx_model_dir)
logging.info("Done")


# for CLI
def main():
"""API for commad line"""
Expand All @@ -243,7 +372,7 @@ def main():
if len(sys.argv) == 1:
logging.warning("No arguments provided. Displaying help information:")
parser.print_help()
return
sys.exit(2)

if args.install:
install(args)
Expand All @@ -255,12 +384,19 @@ def main():
host=args.host,
port=args.port,
)
elif args.paddle2onnx:
paddle_to_onnx(
args.paddle_model_dir,
args.onnx_model_dir,
opset_version=args.opset_version,
)
else:
if args.get_pipeline_config is not None:
interactive_get_pipeline(args.get_pipeline_config, args.save_path)
else:
pipeline_args_dict = {}
from .utils.flags import USE_NEW_INFERENCE

if USE_NEW_INFERENCE:
for arg in pipeline_args:
arg_name = arg["name"].lstrip("-")
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def serving_dependencies():
return file.read()


def paddle2onnx_dependencies():
with open(os.path.join("paddlex", "paddle2onnx_requirements.txt"), "r") as file:
return file.read()


def version():
"""get version"""
with open(os.path.join("paddlex", ".version"), "r") as file:
Expand Down Expand Up @@ -92,6 +97,7 @@ def _recursively_find(pattern, exts=None):
install_requires=dependencies(),
extras_require={
"serving": serving_dependencies(),
"paddle2onnx": paddle2onnx_dependencies(),
},
packages=pkgs,
package_data=pkg_data,
Expand Down