Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
convert layout to global
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 27, 2018
1 parent dcbb093 commit 18bd6b9
Show file tree
Hide file tree
Showing 20 changed files with 153 additions and 94 deletions.
23 changes: 13 additions & 10 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.mxnet

import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.apache.mxnet.Layout.Layout
import org.apache.mxnet.io.{MXDataIter, MXDataPack}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -110,7 +111,8 @@ object IO {
val dataDType = params.getOrElse("dataDType", "Float32")
val labelDType = params.getOrElse("labelDType", "Int32")
new MXDataIter(out.value, dataName, labelName,
dataLayout = dataLayout, labelLayout = labelLayout,
dataLayout = Layout.getLayout(dataLayout),
labelLayout = Layout.getLayout(labelLayout),
dataDType = DType.getType(dataDType),
labelDType = DType.getType(labelDType))
}
Expand Down Expand Up @@ -149,8 +151,8 @@ class DataBatch(val data: IndexedSeq[NDArray],
private val providedLabel: ListMap[String, Shape] = null,
val dataDType: DType = Base.MX_REAL_TYPE,
val labelDType: DType = DType.Int32,
val dataLayout: String = "NCHW",
val labelLayout: String = "N") {
val dataLayout: Layout = Layout.NCHW,
val labelLayout: Layout = Layout.N) {
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
Expand Down Expand Up @@ -180,8 +182,8 @@ object DataBatch {
private var label: IndexedSeq[NDArray] = null
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var dataLayout: String = "NCHW"
private var labelLayout: String = "N"
private var dataLayout: Layout = Layout.NCHW
private var labelLayout: Layout = Layout.N
private var dataDType: DType = Base.MX_REAL_TYPE
private var labelDType: DType = DType.Int32
private var bucketKey: AnyRef = null
Expand Down Expand Up @@ -248,7 +250,7 @@ object DataBatch {
* @param labelLayout The layout of the label, default is N
* @return this
*/
def setLayout(dataLayout: String, labelLayout: String): Builder = {
def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = {
this.dataLayout = dataLayout
this.labelLayout = labelLayout
this
Expand Down Expand Up @@ -353,7 +355,7 @@ abstract class DataIter extends Iterator[DataBatch] {
* Get the layout
* @return data and label layout of the DataIter
*/
def getLayout(): (String, String)
def getLayout(): (Layout, Layout)

/**
* Get the index of current batch
Expand Down Expand Up @@ -393,10 +395,11 @@ abstract class DataPack() extends Iterable[DataBatch] {

// Named data desc description contains name, shape, type and other extended attributes.
case class DataDesc(name: String, shape: Shape,
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") {
require(shape.length == layout.length, ("number of dimensions in shape :%d with" +
dtype: DType = Base.MX_REAL_TYPE, layout: Layout = Layout.NCHW) {
val layoutStr = layout.toString
require(shape.length == layoutStr.length, ("number of dimensions in shape :%d with" +
" shape: %s should match the length of the layout: %d with layout: %s").
format(shape.length, shape.toString, layout.length, layout))
format(shape.length, shape.toString, layoutStr.length, layoutStr))

override def toString(): String = {
s"DataDesc[$name,$shape,$dtype,$layout]"
Expand Down
48 changes: 48 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

/**
* Layout type that represent what inside of a shape
* N Batch size
* C number of channels
* H height (image)
* W width (image)
* T temporal axis representing time (NLP)
*/

object Layout extends Enumeration {
type Layout = Value
val NCHW = Value("NCHW")
val TNC = Value("TNC")
val CHW = Value("CHW")
val NT = Value("NT")
val N = Value("N")

private[mxnet] def getLayout(layoutStr: String): Layout = {
layoutStr match {
case "NCHW" => NCHW
case "TNC" => TNC
case "CHW" => CHW
case "NT" => NT
case "N" => N
case _ => throw new RuntimeException(
s"Unknown $layoutStr defined!, please check Layout.scala")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.apache.mxnet._
import org.apache.mxnet.IO._
import org.apache.mxnet.Layout.Layout
import org.slf4j.LoggerFactory

import scala.collection.immutable.ListMap
Expand All @@ -33,9 +34,8 @@ import scala.collection.mutable.ListBuffer
private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label",
dtype: DType = DType.Float32,
dataLayout: String = "NCHW",
labelLayout: String = "N",
dataLayout: Layout = Layout.NCHW,
labelLayout: Layout = Layout.N,
dataDType: DType = DType.Float32,
labelDType: DType = DType.Int32)
extends DataIter with WarnIfNotDisposed {
Expand All @@ -56,7 +56,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
val dataType = currentBatch.dataDType
val dataDType = currentBatch.dataDType
val labelDType = currentBatch.labelDType
val dataLayout = currentBatch.dataLayout
val labelLayout = currentBatch.labelLayout
Expand Down Expand Up @@ -182,7 +182,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
* Get the layout
* @return layout
*/
def getLayout(): (String, String) = (dataLayout, labelLayout)
def getLayout(): (Layout, Layout) = (dataLayout, labelLayout)

// The name and shape of data provided by this iterator
override def provideData: ListMap[String, Shape] = _provideData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.NoSuchElementException

import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.apache.mxnet.Layout.Layout
import org.apache.mxnet._
import org.slf4j.LoggerFactory

Expand All @@ -45,31 +46,34 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
private val dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String,
dataDType: DType, labelDType: DType,
dataLayout: String, labelLayout: String) extends DataIter {

dataLayout: Layout, labelLayout: Layout) extends DataIter {
// scalastyle:off
/**
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
* @param data Specify the data. Data names will be data_0, data_1, ..., etc.
* @param label Same as data, but is not fed to the model during testing.
* Label names will be label_0, label_1, ..., etc.
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
*
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label",
dataDType: DType = MX_REAL_TYPE, labelDType: DType = DType.Int32,
dataLayout: String = "NCHW", labelLayout: String = "N") {
dataDType: DType = Base.MX_REAL_TYPE,
labelDType: DType = DType.Int32,
dataLayout: Layout = Layout.NCHW,
labelLayout : Layout = Layout.N) {
this(IO.initData(data, allowEmpty = false, dataName),
IO.initData(label, allowEmpty = true, labelName),
dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType, dataLayout, labelLayout)
dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType,
dataLayout, labelLayout)
}

// scalastyle:on
private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])

val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = {
Expand Down Expand Up @@ -250,7 +254,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
* Get the layout
* @return layout
*/
def getLayout(): (String, String) = {
def getLayout(): (Layout, Layout) = {
(dataLayout, labelLayout)
}

Expand Down Expand Up @@ -279,8 +283,8 @@ object NDArrayIter {
private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
private var dataBatchSize: Int = 1
private var lastBatchHandle: String = "pad"
private var dataLayout: String = "NCHW"
private var labelLayout: String = "N"
private var dataLayout: Layout = Layout.NCHW
private var labelLayout: Layout = Layout.N
private var dataDType: DType = Base.MX_REAL_TYPE
private var labelDType: DType = DType.Int32

Expand Down Expand Up @@ -344,7 +348,7 @@ object NDArrayIter {
* @param labelLayout The layout of the label, default is N
* @return this
*/
def setLayout(dataLayout: String, labelLayout: String): Builder = {
def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = {
this.dataLayout = dataLayout
this.labelLayout = labelLayout
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.slf4j.LoggerFactory
import java.util.concurrent.Semaphore

import org.apache.mxnet.DType.DType
import org.apache.mxnet.Layout.Layout

import scala.collection.immutable.ListMap

Expand Down Expand Up @@ -189,7 +190,7 @@ class PrefetchingIter(
* Get the layout
* @return layout
*/
def getLayout(): (String, String) = {
def getLayout(): (Layout, Layout) = {
(currentBatch.dataLayout, currentBatch.labelLayout)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.mxnet.io
import java.util.NoSuchElementException

import org.apache.mxnet.DType.DType
import org.apache.mxnet.Layout.Layout
import org.apache.mxnet._
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -141,7 +142,7 @@ class ResizeIter(
* Get the layout
* @return layout
*/
def getLayout(): (String, String) = {
def getLayout(): (Layout, Layout) = {
(currentBatch.dataLayout, currentBatch.labelLayout)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class DataParallelExecutorGroup private[module](
*/
private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = {
require(dataShapes.size > 0)
val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layout)))
val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layoutStr)))

for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) {
if (axis != -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {

// test pad
val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
dataLayout = "NTC", labelLayout = "NT")
dataLayout = Layout.TNC, labelLayout = Layout.NT)
var batchCount = 0
val nBatch0 = 8
while(dataIter0.hasNext) {
Expand All @@ -271,7 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
.addData("data0", data(0)).addData("data1", data(1))
.addLabel("label", label(0))
.setBatchSize(128)
.setLayout("NTC", "NT")
.setLayout(Layout.TNC, Layout.NT)
.setLastBatchHandle("discard").build()
val nBatch1 = 7
batchCount = 0
Expand All @@ -288,7 +288,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {

// test empty label (for prediction)
val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, lastBatchHandle = "discard",
dataLayout = "NTC")
dataLayout = Layout.TNC)
batchCount = 0
while(dataIter2.hasNext) {
val tBatch = dataIter2.next()
Expand Down
Loading

0 comments on commit 18bd6b9

Please sign in to comment.