Skip to content

Commit

Permalink
Update to support/require oneDNN >= 3.x (flashlight#1137)
Browse files Browse the repository at this point in the history
Summary:
oneDNN breaks their API on major version updates; make changes to support the new API. This is mostly consolidation of primitives, primitive descriptors, and operators.

Pull Request resolved: flashlight#1137

Test Plan: CI + local tests to ensure feature parity on what's supported

Reviewed By: lshamis

Differential Revision: D50206237

Pulled By: jacobkahn

fbshipit-source-id: 2631659e9db401b143047a9a80e6ffdc16134c76
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Oct 22, 2023
1 parent f354e7f commit 94dc10b
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 186 deletions.
18 changes: 13 additions & 5 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ commands:
- when:
condition:
or:
- equal: ["macos", << parameters.platform >>]
- equal: ["macos-arm", << parameters.platform >>]
steps:
- run:
Expand Down Expand Up @@ -280,11 +279,20 @@ commands:
- run:
name: "Create micromamba environment for oneDNN"
command: |
micromamba create -n onednn -y
echo "micromamba activate onednn" >> $BASH_ENV
micromamba create -n flashlight -y
echo "micromamba activate flashlight" >> $BASH_ENV
- run:
name: "Install oneDNN"
command: micromamba install -y -c conda-forge onednn=2.7.2
name: "Install oneDNN in env"
command: micromamba install onednn -c conda-forge -y
- when:
condition:
or:
- equal: ["linux-arm", << parameters.platform >>]
- equal: ["linux", << parameters.platform >>]
steps:
- run:
name: "Install compiler toolchains needed with oneDNN"
command: micromamba install gxx=11 -c conda-forge -y

# Primary job for installing all dependencies based on platform,
# backend, and autograd backend impl
Expand Down
18 changes: 8 additions & 10 deletions .github/actions/install_core_deps/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ runs:
sudo apt update
sudo apt install arrayfire-cmake=3.8.1-2 arrayfire-headers=3.8.1-2 arrayfire-cpu3-mkl=3.8.1-2 arrayfire-cpu3-dev=3.8.1-2
if: runner.os == 'Linux' && inputs.backend == 'ArrayFire'
shell: bash -l {0}
shell: bash -el {0}
- name: "Install ArrayFire (macOS)"
run: brew install arrayfire
if: runner.os == 'macOS' && inputs.backend == 'ArrayFire'
shell: bash -l {0}
shell: bash -el {0}
- name: "Install ArrayFire (Windows)"
run: |
choco install --no-progress wget -y
Expand All @@ -38,20 +38,18 @@ runs:
7z.exe x $INSTALLER_NAME -o"C:\Program Files\ArrayFire" -y
rm $INSTALLER_NAME
if: runner.os == 'Windows' && inputs.backend == 'ArrayFire'
shell: bash -l {0}
shell: bash -el {0}
# oneDNN
- name: Install oneDNN 2.7.2 with microconda
uses: mamba-org/provision-with-micromamba@main
- name: Install oneDNN with micromamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: false
environment-name: flashlight
channels: conda-forge
extra-specs: onednn=2.7.2
if: inputs.backend == 'oneDNN' || inputs.autograd_backend == 'oneDNN'
create-args: onednn -c conda-forge
if: (inputs.backend == 'oneDNN' || inputs.autograd_backend == 'oneDNN')
# MPI
- name: "Install OpenMPI (Linux)"
run: |
sudo apt update
sudo apt install -y openmpi-bin libopenmpi-dev
if: runner.os == 'Linux' && inputs.distributed_backend != '' && inputs.distributed_backend != 'Stub'
shell: bash -l {0}
shell: bash -el {0}
10 changes: 5 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04, windows-2022, macOS-12]
os: [ubuntu-22.04, windows-2022, macOS-12]
backend: [ArrayFire, oneDNN, Stub]
autograd_backend: [oneDNN]
distributed_backend: [Stub]
Expand All @@ -19,12 +19,12 @@ jobs:
- os: windows-2022
distributed_backend: Gloo
include:
- os: ubuntu-20.04
- os: ubuntu-22.04
backend: oneDNN
autograd_backend: oneDNN
distributed_backend: Gloo
# Configuration using only stubs and no autograd backend
- os: ubuntu-20.04
- os: ubuntu-22.04
backend: Stub
distributed_backend: Stub
defaults:
Expand Down Expand Up @@ -74,7 +74,7 @@ jobs:
# needs: build_core # TODO: this won't work until Github Actions enables specific matrix dependencies
# strategy:
# matrix:
# os: [ubuntu-20.04, macOS-12]
# os: [ubuntu-22.04, macOS-12]
# backend: [ArrayFire]
# autograd_backend: [oneDNN]
# pkg: [runtime, speech, vision, text]
Expand Down Expand Up @@ -112,7 +112,7 @@ jobs:

