Skip to content

Commit

Permalink
created forward and backward methods for Net class
Browse files Browse the repository at this point in the history
  • Loading branch information
John Canny committed Apr 22, 2017
1 parent 96e1cf5 commit 7851e0e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 64 deletions.
68 changes: 34 additions & 34 deletions .classpath
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src/main/scala">
<attributes>
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="C:/code/BIDMach/lib"/>
</attributes>
</classpathentry>
<classpathentry kind="src" path="src/main/java">
<attributes>
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="c:/code/CUDA7/bin"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.scala-ide.sdt.launching.SCALA_CONTAINER"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="lib" path="lib/lz4-1.3.jar"/>
<classpathentry kind="lib" path="lib/jfreechart-1.0.19.jar"/>
<classpathentry kind="lib" path="lib/jhdf5-3.2.1.jar"/>
<classpathentry kind="lib" path="lib/jline-2.10.jar"/>
<classpathentry kind="lib" path="lib/jocl-2.0.0.jar"/>
<classpathentry kind="lib" path="lib/json-io-4.5.0.jar"/>
<classpathentry kind="lib" path="lib/ptplot-1.0.jar"/>
<classpathentry kind="lib" path="lib/ptplotapplication-1.0.jar"/>
<classpathentry kind="lib" path="lib/scala-arm_2.11-1.4.jar"/>
<classpathentry kind="lib" path="lib/scala-compiler-2.11.0-M8.jar"/>
<classpathentry kind="lib" path="lib/jcublas-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcuda-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcudnn-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcufft-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcurand-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcusparse-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/protobuf-java-3.1.0.jar"/>
<classpathentry kind="lib" path="lib/BIDMat-2.0.1-cuda8.0beta.jar"/>
<classpathentry kind="output" path="bin"/>
</classpath>
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src/main/scala">
<attributes>
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="C:/code/BIDMach/lib"/>
</attributes>
</classpathentry>
<classpathentry kind="src" path="src/main/java">
<attributes>
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="c:/code/CUDA7/bin"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.scala-ide.sdt.launching.SCALA_CONTAINER"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="lib" path="lib/lz4-1.3.jar"/>
<classpathentry kind="lib" path="lib/jfreechart-1.0.19.jar"/>
<classpathentry kind="lib" path="lib/jhdf5-3.2.1.jar"/>
<classpathentry kind="lib" path="lib/jline-2.10.jar"/>
<classpathentry kind="lib" path="lib/jocl-2.0.0.jar"/>
<classpathentry kind="lib" path="lib/json-io-4.5.0.jar"/>
<classpathentry kind="lib" path="lib/ptplot-1.0.jar"/>
<classpathentry kind="lib" path="lib/ptplotapplication-1.0.jar"/>
<classpathentry kind="lib" path="lib/scala-arm_2.11-1.4.jar"/>
<classpathentry kind="lib" path="lib/scala-compiler-2.11.0-M8.jar"/>
<classpathentry kind="lib" path="lib/jcublas-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcuda-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcudnn-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcufft-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcurand-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/jcusparse-0.8.0.jar"/>
<classpathentry kind="lib" path="lib/protobuf-java-3.1.0.jar"/>
<classpathentry kind="lib" path="lib/BIDMat-2.0.2-cuda8.0beta.jar"/>
<classpathentry kind="output" path="bin"/>
</classpath>
2 changes: 1 addition & 1 deletion scripts/networks/testCIFAR10a.ssc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ val (nn,opts) = Net.learner(trainfname,labelsfname);
val convt = jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION


opts.batchSize= 64
opts.batchSize= 32
opts.npasses = 10
opts.lrate = 1e-3f
opts.lrate = 1e-4f
Expand Down
68 changes: 39 additions & 29 deletions src/main/scala/BIDMach/networks/Net.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,41 +132,51 @@ class Net(override val opts:Net.Opts = new Net.Options) extends Model(opts) {
}
}

def forward:Int = {
if (mask.asInstanceOf[AnyRef] != null) {
modelmats(0) ~ modelmats(0) mask;
}
var i = 0;
while (i < layers.length) {
if (opts.debug > 0) {
println("dobatch forward %d %s" format (i, layers(i).getClass))
}
layers(i).forward;
i += 1;
}
i;
}

def backward(nl:Int, ipass:Int, pos:Long) = {
var i = nl;
var j = 0;
while (j < output_layers.length) {
output_layers(j).deriv.set(1);
j += 1;
}
if (opts.aopts == null) {
for (j <- 0 until updatemats.length) updatemats(j).clear;
}
while (i > 1) {
i -= 1;
if (opts.debug > 0) {
println("dobatch backward %d %s" format (i, layers(i).getClass))
}
layers(i).backward(ipass, pos);
}
if (mask.asInstanceOf[AnyRef] != null) {
updatemats(0) ~ updatemats(0) mask;
}
}


def dobatch(gmats:Array[Mat], ipass:Int, pos:Long):Unit = {
if (batchSize < 0) batchSize = gmats(0).ncols;
if (batchSize == gmats(0).ncols) { // discard odd-sized minibatches
assignInputs(gmats, ipass, pos);
assignTargets(gmats, ipass, pos);
if (mask.asInstanceOf[AnyRef] != null) {
modelmats(0) ~ modelmats(0) mask;
}
var i = 0;
while (i < layers.length) {
if (opts.debug > 0) {
println("dobatch forward %d %s" format (i, layers(i).getClass))
}
layers(i).forward;
i += 1;
}
var j = 0;
while (j < output_layers.length) {
output_layers(j).deriv.set(1);
j += 1;
}
if (opts.aopts == null) {
for (j <- 0 until updatemats.length) updatemats(j).clear;
}
while (i > 1) {
i -= 1;
if (opts.debug > 0) {
println("dobatch backward %d %s" format (i, layers(i).getClass))
}
layers(i).backward(ipass, pos);
}
if (mask.asInstanceOf[AnyRef] != null) {
updatemats(0) ~ updatemats(0) mask;
}
val nl = forward;
backward(nl, ipass, pos);
}
}

Expand Down

0 comments on commit 7851e0e

Please sign in to comment.