-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathmetrics.py
32 lines (28 loc) · 955 Bytes
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
License: Apache 2.0
Author: Ashley Gritzman
E-mail: [email protected]
"""
import tensorflow as tf
def accuracy(logits, labels):
"""Compute accuracy
Credit:
Suofei Zhang's implementation on GitHub, "Matrix-Capsules-EM-
Tensorflow"
https://github.com/www0wwwjs1/Matrix-Capsules-EM-Tensorflow
Args:
logits: shape (batch_size, num_classes)
labels: shape (batch_size,) containing index of correct class
Returns:
accuracy:
"""
with tf.variable_scope("accuracy") as scope:
logits = tf.identity(logits, name="logits")
labels = tf.identity(labels, name="labels")
batch_size = int(logits.get_shape()[0])
logits_idx = tf.to_int32(tf.argmax(logits, axis=1))
logits_idx = tf.reshape(logits_idx, shape=(batch_size,))
correct_preds = tf.equal(tf.to_int32(labels), logits_idx)
accuracy = (tf.reduce_sum(tf.cast(correct_preds, tf.float32))
/ batch_size)
return accuracy