Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change the DType
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 25, 2018
1 parent 4b70ef6 commit 4ce87dd
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 76 deletions.
31 changes: 20 additions & 11 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.slf4j.LoggerFactory
import scala.annotation.varargs
import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer

import scala.reflect.runtime.universe._
/**
* IO iterators for loading training & validation data
*/
Expand Down Expand Up @@ -108,8 +108,12 @@ object IO {
val labelName = params.getOrElse("label_name", "label")
val dataLayout = params.getOrElse("dataLayout", "NCHW")
val labelLayout = params.getOrElse("labelLayout", "N")
val dataDType = params.getOrElse("dataDType", "Float32")
val labelDType = params.getOrElse("labelDType", "Int32")
new MXDataIter(out.value, dataName, labelName,
dataLayout = dataLayout, labelLayout = labelLayout)
dataLayout = dataLayout, labelLayout = labelLayout,
dataDType = q"DType ${TermName(dataDType)}".asInstanceOf[DType],
labelDType = q"DType ${TermName(labelDType)}".asInstanceOf[DType])
}

// Convert data into canonical form.
Expand Down Expand Up @@ -144,7 +148,8 @@ class DataBatch(val data: IndexedSeq[NDArray],
// (must match the order of input data/label)
private val providedData: ListMap[String, Shape] = null,
private val providedLabel: ListMap[String, Shape] = null,
val dtype: DType = Base.MX_REAL_TYPE,
val dataDType: DType = Base.MX_REAL_TYPE,
val labelDType: DType = DType.Int32,
val dataLayout: String = "NCHW",
val labelLayout: String = "N") {
/**
Expand Down Expand Up @@ -178,7 +183,8 @@ object DataBatch {
private var pad: Int = 0
private var dataLayout: String = "NCHW"
private var labelLayout: String = "N"
private var dtype: DType = Base.MX_REAL_TYPE
private var dataDType: DType = Base.MX_REAL_TYPE
private var labelDType: DType = DType.Int32
private var bucketKey: AnyRef = null
private var datatShapes: ListMap[String, Shape] = null
private var labelShapes: ListMap[String, Shape] = null
Expand Down Expand Up @@ -227,11 +233,13 @@ object DataBatch {

/**
* Set the dtype.
* @param dtype The dtype of the label, default is Float32
* @param dataDType The dtype of the data, default is Float32
* @param labelDType The dtype of the label, default is Int32
* @return this
*/
def setDType(dtype: DType): Builder = {
this.dtype = dtype
def setDType(dataDType: DType, labelDType: DType): Builder = {
this.dataDType = dataDType
this.labelDType = labelDType
this
}

Expand Down Expand Up @@ -290,7 +298,7 @@ object DataBatch {
def build(): DataBatch = {
require(data != null, "data is required.")
new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes,
dtype, dataLayout, labelLayout)
dataDType, labelDType, dataLayout, labelLayout)
}
}
}
Expand All @@ -313,7 +321,8 @@ abstract class DataIter extends Iterator[DataBatch] {
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2)
dataDType = getDType()._1, labelDType = getDType()._2,
dataLayout = getLayout()._1, labelLayout = getLayout()._2)
}

/**
Expand All @@ -337,9 +346,9 @@ abstract class DataIter extends Iterator[DataBatch] {

/**
* Get the DType
* @return DType of the DataIter
* @return data and label DType of the DataIter
*/
def getDType(): DType
def getDType(): (DType, DType)

