Skip to content

Commit

Permalink
[vulkan] test_app for mobilenetV2 on vulkan api (pytorch#48924)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#48924

Test Plan: Imported from OSS

Reviewed By: SS-JIA

Differential Revision: D25365000

Pulled By: IvanKobzarev

fbshipit-source-id: 79295b5781d2494681dbb4e4a741de49ff9c058c
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Dec 7, 2020
1 parent 36df253 commit 21ba48f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
20 changes: 10 additions & 10 deletions android/test_app/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ android {
//}
flavorDimensions "model", "build", "activity"
productFlavors {
mbq {
mnet {
dimension "model"
applicationIdSuffix ".mbq"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"")
addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
}
mbvulkan {
mnetVulkan {
dimension "model"
applicationIdSuffix ".mbvulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
}
resnet18 {
dimension "model"
Expand Down
37 changes: 30 additions & 7 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,40 @@ vTensor pack_weights(
}

// shader KO4C4HW_to_image
float image[4 * C_4][OC_4][KH * KW][4];
memset(image, 0.f, 16 * C_4 * OC_4 * KH * KW * sizeof(float));
struct Image3D {
float* data_;
uint32_t dim0_, dim1_, dim2_;

Image3D(uint32_t dim0, uint32_t dim1, uint32_t dim2) {
dim0_ = dim0;
dim1_ = dim1;
dim2_ = dim2;
data_ = new float[dim0 * dim1 * dim2 * 4];
memset(data_, 0.f, dim0 * dim1 * dim2 * 4 * sizeof(float));
}

inline uint32_t idx(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) {
return i3 + i2 * 4 + i1 * 4 * dim2_ + i0 * 4 * dim2_ * dim1_;
}

void set(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, float value) {
data_[idx(i0, i1, i2, i3)] = value;
}

float get(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) {
return data_[idx(i0, i1, i2, i3)];
}
} image{4 * C_4, OC_4, KH * KW};

for (uint32_t sx = 0; sx < C_4; ++sx) {
for (uint32_t sy = 0; sy < OC_4; ++sy) {
for (uint32_t sz = 0; sz < (KH * KW); ++sz) {
for (uint32_t vi = 0; vi < 4; ++vi) {
int bufferVIdx = 4 * sx * KH * KW + 4 * sy * C_4 * KH * KW + 4 * sz;
image[4 * sx + 0][sy][sz][vi] = dst[4 * (bufferVIdx + 0) + vi];
image[4 * sx + 1][sy][sz][vi] = dst[4 * (bufferVIdx + 1) + vi];
image[4 * sx + 2][sy][sz][vi] = dst[4 * (bufferVIdx + 2) + vi];
image[4 * sx + 3][sy][sz][vi] = dst[4 * (bufferVIdx + 3) + vi];
image.set(4 * sx + 0, sy, sz, vi, dst[4 * (bufferVIdx + 0) + vi]);
image.set(4 * sx + 1, sy, sz, vi, dst[4 * (bufferVIdx + 1) + vi]);
image.set(4 * sx + 2, sy, sz, vi, dst[4 * (bufferVIdx + 2) + vi]);
image.set(4 * sx + 3, sy, sz, vi, dst[4 * (bufferVIdx + 3) + vi]);
}
}
}
Expand All @@ -143,7 +166,7 @@ vTensor pack_weights(
for (uint32_t sy = 0; sy < H; ++sy) {
for (uint32_t sz = 0; sz < D; ++sz) {
for (uint32_t szvi = 0; szvi < 4; ++szvi) {
dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image[sx][sy][sz][szvi];
dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image.get(sx, sy, sz, szvi);
}
}
}
Expand Down

0 comments on commit 21ba48f

Please sign in to comment.