Skip to content
This repository has been archived by the owner on May 9, 2022. It is now read-only.

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert-Steiner committed Apr 28, 2021
1 parent c8a779d commit 6d85084
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions xayn-ai-ffi-wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ xayn-ai = { path = "../xayn-ai" }

[dev-dependencies]
wasm-bindgen-test = "0.3.23"
itertools = "0.10.0"

[lib]
crate-type = ["cdylib"]
Expand Down
170 changes: 158 additions & 12 deletions xayn-ai-ffi-wasm/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,31 +119,177 @@ impl WXaynAi {
#[cfg(target_arch = "wasm32")]
#[cfg(test)]
mod tests {
use super::*;

use std::iter::repeat;

use itertools::izip;
use wasm_bindgen_test::wasm_bindgen_test;
use xayn_ai::{Relevance, UserFeedback};

use super::*;
use crate::error::ExternError;

/// Path to the current vocabulary file.
pub const VOCAB: &[u8] = include_bytes!("../../data/rubert_v0000/vocab.txt");

/// Path to the current onnx model file.
pub const MODEL: &[u8] = include_bytes!("../../data/rubert_v0000/model.onnx");

fn test_histories() -> Vec<JsValue> {
let len = 6;
let ids = (0..len).map(|idx| idx.to_string()).collect::<Vec<_>>();

let relevances = repeat(Relevance::Low)
.take(len / 2)
.chain(repeat(Relevance::High).take(len - len / 2));
let feedbacks = repeat(UserFeedback::Irrelevant)
.take(len / 2)
.chain(repeat(UserFeedback::Relevant).take(len - len / 2));

let history = izip!(ids, relevances, feedbacks)
.map(|(id, relevance, user_feedback)| {
JsValue::from_serde(&DocumentHistory {
id: id.as_str().into(),
relevance,
user_feedback,
})
.unwrap()
})
.collect::<Vec<_>>();

history
}

fn test_documents() -> Vec<JsValue> {
let len = 10;
let ids = (0..len).map(|idx| idx.to_string()).collect::<Vec<_>>();

let snippets = (0..len)
.map(|idx| format!("snippet {}", idx))
.collect::<Vec<_>>();
let ranks = 0..len as usize;

let document = izip!(ids, snippets, ranks)
.map(|(id, snippet, rank)| {
JsValue::from_serde(&Document {
id: id.as_str().into(),
snippet,
rank,
})
.unwrap()
})
.collect::<Vec<_>>();

document
}

#[wasm_bindgen_test]
fn test_reranker() {
fn test_rerank() {
let mut xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
xaynai.rerank(test_histories(), test_documents()).unwrap();
}

let document = Document {
id: "1".into(),
rank: 0,
snippet: "abc".to_string(),
};
let js_document = JsValue::from_serde(&document).unwrap();
#[wasm_bindgen_test]
fn test_serialize() {
let xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
xaynai.serialize().unwrap();
}

#[wasm_bindgen_test]
fn test_faults() {
let xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
let faults = xaynai.faults();
assert!(faults.is_empty())
}

#[wasm_bindgen_test]
fn test_analytics() {
let xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
let analytics = xaynai.analytics();
assert!(analytics.is_null())
}

#[wasm_bindgen_test]
fn test_vocab_invalid() {
let error: ExternError = WXaynAi::new(&[], MODEL, None)
.map_err(|e| JsValue::into_serde(&e).unwrap())
.err()
.unwrap();

let ranks = xaynai.rerank(vec![], vec![js_document]).unwrap();
assert_eq!(ranks, [0]);
assert_eq!(error.code, CCode::InitAi as i32);
assert!(error
.message
.contains("Failed to initialize the ai: Failed to build the tokenizer: "));
}

// #[wasm_bindgen_test]
// fn test_model_invalid() {
// let error: ExternError = WXaynAi::new(VOCAB, &[], None)
// .map_err(|e| JsValue::into_serde(&e).unwrap())
// .err()
// .unwrap();

// assert_eq!(error.code, CCode::InitAi as i32);
// assert!(error.message.contains("Failed to initialize the ai: "));
// }

#[wasm_bindgen_test]
fn test_history_invalid() {
let mut xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
let error: ExternError = xaynai
.rerank(vec![JsValue::from("invalid")], test_documents())
.map_err(|e| JsValue::into_serde(&e).unwrap())
.err()
.unwrap();

assert_eq!(error.code, CCode::HistoriesDeserialization as i32);
assert!(error
.message
.contains("Failed to deserialize the collection of histories: invalid type: "));
}

#[wasm_bindgen_test]
fn test_history_empty() {
let mut xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
xaynai.rerank(vec![], test_documents()).unwrap();
}

#[wasm_bindgen_test]
fn test_documents_invalid() {
let mut xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
let error: ExternError = xaynai
.rerank(test_histories(), vec![JsValue::from("invalid")])
.map_err(|e| JsValue::into_serde(&e).unwrap())
.err()
.unwrap();

assert_eq!(error.code, CCode::DocumentsDeserialization as i32);
assert!(error
.message
.contains("Failed to deserialize the collection of documents: invalid type: "));
}

#[wasm_bindgen_test]
fn test_documents_empty() {
let mut xaynai = WXaynAi::new(VOCAB, MODEL, None).unwrap();
xaynai.rerank(test_histories(), vec![]).unwrap();
}

#[wasm_bindgen_test]
fn test_serialized_empty() {
WXaynAi::new(VOCAB, MODEL, Some(Box::new([]))).unwrap();
}

#[wasm_bindgen_test]
fn test_serialized_invalid() {
let error: ExternError = WXaynAi::new(VOCAB, MODEL, Some(Box::new([1, 2, 3])))
.map_err(|e| JsValue::into_serde(&e).unwrap())
.err()
.unwrap();

let ser = xaynai.serialize().unwrap();
assert!(ser.length() != 0)
assert_eq!(error.code, CCode::RerankerDeserialization as i32);
assert!(error.message.contains(
"Failed to deserialize the reranker database: Unsupported serialized data. "
));
}
}
6 changes: 4 additions & 2 deletions xayn-ai-ffi-wasm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::utils::IntoJsResult;

// placeholder / later we can have a crate that contains common code for c-ffi and wasm
#[repr(i32)]
#[cfg_attr(test, derive(Clone, Copy, Debug))]
pub enum CCode {
/// A warning or uncritical error.
Fault = -2,
Expand All @@ -28,9 +29,10 @@ impl CCode {
}

#[derive(Serialize)]
#[cfg_attr(test, derive(serde::Deserialize, Debug))]
pub struct ExternError {
code: i32,
message: String,
pub(crate) code: i32,
pub(crate) message: String,
}

impl ExternError {
Expand Down
2 changes: 1 addition & 1 deletion xayn-ai/src/analytics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::{cmp::Ordering, collections::HashMap};

use anyhow::bail;
use displaydoc::Display;
use thiserror::Error;
use serde::Serialize;
use thiserror::Error;

use crate::{
data::{
Expand Down
6 changes: 3 additions & 3 deletions xayn-ai/src/data/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct Document {
pub snippet: String,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct DocumentHistory {
/// unique identifier of this document
pub id: DocumentId,
Expand All @@ -30,14 +30,14 @@ pub struct DocumentHistory {
pub user_feedback: UserFeedback,
}

#[derive(Clone, Copy, Debug, PartialEq, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum UserFeedback {
Relevant,
Irrelevant,
None,
}

#[derive(Clone, Copy, Debug, PartialEq, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum Relevance {
Low,
Medium,
Expand Down

0 comments on commit 6d85084

Please sign in to comment.