/**
* Get the layout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
labelName: String = "label",
dtype: DType = DType.Float32,
dataLayout: String = "NCHW",
labelLayout: String = "N")
labelLayout: String = "N",
dataDType: DType = DType.Float32,
labelDType: DType = DType.Int32)
extends DataIter with WarnIfNotDisposed {

private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
Expand All @@ -45,40 +47,30 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
// fix me if any better way found)
private var currentBatch: DataBatch = null

private val (_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape]) =
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
// properties
val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape))
currentBatch.dispose()
reset()
res
} else {
(null, null)
}

private val (_provideDataDesc: IndexedSeq[DataDesc],
_provideLabelDesc: IndexedSeq[DataDesc],
_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape],
_batchSize: Int) = {
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
val dType = currentBatch.dtype
val dataType = currentBatch.dataDType
val labelDType = currentBatch.labelDType
val dataLayout = currentBatch.dataLayout
val labelLayout = currentBatch.labelLayout
// properties
val res = (IndexedSeq(new DataDesc(dataName, data.shape, dType, dataLayout)),
IndexedSeq(new DataDesc(labelName, label.shape, dType, labelLayout)),
val res = (IndexedSeq(new DataDesc(dataName, data.shape, dataDType, dataLayout)),
IndexedSeq(new DataDesc(labelName, label.shape, labelDType, labelLayout)),
ListMap(dataName -> data.shape),
ListMap(labelName -> label.shape),
data.shape(0))
currentBatch.dispose()
reset()
res
} else {
(null, null, 0)
(null, null, null, null, 0)
}
}

Expand Down Expand Up @@ -130,8 +122,8 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
index = getIndex(), pad = getPad(),
dtype = getDType(), dataLayout = getLayout()._1,
labelLayout = getLayout()._2)
dataDType = getDType()._1, labelDType = getDType()._2,
dataLayout = getLayout()._1, labelLayout = getLayout()._2)
} else {
currentBatch = null
}
Expand Down Expand Up @@ -184,7 +176,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
* Get the DType
* @return DType
*/
def getDType(): DType = dtype
def getDType(): (DType, DType) = (dataDType, labelDType)

