Skip to content

Commit

Permalink
fix print_summary bug and add groups of convolution (apache#9492)
Browse files Browse the repository at this point in the history
* fix print_summary bug and add groups of convolution

1. fix "int(node["attrs"]["no_bias"])" bug
2. add groups of convolution param calculation

* Update visualization.py

lint

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py
  • Loading branch information
chinakook authored and zheng-da committed Jun 28, 2018
1 parent ed93639 commit 997bf6a
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,24 @@ def print_layer_summary(node, out_shape):
pre_filter = pre_filter + int(shape[0])
cur_param = 0
if op == 'Convolution':
if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
cur_param = pre_filter * int(node["attrs"]["num_filter"])
if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True':
num_group = int(node['attrs'].get('num_group', '1'))
cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
// num_group
for k in _str2tuple(node["attrs"]["kernel"]):
cur_param *= int(k)
else:
cur_param = pre_filter * int(node["attrs"]["num_filter"])
num_group = int(node['attrs'].get('num_group', '1'))
cur_param = pre_filter * int(node["attrs"]["num_filter"]) \
// num_group
for k in _str2tuple(node["attrs"]["kernel"]):
cur_param *= int(k)
cur_param += int(node["attrs"]["num_filter"])
elif op == 'FullyConnected':
if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
cur_param = pre_filter * (int(node["attrs"]["num_hidden"]))
if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True':
cur_param = pre_filter * int(node["attrs"]["num_hidden"])
else:
cur_param = (pre_filter+1) * (int(node["attrs"]["num_hidden"]))
cur_param = (pre_filter+1) * int(node["attrs"]["num_hidden"])
elif op == 'BatchNorm':
key = node["name"] + "_output"
if show_shape:
Expand Down

0 comments on commit 997bf6a

Please sign in to comment.