build_core_wasm:
name: "Build WebAssembly libraries with Emscripten compilers + Flashlight core + stub backend"
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
defaults:
run:
shell: bash -l {0}
Expand Down
76 changes: 44 additions & 32 deletions flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ dnnl::memory::dims getInputOutputDims(

struct OneDnnBatchNormPayload : detail::AutogradPayloadData {
dnnl::batch_normalization_forward::primitive_desc fwdPrimDesc;
Tensor weightsDnnl; // combined weight and bias
dnnl::memory::dims weightsDnnlDims;
Tensor weights; // combined weight and bias
Tensor bias;
dnnl::memory::dims weightsDims;
dnnl::memory::dims biasDims;
dnnl::memory::desc outputMemoryDescriptor;
dnnl::memory meanMemory;
dnnl::memory varMemory;
dnnl::memory weightsMemory;
dnnl::memory biasMemory;
};

} // namespace
Expand Down Expand Up @@ -142,14 +145,12 @@ Tensor OneDnnAutogradExtension::batchnorm(

// DNNL only accepts weight and bias as a combined input.
// https://git.io/JLn9X
payload->weightsDnnl = fl::concatenate(0, weightNonempty, biasNonempty);

payload->weights = weightNonempty;
payload->bias = biasNonempty;
payload->weightsDims = detail::convertToDnnlDims({nfeatures});
payload->biasDims = detail::convertToDnnlDims({nfeatures});
auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures);

auto inputOutputMemDesc =
dnnl::memory::desc({inputOutputDims}, dType, formatNCHW);
payload->weightsDnnlDims = detail::convertToDnnlDims({2, nfeatures});

// Memory for forward
const detail::DnnlMemoryWrapper inputMemory(
input, inputOutputDims, formatNCHW);
Expand All @@ -161,30 +162,38 @@ Tensor OneDnnAutogradExtension::batchnorm(
runningVar, {runningVar.dim(0)}, formatX);
// combined scale and shift (weight and bias)
const detail::DnnlMemoryWrapper weightsMemory(
payload->weightsDnnl, payload->weightsDnnlDims, format2d);
payload->weights, payload->weightsDims, formatX);
const detail::DnnlMemoryWrapper biasMemory(
payload->bias, payload->biasDims, formatX);
payload->meanMemory = meanMemory.getMemory();
payload->varMemory = varMemory.getMemory();
payload->weightsMemory = weightsMemory.getMemory();
payload->biasMemory = biasMemory.getMemory();
// Primitives and descriptors
auto kind = train ? dnnl::prop_kind::forward_training
: dnnl::prop_kind::forward_inference;
// https://fburl.com/6latj733
dnnl::normalization_flags flag = train
? dnnl::normalization_flags::none
: dnnl::normalization_flags::use_global_stats;
flag = flag | dnnl::normalization_flags::use_scale_shift;
auto fwdDesc = dnnl::batch_normalization_forward::desc(
kind, inputOutputMemDesc, epsilon, flag);
payload->fwdPrimDesc =
dnnl::batch_normalization_forward::primitive_desc(fwdDesc, dnnlEngine);
flag = flag | dnnl::normalization_flags::use_scale |
dnnl::normalization_flags::use_shift;
payload->fwdPrimDesc = dnnl::batch_normalization_forward::primitive_desc(
dnnlEngine,
kind,
inputMemory.getDescriptor(),
outputMemory.getDescriptor(),
epsilon,
flag);
payload->outputMemoryDescriptor = outputMemory.getDescriptor();
auto bn = dnnl::batch_normalization_forward(payload->fwdPrimDesc);
std::unordered_map<int, dnnl::memory> bnFwdArgs = {
{DNNL_ARG_SRC, inputMemory.getMemory()},
{DNNL_ARG_MEAN, meanMemory.getMemory()},
{DNNL_ARG_VARIANCE, varMemory.getMemory()},
{DNNL_ARG_DST, outputMemory.getMemory()},
{DNNL_ARG_SCALE_SHIFT, weightsMemory.getMemory()}};
{DNNL_ARG_SCALE, weightsMemory.getMemory()},
{DNNL_ARG_SHIFT, biasMemory.getMemory()}};

