Skip to content

Commit

Permalink
notebooks updated: replaced ckpt loader to use default saver instead …
Browse files Browse the repository at this point in the history
…of building var_list.

Also, rms error between darkflow and DW2TF converted networks is zero
now (not exactly sure what changed).
  • Loading branch information
sjain-stanford committed Feb 27, 2019
1 parent cbb8516 commit 3b289fc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 63 deletions.
38 changes: 7 additions & 31 deletions docs/yolov2-tiny_dw2tf_validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,14 @@
"source": [
"tf.reset_default_graph()\n",
"\n",
"tf.train.import_meta_graph('/data/darkflow/yolov2-tiny.ckpt.meta')\n",
"saver = tf.train.import_meta_graph('/data/darkflow/yolov2-tiny.ckpt.meta')\n",
"ckpt_path = '/data/darkflow/yolov2-tiny.ckpt'\n",
"\n",
"g = tf.get_default_graph()\n",
"in_t = g.get_tensor_by_name('input:0')\n",
"out_t = g.get_tensor_by_name('output:0')\n",
"\n",
"with tf.Session() as sess:\n",
" var_list = {} # dictionary mapping variable name to tensor, used to create saver object to restore\n",
" reader = tf.train.NewCheckpointReader(ckpt_path)\n",
" for key in reader.get_variable_to_shape_map():\n",
" # Look for all variables in ckpt that are used by the graph\n",
" try:\n",
" tensor = g.get_tensor_by_name(key + \":0\")\n",
" except KeyError:\n",
" # This tensor doesn't exist in the graph (for example it's\n",
" # 'global_step' or a similar housekeeping element) so skip it.\n",
" continue\n",
" var_list[key] = tensor\n",
" saver = tf.train.Saver(var_list=var_list)\n",
" saver.restore(sess, ckpt_path)\n",
"\n",
" out_data_darkflow = sess.run(out_t, feed_dict={in_t: in_data})"
Expand Down Expand Up @@ -91,26 +79,14 @@
"source": [
"tf.reset_default_graph()\n",
"\n",
"tf.train.import_meta_graph('/data/DW2TF/data/yolov2-tiny.ckpt.meta')\n",
"saver = tf.train.import_meta_graph('/data/DW2TF/data/yolov2-tiny.ckpt.meta')\n",
"ckpt_path = '/data/DW2TF/data/yolov2-tiny.ckpt'\n",
"\n",
"g = tf.get_default_graph()\n",
"in_t = g.get_tensor_by_name('yolov2-tiny/net1:0')\n",
"out_t = g.get_tensor_by_name('yolov2-tiny/convolutional9/BiasAdd:0')\n",
"\n",
"with tf.Session() as sess:\n",
" var_list = {} # dictionary mapping variable name to tensor, used to create saver object to restore\n",
" reader = tf.train.NewCheckpointReader(ckpt_path)\n",
" for key in reader.get_variable_to_shape_map():\n",
" # Look for all variables in ckpt that are used by the graph\n",
" try:\n",
" tensor = g.get_tensor_by_name(key + \":0\")\n",
" except KeyError:\n",
" # This tensor doesn't exist in the graph (for example it's\n",
" # 'global_step' or a similar housekeeping element) so skip it.\n",
" continue\n",
" var_list[key] = tensor\n",
" saver = tf.train.Saver(var_list=var_list)\n",
" saver.restore(sess, ckpt_path)\n",
" \n",
" out_data_dw2tf = sess.run(out_t, feed_dict={in_t: in_data})"
Expand All @@ -134,7 +110,7 @@
"text": [
"(64, 13, 13, 425)\n",
"(64, 13, 13, 425)\n",
"3.734651923107484e-14\n"
"0.0\n"
]
}
],
Expand All @@ -153,8 +129,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[-32.597694 13.941672 4.1781635 2.0687065 -85.162506 139.02223\n",
" 61.937466 125.50001 83.98407 36.48591 ]\n"
"[-54.208088 22.509918 3.4831426 6.245613 -61.15138 122.629616\n",
" 67.048996 84.28465 37.972435 -27.923262 ]\n"
]
}
],
Expand All @@ -171,8 +147,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[-32.59783 13.941475 4.178166 2.0686564 -85.161705 139.0219\n",
" 61.936897 125.500015 83.98348 36.486507 ]\n"
"[-54.208088 22.509918 3.4831426 6.245613 -61.15138 122.629616\n",
" 67.048996 84.28465 37.972435 -27.923262 ]\n"
]
}
],
Expand Down
40 changes: 8 additions & 32 deletions docs/yolov2_dw2tf_validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,16 @@
"source": [
"tf.reset_default_graph()\n",
"\n",
"tf.train.import_meta_graph('/data/darkflow/yolov2.ckpt.meta')\n",
"saver = tf.train.import_meta_graph('/data/darkflow/yolov2.ckpt.meta')\n",
"ckpt_path = '/data/darkflow/yolov2.ckpt'\n",
"\n",
"g = tf.get_default_graph()\n",
"in_t = g.get_tensor_by_name('input:0')\n",
"out_t = g.get_tensor_by_name('output:0')\n",
"\n",
"with tf.Session() as sess:\n",
" var_list = {} # dictionary mapping variable name to tensor, used to create saver object to restore\n",
" reader = tf.train.NewCheckpointReader(ckpt_path)\n",
" for key in reader.get_variable_to_shape_map():\n",
" # Look for all variables in ckpt that are used by the graph\n",
" try:\n",
" tensor = g.get_tensor_by_name(key + \":0\")\n",
" except KeyError:\n",
" # This tensor doesn't exist in the graph (for example it's\n",
" # 'global_step' or a similar housekeeping element) so skip it.\n",
" continue\n",
" var_list[key] = tensor\n",
" saver = tf.train.Saver(var_list=var_list)\n",
" saver.restore(sess, ckpt_path)\n",
"\n",
" \n",
" out_data_darkflow = sess.run(out_t, feed_dict={in_t: in_data})"
]
},
Expand All @@ -91,26 +79,14 @@
"source": [
"tf.reset_default_graph()\n",
"\n",
"tf.train.import_meta_graph('/data/DW2TF/data/yolov2.ckpt.meta')\n",
"saver = tf.train.import_meta_graph('/data/DW2TF/data/yolov2.ckpt.meta')\n",
"ckpt_path = '/data/DW2TF/data/yolov2.ckpt'\n",
"\n",
"g = tf.get_default_graph()\n",
"in_t = g.get_tensor_by_name('yolov2/net1:0')\n",
"out_t = g.get_tensor_by_name('yolov2/convolutional23/BiasAdd:0')\n",
"\n",
"with tf.Session() as sess:\n",
" var_list = {} # dictionary mapping variable name to tensor, used to create saver object to restore\n",
" reader = tf.train.NewCheckpointReader(ckpt_path)\n",
" for key in reader.get_variable_to_shape_map():\n",
" # Look for all variables in ckpt that are used by the graph\n",
" try:\n",
" tensor = g.get_tensor_by_name(key + \":0\")\n",
" except KeyError:\n",
" # This tensor doesn't exist in the graph (for example it's\n",
" # 'global_step' or a similar housekeeping element) so skip it.\n",
" continue\n",
" var_list[key] = tensor\n",
" saver = tf.train.Saver(var_list=var_list)\n",
" saver.restore(sess, ckpt_path)\n",
" \n",
" out_data_dw2tf = sess.run(out_t, feed_dict={in_t: in_data})"
Expand All @@ -134,7 +110,7 @@
"text": [
"(64, 19, 19, 425)\n",
"(64, 19, 19, 425)\n",
"7.561380489925889e-15\n"
"0.0\n"
]
}
],
Expand All @@ -153,8 +129,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[ -2.2853928 4.57492 -4.6335416 -2.899775 -21.905552 16.398315\n",
" 14.86693 11.579112 -2.0311275 -14.31339 ]\n"
"[ -4.044286 4.3981853 -4.290732 -1.7628956 -19.399569 16.678816\n",
" 17.136553 17.491247 1.3627763 -14.287616 ]\n"
]
}
],
Expand All @@ -171,8 +147,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[ -2.2854545 4.574852 -4.633548 -2.899774 -21.905493 16.398537\n",
" 14.866977 11.579125 -2.0310893 -14.313481 ]\n"
"[ -4.044286 4.3981853 -4.290732 -1.7628956 -19.399569 16.678816\n",
" 17.136553 17.491247 1.3627763 -14.287616 ]\n"
]
}
],
Expand Down

0 comments on commit 3b289fc

Please sign in to comment.