Skip to content

Commit

Permalink
Merge pull request #386 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Add testing phase
  • Loading branch information
Hananel-Hazan authored Jun 19, 2020
2 parents 97a655c + 45884ec commit cd2b7d9
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 66 deletions.
87 changes: 77 additions & 10 deletions examples/mnist/SOM_LM-SNNs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@

# Sets up Gpu use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on Device = ", device)
if torch.cuda.is_available():
if gpu and torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
if gpu:
gpu = False
device = 'cpu'

torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

# Determines number of workers to use
if n_workers == -1:
Expand Down Expand Up @@ -102,7 +106,7 @@
)

# Record spikes during the simulation.
spike_record = torch.zeros(update_interval, time, n_neurons).cpu()
spike_record = torch.zeros(update_interval, int(time/dt), n_neurons).cpu()

# Neuron assignments and spike proportions.
n_classes = 10
Expand All @@ -114,18 +118,18 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time)
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=int(time/dt))
network.add_monitor(som_voltage_monitor, name="som_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=int(time/dt))
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time)
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=int(time/dt))
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

inpt_ims, inpt_axes = None, None
Expand Down Expand Up @@ -164,7 +168,7 @@

for step, batch in enumerate(tqdm(dataloader)):
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(time, 1, 1, 28, 28).to(device)}
inputs = {"X": batch["encoded_image"].view(int(time/dt), 1, 1, 28, 28).to(device)}

if step > 0:
if step % update_inhibation_weights == 0:
Expand Down Expand Up @@ -243,10 +247,10 @@
if temp_spikes.sum().sum() < 2:
inputs["X"] *= (
poisson(
datum=factor * batch["image"].clamp(min=0), dt=dt, time=time
datum=factor * batch["image"].clamp(min=0), dt=dt, time=int(time/dt)
)
.to(device)
.view(time, 1, 1, 28, 28)
.view(int(time/dt), 1, 1, 28, 28)
)
factor *= factor
else:
Expand All @@ -256,7 +260,7 @@
exc_voltages = som_voltage_monitor.get("v")

# Add to spikes recording.
spike_record[step % update_interval] = temp_spikes.detach().clone().cpu()
# spike_record[step % update_interval] = temp_spikes.detach().clone().cpu()
spike_record[step % update_interval].copy_(temp_spikes, non_blocking=True)

# Optionally plot various simulation information.
Expand Down Expand Up @@ -291,3 +295,66 @@

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")


# Load MNIST data.
test_dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join("..", "..", "data", "MNIST"),
download=True,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
),
)

# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}

# Record spikes during the simulation.
spike_record = torch.zeros(1, int(time/dt), n_neurons)

# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()

for step, batch in enumerate(tqdm(test_dataset)):
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(int(time/dt), 1, 1, 28, 28)}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)

# Add to spikes recording.
spike_record[0] = spikes["Y"].get("s").squeeze()

# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"])

# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item())

network.reset_state_variables() # Reset state variables.

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / test_dataset.test_labels.shape[0]))
print("Proportion weighting accuracy: %.2f \n" % ( accuracy["proportion"] / test_dataset.test_labels.shape[0]))


print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
91 changes: 83 additions & 8 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
parser.add_argument("--test", dest="train", action="store_false")
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=False, gpu=False, train=True)
parser.set_defaults(plot=False, gpu=False)

args = parser.parse_args()

Expand Down Expand Up @@ -73,6 +73,8 @@
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
if gpu:
gpu = False

# Determines number of workers to use
if n_workers == -1:
Expand Down Expand Up @@ -119,20 +121,20 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=int(time/dt))
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=int(time/dt))
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=int(time/dt))
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time)
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=int(time/dt))
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

inpt_ims, inpt_axes = None, None
Expand All @@ -142,7 +144,7 @@
perf_ax = None
voltage_axes, voltage_ims = None, None

spike_record = torch.zeros(update_interval, time, n_neurons)
spike_record = torch.zeros(update_interval, int(time/dt), n_neurons)

# Train the network.
print("\nBegin training.\n")
Expand All @@ -156,15 +158,15 @@
start = t()

# Create a dataloader to iterate and batch data
dataloader = DataLoader(
train_dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=gpu,
)

for step, batch in enumerate(tqdm(dataloader)):
for step, batch in enumerate(tqdm(train_dataloader)):
# Get next input sample.
inputs = {"X": batch["encoded_image"]}
if gpu:
Expand Down Expand Up @@ -272,5 +274,78 @@

network.reset_state_variables() # Reset state variables.

if step % update_steps == 0 and step > 0:
break

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")



# Load MNIST data.
test_dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join(ROOT_DIR, "data", "MNIST"),
download=True,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
),
)

# Create a dataloader to iterate and batch data
test_dataloader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=gpu,
)

# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}

# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()

for step, batch in enumerate(tqdm(test_dataloader)):
# Get next input sample.
inputs = {"X": batch["encoded_image"]}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

# Run the network on the input.
network.run(inputs=inputs, time=time, input_time_dim=1)

# Add to spikes recording.
spike_record = spikes["Ae"].get("s").permute((1, 0, 2))

# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"])

# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
)

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item())

network.reset_state_variables() # Reset state variables.

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / test_dataset.test_labels.shape[0]))
print("Proportion weighting accuracy: %.2f \n" % ( accuracy["proportion"] / test_dataset.test_labels.shape[0]))


print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
Loading

0 comments on commit cd2b7d9

Please sign in to comment.