// Execute
std::vector<dnnl::primitive> network;
Expand Down Expand Up @@ -217,17 +226,17 @@ std::tuple<Tensor, Tensor, Tensor> OneDnnAutogradExtension::batchnormBackward(

auto maxAxis = *std::max_element(axes.begin(), axes.end());
auto minAxis = *std::min_element(axes.begin(), axes.end());
bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1));
const bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1));
if (!axesContinuous) {
throw std::invalid_argument("axis array should be continuous");
}

int nfeatures = getNfeatures(input.shape(), axes);
const int nfeatures = getNfeatures(input.shape(), axes);
auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures);

auto gradInput = Tensor(input.shape(), input.type());
auto gradWeightsDNNL =
Tensor(payload->weightsDnnl.shape(), payload->weightsDnnl.type());
auto gradWeights = Tensor(payload->weights.shape(), payload->weights.type());
auto gradBias = Tensor(payload->bias.shape(), payload->bias.type());

const detail::DnnlMemoryWrapper inputMemory(
input, inputOutputDims, formatNCHW);
Expand All @@ -238,38 +247,41 @@ std::tuple<Tensor, Tensor, Tensor> OneDnnAutogradExtension::batchnormBackward(
const detail::DnnlMemoryWrapper gradInputMem(
gradInput, inputOutputDims, formatNCHW);
const detail::DnnlMemoryWrapper gradWeightsMem(
gradWeightsDNNL, payload->weightsDnnlDims, format2d);
gradWeights, payload->weightsDims, formatX);
const detail::DnnlMemoryWrapper gradBiasMem(
gradBias, payload->biasDims, formatX);

// Primitives and descriptors
auto bwdDesc = dnnl::batch_normalization_backward::desc(
auto bwdPrimitiveDesc = dnnl::batch_normalization_backward::primitive_desc(
dnnlEngine,
dnnl::prop_kind::backward,
gradOutputMem.getDescriptor(),
payload->outputMemoryDescriptor,
gradOutputMem.getDescriptor(),
epsilon,
dnnl::normalization_flags::use_scale_shift);
auto bwdPrimDesc = dnnl::batch_normalization_backward::primitive_desc(
bwdDesc, dnnlEngine, payload->fwdPrimDesc);
dnnl::normalization_flags::use_scale |
dnnl::normalization_flags::use_shift,
payload->fwdPrimDesc // hint
);
auto bwdPrim =
std::make_shared<dnnl::batch_normalization_backward>(bwdPrimDesc);
std::make_shared<dnnl::batch_normalization_backward>(bwdPrimitiveDesc);

// Execute
std::vector<dnnl::primitive> networkBackwards;
std::vector<std::unordered_map<int, dnnl::memory>> bwdArgs = {
{{DNNL_ARG_SRC, inputMemory.getMemory()},
{DNNL_ARG_MEAN, payload->meanMemory},
{DNNL_ARG_VARIANCE, payload->varMemory},
{DNNL_ARG_SCALE_SHIFT, payload->weightsMemory},
{DNNL_ARG_SCALE, payload->weightsMemory},
{DNNL_ARG_SHIFT, payload->biasMemory},
{DNNL_ARG_DIFF_SRC, gradInputMem.getMemory()},
{DNNL_ARG_DIFF_DST, gradOutputMem.getMemory()},
{DNNL_ARG_DIFF_SCALE_SHIFT, gradWeightsMem.getMemory()}}};
{DNNL_ARG_DIFF_SCALE, gradWeightsMem.getMemory()},
{DNNL_ARG_DIFF_SHIFT, gradBiasMem.getMemory()}}};
networkBackwards.push_back(*bwdPrim);
detail::executeNetwork(networkBackwards, bwdArgs);

return {
gradInput,
gradWeightsDNNL(fl::range(0, nfeatures)), // weights grad
gradWeightsDNNL(fl::range(nfeatures, 2 * nfeatures)) // bias grad
};
return {gradInput, gradWeights, gradBias};
};

} // namespace fl
Loading

0 comments on commit 94dc10b

Please sign in to comment.