forked from tensorflow/tfjs-models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generalize KNN Image Classifier to KNN Classifier, taking an arbitrar…
…y embedding as input. (tensorflow#36) This PR generalizes the KNN Image Classifier to be a generic KNN Classifier for any embedding. This PR also: - Removes the dependency on MobileNet. It is now the user's responsibility to pass a MobileNet embedding to the model.
- Loading branch information
Nikhil Thorat
authored
Jun 19, 2018
1 parent
f1a5eea
commit 790ac1c
Showing
18 changed files
with
485 additions
and
363 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# KNN Classifier | ||
|
||
This package provides a utility for creating a classifier using the | ||
[K-Nearest Neighbors](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) | ||
algorithm. | ||
|
||
This package is different from the other packages in this repository in that it | ||
doesn't provide a model with weights, but rather a utility for constructing a | ||
KNN model using activations from another model or any other tensors you can | ||
associate with a class. | ||
|
||
You can see example code [here](https://github.com/tensorflow/tfjs-models/tree/master/knn-classifier/demo). | ||
|
||
## Usage example | ||
|
||
```js | ||
import * as tf from '@tensorflow/tfjs'; | ||
import * as mobilenetModule from '@tensorflow-models/mobilenet'; | ||
import * as knnClassifier from '@tensorflow-models/knn-classifier'; | ||
|
||
// Create the classifier. | ||
const classifier = knnClassifier.create(); | ||
|
||
// Load mobilenet. | ||
const mobilenet = await mobilenetModule.load(); | ||
|
||
// Add MobileNet activations to the model repeatedly for all classes. | ||
const img0 = tf.fromPixels(...); | ||
const logits0 = mobilenet.infer(img0, 'conv_preds'); | ||
classifier.addExample(logits, 0); | ||
|
||
const img1 = tf.fromPixels(...); | ||
const logits1 = mobilenet.infer(img1, 'conv_preds'); | ||
classifier.addExample(logits, 1); | ||
|
||
// Make a prediction. | ||
const x = tf.fromPixels(...); | ||
const xlogits = mobilenet.infer(x, 'conv_preds'); | ||
console.log('Predictions:'); | ||
console.log(classifier.predictClass(xlogits)); | ||
``` | ||
|
||
## API | ||
|
||
#### Creating a classifier | ||
`knnClassifier` is the module name, which is automatically included when you use | ||
the <script src> method. | ||
|
||
```ts | ||
classifier = knnClassifier.create() | ||
``` | ||
|
||
Returns a `KNNImageClassifier`. | ||
|
||
#### Adding examples | ||
|
||
```ts | ||
classifier.addExample( | ||
example: tf.Tensor, | ||
classIndex: number | ||
): void; | ||
``` | ||
|
||
Args: | ||
- **example:** An example to add to the dataset, usually an activation from | ||
another model. | ||
- **classIndex:** The class index of the example. | ||
|
||
#### Making a prediction | ||
|
||
```ts | ||
classifier.predictClass( | ||
input: tf.Tensor, | ||
k = 3 | ||
): Promise<{classIndex: number, confidences: {[classId: number]: number}}>; | ||
``` | ||
|
||
Args: | ||
- **input:** An example to make a prediction on, usually an activation from | ||
another model. | ||
- **k:** The K value to use in K-nearest neighbors. The algorithm will first | ||
find the K nearest examples from those it was previously shown, and then choose | ||
the class that appears the most as the final prediction for the input example. | ||
Defaults to 3. If examples < k, k = examples. | ||
|
||
Returns an object with a top classIndex, and confidences mapping all class | ||
indices to their confidence. | ||
|
||
#### Misc | ||
|
||
##### Clear all examples for a class. | ||
|
||
```ts | ||
classifier.clearClass(classIndex: number) | ||
``` | ||
|
||
Args: | ||
- **classIndex:** The class to clear all examples for. | ||
|
||
##### Clear all examples from all classes | ||
|
||
```ts | ||
classifier.clearAllClasses() | ||
``` | ||
|
||
##### Get the example count for each class | ||
|
||
```ts | ||
classifier.getClassExampleCount(): {[classId: number]: number} | ||
``` | ||
|
||
Returns an object that maps classId to example count for that class. | ||
|
||
##### Get the full dataset, useful for saving state. | ||
|
||
```ts | ||
classifier.getClassifierDataset(): {[classId: number]: Tensor2D} | ||
``` | ||
|
||
##### Set the full dataset, useful for restoring state. | ||
|
||
```ts | ||
classifier.setClassifierDataset(dataset: {[classId: number]: Tensor2D}) | ||
``` | ||
|
||
Args: | ||
- **dataset:** The class dataset matrices map. Can be retrieved from | ||
getClassDatsetMatrices. Useful for restoring state. | ||
|
||
##### Get the total number of classes | ||
|
||
```ts | ||
classifier.getNumClasses(): number | ||
``` | ||
|
||
##### Dispose the classifier and all internal state | ||
|
||
Clears up WebGL memory. Useful if you no longer need the classifier in your | ||
application. | ||
|
||
```ts | ||
classifier.dispose() | ||
``` |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,9 +45,9 @@ | |
version "1.1.0" | ||
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570" | ||
|
||
"@tensorflow-models/posenet@0.1.2": | ||
version "0.1.2" | ||
resolved "https://registry.yarnpkg.com/@tensorflow-models/posenet/-/posenet-0.1.2.tgz#621849eaddc53a6a1fd5a34ccf6a22329addc6ee" | ||
"@tensorflow-models/mobilenet@0.1.1": | ||
version "0.1.1" | ||
resolved "https://registry.yarnpkg.com/@tensorflow-models/mobilenet/-/mobilenet-0.1.1.tgz#bfe159b0e0abee31da421cc36e5051e4622bcd03" | ||
|
||
"@tensorflow/[email protected]": | ||
version "0.4.1" | ||
|
Oops, something went wrong.