Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
apply changes to new constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Aug 7, 2018
1 parent 68c3e8c commit f49eeca
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,22 @@ import scala.collection.mutable.ListBuffer
* @param handle the handle to the underlying C++ Data Iterator
*/
private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label",
dtype: DType = DType.Float32,
dataLayout: String = "NCHW",
labelLayout: String = "N",
dataDType: DType = DType.Float32,
labelDType: DType = DType.Int32)
dataName: String,
labelName: String,
dataLayout: String,
labelLayout: String,
dataDType: DType,
labelDType: DType)
extends DataIter with WarnIfNotDisposed {

private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

def this(handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label") {
this(handle, dataName, labelName, "NCHW", "N", DType.Float32, DType.Int32)
}

// use currentBatch to implement hasNext
// (may be this is not the best way to do this work,
// fix me if any better way found)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,36 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
dataLayout: String, labelLayout: String) extends DataIter {

// scalastyle:off
/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label",
dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32,
dataLayout: String = "NCHW", labelLayout: String = "N") {
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray],
dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String,
dataName: String, labelName: String,
dataDType: DType, labelDType: DType,
dataLayout: String, labelLayout: String) {
this(IO.initData(data, allowEmpty = false, dataName),
IO.initData(label, allowEmpty = true, labelName),
dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout)
}
// scalastyle:on
/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") {
this(data, label, dataBatchSize, shuffle, lastBatchHandle, dataName, labelName,
MX_REAL_TYPE, DType.Int32, "NCHW", "N")
}

private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {

// test pad
val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
dataName = "data", labelName = "label", dataDType = DType.Float32, labelDType = DType.Int32,
dataLayout = "NTC", labelLayout = "NT")
var batchCount = 0
val nBatch0 = 8
Expand Down Expand Up @@ -287,8 +288,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch1)

// test empty label (for prediction)
val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard",
dataLayout = "NTC")
val dataIter2 = new NDArrayIter(data = data, label = IndexedSeq.empty,
dataBatchSize = 128, shuffle = false, lastBatchHandle = "discard",
dataName = "data", labelName = "label",
dataLayout = "NTC", labelLayout = "N",
dataDType = DType.Float32, labelDType = DType.Int32)
batchCount = 0
while(dataIter2.hasNext) {
val tBatch = dataIter2.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label",
dataLayout = "NCHW", labelLayout = "NCHW")
IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label",
DType.Float32, DType.Int32, "NCHW", "NCHW")

// symbols
var x = Symbol.Variable("data")
Expand Down Expand Up @@ -235,8 +235,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label",
dataLayout = "NCHW", labelLayout = "NCHW")
IndexedSeq(data), IndexedSeq(label), 1, false, "pad", "data", "softmax_label",
DType.Float32, DType.Int32, "NCHW", "NCHW")

// symbols
var x = Symbol.Variable("data")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,22 @@ object BucketIo {
class BucketSentenceIter(
path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int],
_batchSize: Int, private val initStates: IndexedSeq[(String, (Int, Int))],
seperateChar: String = " <eos> ", text2Id: Text2Id = defaultText2Id,
readContent: ReadContent = defaultReadContent,
dataLayout: String = "NT",
labelLayout: String = "N",
dataDType : DType = DType.Float32,
labelDType: DType = DType.Int32) extends DataIter {
seperateChar: String, text2Id: Text2Id,
readContent: ReadContent,
dataLayout: String,
labelLayout: String,
dataDType : DType,
labelDType: DType) extends DataIter {

// scalastyle:off
def this(path: String, vocab: Map[String, Int], buckets: IndexedSeq[Int],
_batchSize: Int, initStates: IndexedSeq[(String, (Int, Int))],
seperateChar: String = " <eos> ", text2Id: Text2Id = defaultText2Id,
readContent: ReadContent = defaultReadContent) {
this(path, vocab, buckets, _batchSize, initStates, seperateChar, text2Id,
readContent, "NT", "N", DType.Float32, DType.Int32)
}
// scalastyle:on

private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer

/**
* A helper converter for LabeledPoint
* @author Yizhi Liu
*/
class LabeledPointIter private[mxnet](
private val points: Iterator[LabeledPoint],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer

/**
* A temporary helper implementation for predicting Vectors
* @author Yizhi Liu
*/
class PointIter private[mxnet](
private val points: Iterator[Vector],
Expand Down

0 comments on commit f49eeca

Please sign in to comment.