-
Notifications
You must be signed in to change notification settings - Fork 445
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Check in sample showcasing Rust projection of Winrt (#342)
* Init rust squeezenet project * Add winrt-rs submodule * update .gitmodules * Add squeezenet code * More efficient sorting * refactoring * Make pretty * add comments * Add README.md * nit * Use macro to clean up code * actually use winml nuget 1.4.0 * PR comments * Update to latest winrt-rs master Co-authored-by: Ryan Lai <[email protected]>
- Loading branch information
Ryan Lai
and
Ryan Lai
authored
Aug 24, 2020
1 parent
27fd18e
commit 3d2e468
Showing
6 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "Tools/WinML-Dashboard/deps/Netron"] | ||
path = Tools/WinMLDashboard/deps/Netron | ||
url = https://github.com/lutzroeder/Netron.git | ||
[submodule "Samples/RustSqueezenet/winrt-rs"] | ||
path = Samples/RustSqueezenet/winrt-rs | ||
url = https://github.com/microsoft/winrt-rs.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "rust_squeezenet" | ||
version = "0.1.0" | ||
authors = ["Microsoft"] | ||
edition = "2018" | ||
|
||
[build] | ||
target-dir = "target" | ||
|
||
[dependencies] | ||
winrt = { path = "./winrt-rs" } | ||
# NOTE: winrt_macros is needed as a dependency because Rust 1.46 is needed and hasn't been released yet. | ||
winrt_macros = { git = "https://github.com/microsoft/winrt-rs", version = "0.7.2" } | ||
|
||
[build-dependencies] | ||
winrt = { path = "./winrt-rs" } | ||
|
||
# Nuget packages | ||
[package.metadata.winrt.dependencies] | ||
"Microsoft.Windows.SDK.Contracts" = "10.0.19041.1" | ||
"Microsoft.AI.MachineLearning" = "1.4.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# SqueezeNet Rust sample | ||
This is a desktop application that uses SqueezeNet, a pre-trained machine learning model, to detect the predominant object in an image selected by the user from a file. | ||
|
||
Note: SqueezeNet was trained to work with image sizes of 224x224, so you must provide an image of size 224X224. | ||
|
||
## Prerequisites | ||
- [Install Rustup](https://www.rust-lang.org/tools/install) | ||
- Install cargo-winrt through command prompt. Until Rust 1.46 is released, cargo-winrt should be installed through the winrt-rs git repository. | ||
- ```cargo install --git https://github.com/microsoft/winrt-rs cargo-winrt``` | ||
|
||
## Build and Run the sample | ||
1. This project requires Rust 1.46, which is currently in Beta. Rust release dates can be found [here](https://forge.rust-lang.org/). Rust Beta features can be enabled by running the following commands through command prompt in this current project directory after installation of Rustup : | ||
- ``` rustup install beta ``` | ||
- ``` rustup override set beta ``` | ||
2. Install the WinRT nuget dependencies with this command: ``` cargo winrt install ``` | ||
3. Build the project by running ```cargo build``` for debug and ```cargo build --release``` for release. | ||
4. Run the sample by running this command through the command prompt. ``` cargo winrt run ``` | ||
- Another option would be to run the executable directly. Should be ```<git enlistment>\Samples\RustSqueezeNet\target\debug\rust_squeezenet.exe``` | ||
|
||
## Sample output | ||
``` | ||
C:\Repos\Windows-Machine-Learning\Samples\RustSqueezeNet> cargo winrt run | ||
Finished installing WinRT dependencies in 0.47s | ||
Finished dev [unoptimized + debuginfo] target(s) in 0.12s | ||
Running `target\debug\rust_squeezenet.exe` | ||
Loading model C:\Repos\Windows-Machine-Learning\RustSqueezeNet\target\debug\Squeezenet.onnx | ||
Creating session | ||
Loading image file C:\Repos\Windows-Machine-Learning\RustSqueezeNet\target\debug\kitten_224.png | ||
Evaluating | ||
Results: | ||
tabby tabby cat 0.9314611 | ||
Egyptian cat 0.06530659 | ||
tiger cat 0.0029267797 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
macro_rules! copy_file { | ||
($file:expr, $destination:expr) => { | ||
match fs::copy($file, | ||
$destination) { | ||
Ok(file) => file, | ||
Err(error) => panic!("Problem copying the file {} to {}: {:?}", $file, $destination, error), | ||
}; | ||
} | ||
} | ||
|
||
fn copy_resources() { | ||
use std::fs; | ||
let profile = std::env::var("PROFILE").unwrap(); | ||
if profile == "debug" { | ||
copy_file!("..\\..\\SharedContent\\media\\fish.png",".\\target\\debug\\fish.png"); | ||
copy_file!("..\\..\\SharedContent\\media\\fish.png",".\\target\\debug\\kitten_224.png"); | ||
copy_file!("..\\..\\SharedContent\\models\\SqueezeNet.onnx",".\\target\\debug\\SqueezeNet.onnx"); | ||
copy_file!("..\\SqueezeNetObjectDetection\\Desktop\\cpp\\Labels.txt",".\\target\\debug\\Labels.txt"); | ||
} | ||
else if profile == "release" { | ||
copy_file!("..\\..\\SharedContent\\media\\fish.png",".\\target\\release\\fish.png"); | ||
copy_file!("..\\..\\SharedContent\\media\\fish.png",".\\target\\release\\kitten_224.png"); | ||
copy_file!("..\\..\\SharedContent\\models\\SqueezeNet.onnx",".\\target\\release\\SqueezeNet.onnx"); | ||
copy_file!("..\\SqueezeNetObjectDetection\\Desktop\\cpp\\Labels.txt",".\\target\\release\\Labels.txt"); | ||
} | ||
} | ||
|
||
fn main() { | ||
winrt::build!( | ||
types | ||
microsoft::ai::machine_learning::* | ||
windows::graphics::imaging::* | ||
); | ||
copy_resources(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
include!(concat!(env!("OUT_DIR"), "/winrt.rs")); | ||
|
||
macro_rules! handle_io_error_as_winrt_error { | ||
($expression:expr, $error_message:expr) => { | ||
match $expression { | ||
Ok(val) => val, | ||
Err(_err) => return Err(winrt::Error::new(winrt::ErrorCode(Error::last_os_error().raw_os_error().unwrap() as u32), $error_message)), | ||
} | ||
} | ||
} | ||
|
||
fn main() -> winrt::Result<()> { | ||
use microsoft::ai::machine_learning::*; | ||
use winrt::ComInterface; | ||
|
||
let model_path = get_current_dir()? + "\\Squeezenet.onnx"; | ||
println!("Loading model {}", model_path); | ||
let learning_model = LearningModel::load_from_file_path(model_path)?; | ||
|
||
let device = LearningModelDevice::create(LearningModelDeviceKind::Cpu)?; | ||
|
||
println!("Creating session"); | ||
let session = LearningModelSession::create_from_model_on_device(learning_model, device)?; | ||
|
||
let image_file_path = get_current_dir()? + "\\kitten_224.png"; | ||
println!("Loading image file {}", image_file_path); | ||
let input_image_videoframe = load_image_file(image_file_path)?; | ||
let input_image_feature_value = ImageFeatureValue::create_from_video_frame(input_image_videoframe)?; | ||
let binding = LearningModelBinding::create_from_session(&session)?; | ||
binding.bind("data_0", input_image_feature_value)?; | ||
|
||
println!("Evaluating"); | ||
let results = LearningModelSession::evaluate(&session,binding, "RunId")?; | ||
|
||
let result_lookup = results.outputs()?.lookup("softmaxout_1")?; | ||
let result_itensor_float : ITensorFloat = result_lookup.try_query()?; | ||
let result_vector_view = result_itensor_float.get_as_vector_view()?; | ||
println!("Results:"); | ||
print_results(result_vector_view)?; | ||
Ok(()) | ||
} | ||
|
||
// Print the evaluation results. | ||
fn print_results(results: windows::foundation::collections::IVectorView<f32>) -> winrt::Result<()> { | ||
let labels = load_labels()?; | ||
let mut sorted_results : std::vec::Vec<(f32,u32)> = Vec::new(); | ||
for i in 0..results.size()? { | ||
let result = (results.get_at(i)?, i); | ||
sorted_results.push(result); | ||
} | ||
sorted_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); | ||
|
||
// Display the top results | ||
for i in 0..3 { | ||
println!(" {} {}", labels[sorted_results[i].1 as usize], sorted_results[i].0) | ||
} | ||
Ok(()) | ||
} | ||
|
||
// Return the path of the current directory of the executable | ||
fn get_current_dir() -> winrt::Result<String> { | ||
use std::env; | ||
use std::io::Error; | ||
let current_exe = handle_io_error_as_winrt_error!(env::current_exe(), "Failed to get current directory of executable."); | ||
let current_dir = current_exe.parent().unwrap(); | ||
Ok(current_dir.display().to_string()) | ||
} | ||
|
||
// Load all the SqueezeNet labeels and return in a vector of Strings. | ||
fn load_labels() -> winrt::Result<std::vec::Vec<String>> { | ||
use std::io::Error; | ||
use std::fs::File; | ||
use std::io::{prelude::*, BufReader}; | ||
|
||
let mut labels : std::vec::Vec<String> = Vec::new(); | ||
let labels_file_path = get_current_dir()? + "\\Labels.txt"; | ||
let file = handle_io_error_as_winrt_error!(File::open(labels_file_path), "Failed to load labels."); | ||
let reader = BufReader::new(file); | ||
for line in reader.lines() { | ||
let line_str = handle_io_error_as_winrt_error!(line,"Failed to read lines."); | ||
let mut tokenized_line: Vec<&str> = line_str.split(',').collect(); | ||
let index = tokenized_line[0].parse::<usize>().unwrap(); | ||
labels.resize(index+1, "".to_string()); | ||
tokenized_line.remove(0); | ||
labels[index] = tokenized_line.join(""); | ||
} | ||
Ok(labels) | ||
} | ||
|
||
// load image file given a path and return Videoframe | ||
fn load_image_file(image_file_path: String) -> winrt::Result<windows::media::VideoFrame> { | ||
use windows::graphics::imaging::*; | ||
use windows::media::*; | ||
use windows::storage::*; | ||
|
||
let file = StorageFile::get_file_from_path_async(image_file_path)?.get()?; | ||
let stream = file.open_async(FileAccessMode::Read)?.get()?; | ||
let decoder = BitmapDecoder::create_async(&stream)?.get()?; | ||
let software_bitmap = decoder.get_software_bitmap_async()?.get()?; | ||
let image_videoframe = VideoFrame::create_with_software_bitmap(software_bitmap)?; | ||
Ok(image_videoframe) | ||
} |