/**
* Get the layout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
label: IndexedSeq[(String, NDArray)],
private val dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String,
dtype: DType, dataLayout: String, labelLayout: String) extends DataIter {
dataDType: DType, labelDType: DType,
dataLayout: String, labelLayout: String) extends DataIter {

/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
Expand All @@ -62,11 +63,11 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label",
dType: DType = MX_REAL_TYPE, dataLayout: String = "NCHW",
labelLayout: String = "N") {
dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32,
dataLayout: String = "NCHW", labelLayout: String = "N") {
this(IO.initData(data, allowEmpty = false, dataName),
IO.initData(label, allowEmpty = true, labelName),
dataBatchSize, shuffle, lastBatchHandle, dType, dataLayout, labelLayout)
dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout)
}

private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
Expand Down Expand Up @@ -113,8 +114,9 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],

private val (_provideDataDesc: IndexedSeq[DataDesc],
_provideLabelDesc: IndexedSeq[DataDesc]) = {
val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, dataLayout))
val pLabel = initLabel.map(ele => new DataDesc(ele._1, getShape(ele)._2, dtype, labelLayout))
val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, dataDType, dataLayout))
val pLabel = initLabel.map(ele =>
new DataDesc(ele._1, getShape(ele)._2, labelDType, labelLayout))
(pData, pLabel)
}

Expand Down Expand Up @@ -160,7 +162,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
if (hasNext) {
cursor += dataBatchSize
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
dtype = getDType(), dataLayout = getLayout()._1, labelLayout = getLayout()._2)
dataDType = getDType()._1, labelDType = getDType()._2,
dataLayout = getLayout()._1, labelLayout = getLayout()._2)
} else {
throw new NoSuchElementException
}
Expand Down Expand Up @@ -239,8 +242,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
* Get the DType
* @return DType
*/
def getDType(): DType = {
dtype
def getDType(): (DType, DType) = {
(dataDType, labelDType)
}

/**
Expand Down Expand Up @@ -278,7 +281,8 @@ object NDArrayIter {
private var lastBatchHandle: String = "pad"
private var dataLayout: String = "NCHW"
private var labelLayout: String = "N"
private var dtype: DType = Base.MX_REAL_TYPE
private var dataDType: DType = Base.MX_REAL_TYPE
private var labelDType: DType = DType.Int32

/**
* Add one data input with its name.
Expand Down Expand Up @@ -324,11 +328,13 @@ object NDArrayIter {

/**
* Set the dtype.
* @param dtype The dtype of the label, default is Float32
* @param dataDType The dtype of the data, default is Float32
* @param labelDType The dtype of the label, default is Int32
* @return this
*/
def setDType(dtype: DType): Builder = {
this.dtype = dtype
def setDType(dataDType: DType, labelDType: DType): Builder = {
this.dataDType = dataDType
this.labelDType = labelDType
this
}

Expand All @@ -350,7 +356,7 @@ object NDArrayIter {
*/
def build(): NDArrayIter = {
new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle,
dtype, dataLayout, labelLayout)
dataDType, labelDType, dataLayout, labelLayout)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ class PrefetchingIter(
* Get the DType
* @return DType
*/
def getDType(): DType = {
currentBatch.dtype
def getDType(): (DType, DType) = {
(currentBatch.dataDType, currentBatch.labelDType)
}

/**
Expand Down Expand Up @@ -226,7 +226,8 @@ class PrefetchingIter(
nextBatch(0).pad,
dataLayout = nextBatch(0).dataLayout,
labelLayout = nextBatch(0).labelLayout,
dtype = nextBatch(0).dtype)
dataDType = nextBatch(0).dataDType,
labelDType = nextBatch(0).labelDType)
for (e <- dataTaken) e.release()
true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ class ResizeIter(
* Get the DType
* @return DType
*/
def getDType(): DType = {
currentBatch.dtype
def getDType(): (DType, DType) = {
(currentBatch.dataDType, currentBatch.labelDType)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ object GanMnist {
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True",
"dataLayout" -> "NT",
"dataLayout" -> "NCHW",
"labelLayout" -> "N"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import scala.sys.process.Process

/**
* Example of multi-task
* @author Depeng Liang
*/
object ExampleMultiTask {
private val logger = LoggerFactory.getLogger(classOf[ExampleMultiTask])
Expand Down Expand Up @@ -68,8 +67,8 @@ object ExampleMultiTask {
new DataBatch(batch.data,
IndexedSeq(label, label),
batch.index,
batch.pad, dtype = batch.dtype, dataLayout = batch.dataLayout,
labelLayout = batch.labelLayout)
batch.pad, dataDType = batch.dataDType, labelDType = batch.labelDType,
dataLayout = batch.dataLayout, labelLayout = batch.labelLayout)
} else {
throw new NoSuchElementException
}
Expand Down Expand Up @@ -128,7 +127,7 @@ object ExampleMultiTask {
*/
override def getPad(): Int = this.dataIter.getPad()

override def getDType(): DType = this.dataIter.getDType()
override def getDType(): (DType, DType) = this.dataIter.getDType()

override def getLayout(): (String, String) = this.dataIter.getLayout()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ object BucketIo {
readContent: ReadContent = defaultReadContent,
dataLayout: String = "NT",
labelLayout: String = "N",
dtype : DType = DType.Float32) extends DataIter {
dataDType : DType = DType.Float32,
labelDType: DType = DType.Int32) extends DataIter {

private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])

Expand Down Expand Up @@ -175,12 +176,12 @@ object BucketIo {

private val _provideDataDesc = {
val tmp = IndexedSeq(new DataDesc("data",
Shape(_batchSize, _defaultBucketKey), dtype, dataLayout))
tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dtype, dataLayout))
Shape(_batchSize, _defaultBucketKey), dataDType, dataLayout))
tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), dataDType, dataLayout))
}

private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label",
Shape(_batchSize, _defaultBucketKey), dtype, labelLayout))
Shape(_batchSize, _defaultBucketKey), labelDType, labelLayout))

private var iBucket = 0

Expand Down Expand Up @@ -212,7 +213,8 @@ object BucketIo {
getIndex(),
getPad(),
this.buckets(bucketIdx).asInstanceOf[AnyRef],
batchProvideData, batchProvideLabel, getDType(),
batchProvideData, batchProvideLabel,
getDType()._1, getDType()._2,
getLayout()._1, getLayout()._2)
}

Expand Down Expand Up @@ -251,7 +253,7 @@ object BucketIo {
*/
override def getPad(): Int = 0

override def getDType(): DType = dtype
override def getDType(): (DType, DType) = (dataDType, labelDType)

override def getLayout(): (String, String) = (dataLayout, labelLayout)

Expand Down
Loading

0 comments on commit 4ce87dd

Please sign in to comment.