Skip to content

Commit

Permalink
Generalize KNN Image Classifier to KNN Classifier, taking an arbitrar…
Browse files Browse the repository at this point in the history
…y embedding as input. (tensorflow#36)

This PR generalizes the KNN Image Classifier to be a generic KNN Classifier for any embedding.

This PR also:
- Removes the dependency on MobileNet. It is now the user's responsibility to pass a MobileNet embedding to the model.
  • Loading branch information
Nikhil Thorat authored Jun 19, 2018
1 parent f1a5eea commit 790ac1c
Show file tree
Hide file tree
Showing 18 changed files with 485 additions and 363 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"[javascript]": {
"editor.formatOnSave": true
},
"editor.rulers": [80],
"clang-format.style": "Google",
"files.insertFinalNewline": true,
"editor.detectIndentation": false,
Expand Down
File renamed without changes.
143 changes: 143 additions & 0 deletions knn-classifier/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# KNN Classifier

This package provides a utility for creating a classifier using the
[K-Nearest Neighbors](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)
algorithm.

This package is different from the other packages in this repository in that it
doesn't provide a model with weights, but rather a utility for constructing a
KNN model using activations from another model or any other tensors you can
associate with a class.

You can see example code [here](https://github.com/tensorflow/tfjs-models/tree/master/knn-classifier/demo).

## Usage example

```js
import * as tf from '@tensorflow/tfjs';
import * as mobilenetModule from '@tensorflow-models/mobilenet';
import * as knnClassifier from '@tensorflow-models/knn-classifier';

// Create the classifier.
const classifier = knnClassifier.create();

// Load mobilenet.
const mobilenet = await mobilenetModule.load();

// Add MobileNet activations to the model repeatedly for all classes.
const img0 = tf.fromPixels(...);
const logits0 = mobilenet.infer(img0, 'conv_preds');
classifier.addExample(logits, 0);

const img1 = tf.fromPixels(...);
const logits1 = mobilenet.infer(img1, 'conv_preds');
classifier.addExample(logits, 1);

// Make a prediction.
const x = tf.fromPixels(...);
const xlogits = mobilenet.infer(x, 'conv_preds');
console.log('Predictions:');
console.log(classifier.predictClass(xlogits));
```

## API

#### Creating a classifier
`knnClassifier` is the module name, which is automatically included when you use
the <script src> method.

```ts
classifier = knnClassifier.create()
```

Returns a `KNNImageClassifier`.

#### Adding examples

```ts
classifier.addExample(
example: tf.Tensor,
classIndex: number
): void;
```

Args:
- **example:** An example to add to the dataset, usually an activation from
another model.
- **classIndex:** The class index of the example.

#### Making a prediction

```ts
classifier.predictClass(
input: tf.Tensor,
k = 3
): Promise<{classIndex: number, confidences: {[classId: number]: number}}>;
```

Args:
- **input:** An example to make a prediction on, usually an activation from
another model.
- **k:** The K value to use in K-nearest neighbors. The algorithm will first
find the K nearest examples from those it was previously shown, and then choose
the class that appears the most as the final prediction for the input example.
Defaults to 3. If examples < k, k = examples.

Returns an object with a top classIndex, and confidences mapping all class
indices to their confidence.

#### Misc

##### Clear all examples for a class.

```ts
classifier.clearClass(classIndex: number)
```

Args:
- **classIndex:** The class to clear all examples for.

##### Clear all examples from all classes

```ts
classifier.clearAllClasses()
```

##### Get the example count for each class

```ts
classifier.getClassExampleCount(): {[classId: number]: number}
```

Returns an object that maps classId to example count for that class.

##### Get the full dataset, useful for saving state.

```ts
classifier.getClassifierDataset(): {[classId: number]: Tensor2D}
```

##### Set the full dataset, useful for restoring state.

```ts
classifier.setClassifierDataset(dataset: {[classId: number]: Tensor2D})
```

Args:
- **dataset:** The class dataset matrices map. Can be retrieved from
getClassDatsetMatrices. Useful for restoring state.

##### Get the total number of classes

```ts
classifier.getNumClasses(): number
```

##### Dispose the classifier and all internal state

Clears up WebGL memory. Useful if you no longer need the classifier in your
application.

```ts
classifier.dispose()
```
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The camera demo shows how to create a custom classifier with 3 classes that can
cd into the demos folder:

```sh
cd knn-image-classifier/demos
cd knn-classifier/demos
```

Install dependencies and prepare the build directory:
Expand All @@ -33,17 +33,17 @@ Install yalc:
npm i -g yalc
```

cd into the knn-image-classifier folder:
cd into the knn-classifier folder:
```sh
cd knn-image-classifier
cd knn-classifier
```

Install dependencies:
```sh
yarn
```

Publish knn-image-classifier locally:
Publish knn-classifier locally:
```sh
yalc push
```
Expand All @@ -55,19 +55,7 @@ cd demos
yarn
```

Link the local knn-image-classifier to the demos:
```sh
yalc link \@tensorflow-models/knn-image-classifier
```

Start the dev demo server:
```sh
yarn watch
```

To get future updates from the knn-image-classifier source code:
```
# cd up into the knn-image-classifier directory
cd ../
yarn build && yalc push
```
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
* limitations under the License.
* =============================================================================
*/
import * as mobilenetModule from '@tensorflow-models/mobilenet';
import * as tf from '@tensorflow/tfjs';
import * as knn from '@tensorflow-models/knn-image-classifier';
import Stats from 'stats.js';

import * as knnClassifier from '../src/index';

const videoWidth = 300;
const videoHeight = 250;
const stats = new Stats();
Expand All @@ -26,11 +28,12 @@ const stats = new Stats();
const NUM_CLASSES = 3;

// K value for KNN
const TOPK = 10;
const TOPK = 3;

const infoTexts = [];
let training = -1;
let model;
let classifier;
let mobilenet;
let video;

function isAndroid() {
Expand All @@ -52,7 +55,7 @@ function isMobile() {
async function setupCamera() {
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
throw new Error(
'Browser API navigator.mediaDevices.getUserMedia not available');
'Browser API navigator.mediaDevices.getUserMedia not available');
}

const video = document.getElementById('video');
Expand Down Expand Up @@ -105,61 +108,59 @@ function setupGui() {
}
}

/**
* Load the KNN model
*/
async function loadKNN() {
const model = await knn.load(NUM_CLASSES, TOPK);
return model;
}

/**
* Sets up a frames per second panel on the top-left of the window
*/
function setupFPS() {
stats.showPanel(0); // 0: fps, 1: ms, 2: mb, 3+: custom
stats.showPanel(0); // 0: fps, 1: ms, 2: mb, 3+: custom
document.body.appendChild(stats.dom);
}

/**
* Animation function called on each frame, running prediction
*/
function animate() {
async function animate() {
stats.begin();

// Get image data from video element
const image = tf.fromPixels(video);
let logits;
// 'conv_preds' is the logits activation of MobileNet.
const infer = () => mobilenet.infer(image, 'conv_preds');

// Train class if one of the buttons is held down
if (training != -1) {
logits = infer();
// Add current image to classifier
model.addImage(image, training);
classifier.addExample(logits, training);
}

// If the classifier has examples for any classes, make a prediction!
const numClasses = classifier.getNumClasses();
if (numClasses > 0) {
logits = infer();

const res = await classifier.predictClass(logits, TOPK);
for (let i = 0; i < NUM_CLASSES; i++) {
// Make the predicted class bold
if (res.classIndex == i) {
infoTexts[i].style.fontWeight = 'bold';
} else {
infoTexts[i].style.fontWeight = 'normal';
}

const classExampleCount = classifier.getClassExampleCount();
// Update info text
if (classExampleCount[i] > 0) {
const conf = res.confidences[i] * 100;
infoTexts[i].innerText = ` ${classExampleCount[i]} examples - ${conf}%`;
}
}
}

// If any examples have been added, run predict
const exampleCount = model.getClassExampleCount();
if (Math.max(...exampleCount) > 0) {
model.predictClass(image)
.then((res) => {
for (let i = 0; i < NUM_CLASSES; i++) {
// Make the predicted class bold
if (res.classIndex == i) {
infoTexts[i].style.fontWeight = 'bold';
} else {
infoTexts[i].style.fontWeight = 'normal';
}

// Update info text
if (exampleCount[i] > 0) {
const conf = res.confidences[i] * 100;
infoTexts[i].innerText = ` ${exampleCount[i]} examples - ${conf}%`;
}
}
})
// Dispose image when done
.then(() => image.dispose());
} else {
image.dispose();
image.dispose();
if (logits != null) {
logits.dispose();
}

stats.end();
Expand All @@ -172,8 +173,8 @@ function animate() {
* available camera devices, and setting off the animate function.
*/
export async function bindPage() {
// Load the KNN model
model = await loadKNN();
classifier = knnClassifier.create();
mobilenet = await mobilenetModule.load();

document.getElementById('loading').style.display = 'none';
document.getElementById('main').style.display = 'block';
Expand All @@ -189,7 +190,7 @@ export async function bindPage() {
} catch (e) {
let info = document.getElementById('info');
info.textContent = 'this browser does not support video capture,' +
'or this device does not have a camera';
'or this device does not have a camera';
info.style.display = 'block';
throw e;
}
Expand All @@ -199,6 +200,6 @@ export async function bindPage() {
}

navigator.getUserMedia = navigator.getUserMedia ||
navigator.webkitGetUserMedia || navigator.mozGetUserMedia;
navigator.webkitGetUserMedia || navigator.mozGetUserMedia;
// kick off the demo
bindPage();
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"node": ">=8.9.0"
},
"dependencies": {
"@tensorflow-models/knn-image-classifier": "0.1.0",
"@tensorflow-models/mobilenet": "0.1.1",
"@tensorflow/tfjs": "0.11.4",
"stats.js": "^0.17.0"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"

"@tensorflow-models/posenet@0.1.2":
version "0.1.2"
resolved "https://registry.yarnpkg.com/@tensorflow-models/posenet/-/posenet-0.1.2.tgz#621849eaddc53a6a1fd5a34ccf6a22329addc6ee"
"@tensorflow-models/mobilenet@0.1.1":
version "0.1.1"
resolved "https://registry.yarnpkg.com/@tensorflow-models/mobilenet/-/mobilenet-0.1.1.tgz#bfe159b0e0abee31da421cc36e5051e4622bcd03"

"@tensorflow/[email protected]":
version "0.4.1"
Expand Down
Loading

0 comments on commit 790ac1c

Please sign in to comment.