-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adadelta Optimizer #26590
Adadelta Optimizer #26590
Conversation
Thanks for your contribution! |
|
||
Args: | ||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. | ||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__
里没有默认值呀。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢 已作修改
python/paddle/optimizer/adadelta.py
Outdated
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. | ||
epsilon (float): a small float number for numeric stability. Default 1.0e-6. | ||
rho (float): a floating point value indicating the decay rate. Default 0.95. | ||
parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动态图下这里是list of Parameters 还是list of Tensor names ?
我看到一些示例里,这里传入的是paddle.nn.Layer.parameters
的返回结果。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢 已做修改
python/paddle/optimizer/momentum.py
Outdated
.. code-block:: python | ||
|
||
import paddle | ||
import paddle.fluid as fluid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to use fluid here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢指出 已作删除
python/paddle/optimizer/sgd.py
Outdated
.. code-block:: python | ||
|
||
import paddle | ||
import paddle.fluid as fluid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no fluid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢指出 已作删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Args: | ||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. | ||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. | ||
epsilon (float): a small float number for numeric stability. Default 1.0e-6. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float, optional
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. | ||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. | ||
epsilon (float): a small float number for numeric stability. Default 1.0e-6. | ||
rho (float): a floating point value indicating the decay rate. Default 0.95. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float, optional
:ref:`api_guide_Name` . | ||
|
||
Examples: | ||
.. code-block:: python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. code-block 与 import中间加个空行 不然预览有bug
|
||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. | ||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. | ||
momentum (float): Momentum factor. The default value is 0.9. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float, optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
todo : fix docs
PR types
New features
PR changes
OPs
Describe
完善Adadelta Op, SGD Op, Momentum Op
AdadeltaOptimizer 改为 Adadelta,SGDOptimizer 改为 SGD,MomentumOptimizer改为Momentum其余改动与基类Optimizer相同。
修改Adadelta op和top k op在c++侧的报错信息。
todo: FluidDoc下Adadelta改动