Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
futami committed Nov 5, 2018
1 parent 339a03d commit 0c77c35
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
12 changes: 7 additions & 5 deletions load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ def load_image_data(dataset_path, test_img, img_size, mode):

print(dataset_path)
print("train image data loading...")
train_image_list = glob.glob(dataset_path+'*.jpg')
train_image_list = glob.glob(dataset_path+'*.jpeg')
for img_path in train_image_list:
print(img_path)
img = load_img(img_path, target_size=(img_size, img_size))
img = load_img(img_path, target_size=(img_size, img_size), grayscale=True)
imgarray = img_to_array(img)
X_train.append(imgarray)

X_train = np.array(X_train).astype(np.float32)
X_train = (X_train -127.5) / 127.5

print(X_train.shape)

if mode == 'test':
test_img = load_img(test_img, target_size=(img_size, img_size))
test_img = load_img(test_img, target_size=(img_size, img_size), grayscale=True)
test_imgarray = img_to_array(test_img)
X_test = np.array(test_imgarray).astype(np.float32)
X_test = (X_test -127.5) / 127.5
Expand Down Expand Up @@ -86,8 +88,8 @@ def load_mnist_data():

X_test_original = X_test.copy()

X_train = X_train[Y_train==1]
X_test = X_test[Y_test==1]
X_train = X_train[Y_train==0]
X_test = X_test[Y_test==0]
print('train shape: ', X_train.shape)

return X_train, X_test, X_test_original, Y_test
31 changes: 18 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def anomaly_detection(test_img, args, g=None, d=None):
original_x = (test_img.reshape(args.imgsize,args.imgsize,args.channels)*127.5+127.5).astype(np.uint8)
similar_x = (similar_img.reshape(args.imgsize,args.imgsize,args.channels)*127.5+127.5).astype(np.uint8)

original_x_color = cv2.cvtColor(original_x, cv2.COLOR_RGB2BGR)
original_x_color = cv2.cvtColor(original_x, cv2.COLOR_GRAY2BGR)
residual_color = cv2.applyColorMap(np_residual, cv2.COLORMAP_JET)
show = cv2.addWeighted(original_x_color, 0.3, residual_color, 0.7, 0.)

Expand Down Expand Up @@ -68,10 +68,10 @@ def run(args):
#X_train, X_test, X_test_original, Y_test = load.load_mnist_data()

""" load image data """
#X_train, test_img = load.load_image_data(args.datapath, args.testpath, args.imgsize, args.mode)
X_train, test_img = load.load_image_data(args.datapath, args.testpath, args.imgsize, args.mode)

""" load csv data """
X_train, Y_test, X_test_original, Y_test = load.load_csv_data(args.datapath, args.imgsize)
#X_train, Y_test, X_test_original, Y_test = load.load_csv_data(args.datapath, args.imgsize)

""" init DCGAN """
print("initialize DCGAN ")
Expand Down Expand Up @@ -111,41 +111,46 @@ def run(args):
#test_img = X_test_original[Y_test==0][30]

# compute anomaly score - sample from strange image
img_idx = args.img_idx
label_idx = args.label_idx
test_img = X_test_original[Y_test==label_idx][img_idx]
# test_img = np.random.uniform(-1, 1 (28, 28, 1))
#img_idx = args.img_idx
#label_idx = args.label_idx
#test_img = X_test_original[Y_test==label_idx][img_idx]
#test_img = np.random.uniform(-1, 1 (args.imgsize, args.imgsize, args.channels))

start = cv2.getTickCount()
score, qurey, pred, diff = anomaly_detection(test_img, args)
time = (cv2.getTickCount() - start ) / cv2.getTickFrequency() * 1000
print ('%d label, %d : done ' %(label_idx, img_idx), '%.2f' %score, '%.2fms'%time)
#print ('%d label, %d : done ' %(label_idx, img_idx), '%.2f' %score, '%.2fms'%time)

""" matplot view """
plt.figure(1, figsize=(3, 3))
plt.title('query image')
plt.imshow(qurey.reshape(args.imgsize, args.imgsize, args.channels), cmap=plt.cm.gray)
plt.imshow(qurey.reshape(args.imgsize, args.imgsize), cmap=plt.cm.gray)
plt.savefig('./anomaly_detection/' + str(label_idx) +'_query_image.png' )

print('anomaly score :', score)
plt.figure(2, figsize=(3,3))
plt.title('generated similar image')
plt.imshow(pred.reshape(args.imgsize, args.imgsize, args.channels), cmap=plt.cm.gray)
plt.imshow(pred.reshape(args.imgsize, args.imgsize), cmap=plt.cm.gray)
plt.savefig('./anomaly_detection/' + str(label_idx) +'_generated_similar.png' )

plt.figure(3, figsize=(3,4))
plt.figure(3, figsize=(3,3))
plt.title('anomaly detection')
plt.imshow(cv2.cvtColor(diff, cv2.COLOR_BGR2RGB))
plt.savefig('./anomaly_detection/' + str(label_idx) +'_diff.png' )
plt.show()

def main():
parser = argparse.ArgumentParser(description='train AnoGAN')
parser.add_argument('--datapath', '-d',)
parser.add_argument('--epoch', '-e', default=1000)
parser.add_argument('--batchsize', '-b', default=32)
parser.add_argument('--batchsize', '-b', default=64)
parser.add_argument('--mode', '-m' , type=str, default='test',help='train, test')
parser.add_argument('--imgsize', type=int, default=28)
parser.add_argument('--imgsize', type=int, default=64)
parser.add_argument('--channels', type=int, default=1)
parser.add_argument('--zdims', type=int, default=100)
parser.add_argument('--testpath', '-p', type=str )
parser.add_argument('--label_idx', type=int ,default=1 )
parser.add_argument('--img_idx', type=int, default=14 )

args = parser.parse_args()

Expand Down

0 comments on commit 0c77c35

Please sign in to comment.