-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-689] add DataDesc type for the Scala Package #11844
Conversation
When I am in the middle of changes, some parts looks confusing to me. Are we defining these field in the /**
* DataIter creator
* @param handle native memory ptr for the iterator
* @param params parameter passed to the iterator
* @return created DataIter
*/
private def creator(handle: DataIterCreator)(
params: Map[String, String]): DataIter = {
val out = new DataIterHandleRef
val keys = params.keys.toArray
val vals = params.values.toArray
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
val dataName = params.getOrElse("data_name", "data")
val labelName = params.getOrElse("label_name", "label")
new MXDataIter(out.value, dataName, labelName)
} This creator did not pass in the layout and dtype information and MXDataIter extends DataIter. By default, the layout are set a "NCHW" that causing the problem. |
(null, null) | ||
} | ||
|
||
private val (_provideDataDesc: IndexedSeq[DataDesc], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we only keep DataDesc and construct ListMap from DataDesc? or at least combine it with L45-58
* @param dtype The dtype of the label, default is Float32 | ||
* @return this | ||
*/ | ||
def setDType(dtype: DType): Builder = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you specify the default value here
@@ -140,7 +140,9 @@ class DataBatch(val data: IndexedSeq[NDArray], | |||
// use ListMap to indicate the order of data/label loading | |||
// (must match the order of input data/label) | |||
private val providedData: ListMap[String, Shape] = null, | |||
private val providedLabel: ListMap[String, Shape] = null) { | |||
private val providedLabel: ListMap[String, Shape] = null, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should replace this constructor with provideDataShape, provideLabelShape that DataDesc instead ListMap[String, Shape] and constructor(this) another to preserve ListMap so we don't break backwards compatibility
@@ -170,6 +172,8 @@ object DataBatch { | |||
private var label: IndexedSeq[NDArray] = null | |||
private var index: IndexedSeq[Long] = null | |||
private var pad: Int = 0 | |||
private var layout: String = "NCHW" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we wouldn't need this if add a new constructor with DataDesc
@@ -314,6 +351,12 @@ abstract class DataIter extends Iterator[DataBatch] { | |||
// The name and shape of label provided by this iterator | |||
def provideLabel: ListMap[String, Shape] | |||
|
|||
// Provide type:DataDesc of the data | |||
def provideDataDesc: IndexedSeq[DataDesc] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we deprecate the old ones?
* Get the DType | ||
* @return DType | ||
*/ | ||
def getDType(): DType = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to add two new parameters to the constructor dataDescriptor
and labelDescriptor
.
Make sure not break API backwards compatibility
After some days spent on the issues, I finally figure out the fix on the issues we have. @nswamy and @yzhliu thanks for your comments on the changes and things I am about to apply is:
|
@@ -59,10 +61,12 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)], | |||
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, | |||
dataBatchSize: Int = 1, shuffle: Boolean = false, | |||
lastBatchHandle: String = "pad", | |||
dataName: String = "data", labelName: String = "label") { | |||
dataName: String = "data", labelName: String = "label", | |||
dType: DType = MX_REAL_TYPE, dataLayout: String = "NCHW", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to add this here. add the defaults to the new constructor above in L47
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create another this constructor with the old set of constructor args to not break backwards compatibility.
thinking about it, can you first add default to the class constructor and test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its ok to keep here. let all the defaults remain in one place.
@@ -170,6 +180,10 @@ object DataBatch { | |||
private var label: IndexedSeq[NDArray] = null | |||
private var index: IndexedSeq[Long] = null | |||
private var pad: Int = 0 | |||
private var dataLayout: String = "NCHW" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For layouts I suggest to use extensible Enums similar to this https://blogs.oracle.com/darcy/mixing-in-an-enum.
free floating strings are generally not a good practice because it removes compile time type-checks.
WDYT? does it make sense?
Hi @benkamphaus, could you please take a look at the recent CI failure on Clojure. I believe these lines cause that failure: https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj#L194-L225. I have modified the input args for |
@benkamphaus, from Clojure perspective are we breaking backwards compatibility. I am curious to know how it is handled, since we added new parameters with defaults. |
It appears to be the case that the new constructor arglist for
The cause is straight forward -- the constructor call here is no longer valid due to the change to the args required for This constructor call on the clojure side backs Note that the keyword args in the Clojure API means that I'll discuss with @gigasquid as well. |
@benkamphaus thanks for your detailed analysis. I can tell how hard it is if we change the args in here since I also changed a lot of examples to match this change in Scala. Let's be cautious on this change and discuss if it is necessary here to introduce more parameters. |
I have reverted the changes back to Strings. Next step is to add a default value to Layout section so we can bypass the require check |
6c8289a
to
f49eeca
Compare
In the recent two PR, add Backward Compatibility fix on DataIter as well as DataBatch. User can still go with the old defs... |
It looks like if we want to force conversion from Float32 to Int32 will cause a crash on JVM. Need to be addressed.
This reverts commit abbe283.
* add dataDesc * Add amend * add changes with dataLayout and labelLayout * add depreciate and example changes * Gan and Customop fixes * change the DType * add one more class to convert Strings to DTypes * convert layout to global * scala style fix * Revert to 8c7d1f8 * fix coding style issue * print full stacktraces * apply changes to new constructor * add databatch bcc * introduce undefined field * Fix crashes when change provideData to provideDataDesc It looks like if we want to force conversion from Float32 to Int32 will cause a crash on JVM. Need to be addressed. * change spacing and revert test * apply DataDesc on DataBatch * unit test for NDArrayIter and MXDataiter * apply changes on CR * change NDArrayIter and revert the rest * revert change on examples * apply final changes * remove the provideLabelShape * add TODO about the findings
Description
This PR includes massive changes on all Iterators that extends
DataIter
, two new fields are added (dtype
andlayout
).@nswamy @yzhliu @andrewfayres
It intends to solve some issues about the layout information missing in DataDesc. For example: RNN based networks require different layouts(“TNC”/”NTC”). Current implementation of DataBatch, DataIter does not support different layouts and assumes a fixed layout based on the shape of the NDArray returned in the next() call to the iterator.
However I found it hard to have a general solution that both address backward compatibility as well as adding this information. Currently, there are several ways to address this issue:
Assume Shape == 2 follows
NT
format, Shape == 3 followsTNC
, Shape == 4 followsNCHW
and only change the implicit function in here.Add
Layout
andDType
defs for all iterator, then users can pass these information to createDataDesc
instead ofListMap[Name, Shape]
. However, it breaks the input args forNDArrayIter
as well as other iterators. Users are required to add more inputs. However, if the iterator are passed in with dynamic shapes (shapes are not determined), it will break the DataDesc checking scheme.Do what 2 contains and create
this
constructor to make sure BC.Test
Historical issues
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.