Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix print_summary bug and add groups of convolution (#9492)
Browse files Browse the repository at this point in the history
* fix print_summary bug and add groups of convolution

1. fix "int(node["attrs"]["no_bias"])" bug
2. add groups of convolution param calculation

* Update visualization.py

lint

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py
  • Loading branch information
chinakook authored and piiswrong committed Feb 19, 2018
1 parent 50f326f commit f14b2eb
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,24 @@ def print_layer_summary(node, out_shape):
pre_filter = pre_filter + int(shape[0])
cur_param = 0
if op == 'Convolution':
if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
cur_param = pre_filter * int(node["attrs"]["num_filter"])
if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True':
num_group = int(node['attrs'].get('num_group', '1'))
cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
// num_group
for k in _str2tuple(node["attrs"]["kernel"]):
cur_param *= int(k)
else:
cur_param = pre_filter * int(node["attrs"]["num_filter"])
num_group = int(node['attrs'].get('num_group', '1'))
cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
// num_group
for k in _str2tuple(node["attrs"]["kernel"]):
cur_param *= int(k)
cur_param += int(node["attrs"]["num_filter"])
elif op == 'FullyConnected':
if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
cur_param = pre_filter * (int(node["attrs"]["num_hidden"]))
if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True':
cur_param = pre_filter * int(node["attrs"]["num_hidden"])
else:
cur_param = (pre_filter+1) * (int(node["attrs"]["num_hidden"]))
cur_param = (pre_filter+1) * int(node["attrs"]["num_hidden"])
elif op == 'BatchNorm':
key = node["name"] + "_output"
if show_shape:
Expand Down

1 comment on commit f14b2eb

@ranman
Copy link

@ranman ranman commented on f14b2eb Mar 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes #10093

Please sign in to comment.