Skip to content

Commit

Permalink
move some from fluid.io to static.io
Browse files Browse the repository at this point in the history
  • Loading branch information
Difers committed Aug 8, 2023
1 parent 4dd9a3a commit 2041a4a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 52 deletions.
50 changes: 0 additions & 50 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,53 +65,3 @@
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


def prepend_feed_ops(
inference_program, feed_target_names, feed_holder_name='feed'
):
if len(feed_target_names) == 0:
return

global_block = inference_program.global_block()
feed_var = global_block.create_var(
name=feed_holder_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True,
)

for i, name in enumerate(feed_target_names):
if not global_block.has_var(name):
raise ValueError(
"The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
"Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
"if '{name}' is not involved in the target_vars calculation.".format(
i=i, name=name
)
)
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i},
)


def append_fetch_ops(
inference_program, fetch_target_names, fetch_holder_name='fetch'
):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name=fetch_holder_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True,
)

for i, name in enumerate(fetch_target_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i},
)
51 changes: 50 additions & 1 deletion python/paddle/static/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import Parameter, dygraph_not_support, static_only
from paddle.fluid.io import append_fetch_ops, prepend_feed_ops
from paddle.fluid.log_helper import get_logger
from paddle.framework.io_utils import (
_clone_var_in_block_,
Expand Down Expand Up @@ -138,6 +137,56 @@ def _clone_var_in_block(block, var):
)


def prepend_feed_ops(
inference_program, feed_target_names, feed_holder_name='feed'
):
if len(feed_target_names) == 0:
return

global_block = inference_program.global_block()
feed_var = global_block.create_var(
name=feed_holder_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True,
)

for i, name in enumerate(feed_target_names):
if not global_block.has_var(name):
raise ValueError(
"The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
"Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
"if '{name}' is not involved in the target_vars calculation.".format(
i=i, name=name
)
)
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i},
)


def append_fetch_ops(
inference_program, fetch_target_names, fetch_holder_name='fetch'
):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name=fetch_holder_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True,
)

for i, name in enumerate(fetch_target_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i},
)


def normalize_program(program, feed_vars, fetch_vars):
"""
Expand Down
2 changes: 1 addition & 1 deletion test/ir/inference/quant_dequant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from paddle.fluid import Program, Variable, core
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor
from paddle.fluid.framework import IrGraph
from paddle.fluid.io import append_fetch_ops, prepend_feed_ops
from paddle.static.io import append_fetch_ops, prepend_feed_ops
from paddle.static.quantization import (
AddQuantDequantPass,
OutScaleForInferencePass,
Expand Down

0 comments on commit 2041a4a

Please sign in to comment.