{"id":14958809,"url":"https://github.com/thekevinscott/ml-classifier","last_synced_at":"2025-08-18T19:31:29.702Z","repository":{"id":32700395,"uuid":"139576119","full_name":"thekevinscott/ml-classifier","owner":"thekevinscott","description":"A tool for quickly training image classifiers in the browser","archived":false,"fork":false,"pushed_at":"2022-12-08T19:13:27.000Z","size":11163,"stargazers_count":113,"open_issues_count":13,"forks_count":16,"subscribers_count":5,"default_branch":"master","last_synced_at":"2024-12-04T23:06:22.147Z","etag":null,"topics":["image-classification","image-classifier","machine-learning","machinelearning","tensorflow","tensorflow-examples","tensorflow-experiments","tensorflow-tutorials","tensorflowjs"],"latest_commit_sha":null,"homepage":"https://thekevinscott.github.io/ml-classifier-ui/","language":"JavaScript","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/thekevinscott.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2018-07-03T11:52:38.000Z","updated_at":"2024-08-06T08:32:34.000Z","dependencies_parsed_at":"2023-01-14T22:00:44.148Z","dependency_job_id":null,"html_url":"https://github.com/thekevinscott/ml-classifier","commit_stats":null,"previous_names":[],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/thekevinscott%2Fml-classifier","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/thekevinscott%2Fml-classifier/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/thekevinscott%2Fml-classifier/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/thekevinscott%2Fml-classifier/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/thekevinscott","download_url":"https://codeload.github.com/thekevinscott/ml-classifier/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":229859898,"owners_count":18135536,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["image-classification","image-classifier","machine-learning","machinelearning","tensorflow","tensorflow-examples","tensorflow-experiments","tensorflow-tutorials","tensorflowjs"],"created_at":"2024-09-24T13:18:19.929Z","updated_at":"2024-12-18T12:10:03.310Z","avatar_url":"https://github.com/thekevinscott.png","language":"JavaScript","funding_links":[],"categories":[],"sub_categories":[],"readme":"# ML Classifier\n\nML Classifier is a machine learning engine for quickly training image classification models in your browser. Models can be saved with a single command, and the resulting models reused to make image classification predictions.\n\nThis package is intended as a companion for [`ml-classifier-ui`](https://github.com/thekevinscott/ml-classifier-ui), which provides a web frontend in React for uploading data and seeing results.\n\n## Walkthrough\n\nA walkthrough of the code can be found in the article [Image Classification in the Browser with Javascript](https://thekevinscott.com/image-classification-with-javascript/).\n\n## Demo\n\nAn interactive [demo can be found here](https://thekevinscott.github.io/ml-classifier-ui/).\n\n![Demo](https://github.com/thekevinscott/ml-classifier-ui/raw/master/example/public/example.gif)\n*Screenshot of demo*\n\n## Getting Started\n\n### Installation\n\n`ml-classifier` can be installed via `yarn` or `npm`:\n\n```\nyarn add ml-classifier\n```\n\nor\n\n```\nnpm install ml-classifier\n```\n\n### Quick Start\n\nStart by instantiating a new MLClassifier.\n\n```\nimport MLClassifier from 'ml-classifier';\n\nconst mlClassifier = new MLClassifier();\n```\n\nThen, train the model:\n\n```\nawait mlClassifier.train(imageData, {\n  callbacks: {\n    onTrainBegin: () =\u003e {\n      console.log('training begins');\n    },\n    onBatchEnd: (batch: any,logs: any) =\u003e {\n      console.log('Loss is: ' + logs.loss.toFixed(5));\n    }\n  },\n});\n```\n\nAnd get predictions:\n\n```\nconst prediction = await mlClassifier.predict(data);\n```\n\nWhen you have a trained model you're happy with, save it with:\n\n```\nmlClassifier.save();\n```\n\n## Using the saved model\n\nWhen you hit save, Tensorflow.js will download a weights file and a model topology file.\n\nYou'll need to combine both into a single `json` file. Open up your model topology file and at the top level of the JSON file, make sure to add a `weightsManifest` key pointing to your weights, like:\n\n```\n{\n  \"weightsManifest\": \"ml-classifier-class1-class2.weights.bin\",\n  \"modelTopology\": {\n    ...\n  }\n}\n```\n\nWhen using the model in your app, there's a few things to keep in mind:\n\n1. You need to make sure you transform images into the correct dimensions, depending on the pretrained model it was trained with. (For MOBILENET, this would be 1x224x224x3).\n2. You must create a pretrained model matching the dimensions used to train. An example is below for MOBILENET.\n3. You must first run your images through the pretrained model to activate them.\n4. After getting the final prediction, you must take the arg max.\n5. You'll get back a number indicating your class.\n\nFull example for MOBILENET:\n\n```\n    const loadImage = (src) =\u003e new Promise((resolve, reject) =\u003e {\n      const image = new Image();\n      image.src = src;\n      image.crossOrigin = 'Anonymous';\n      image.onload = () =\u003e resolve(image);\n      image.onerror = (err) =\u003e reject(err);\n    });\n\n    const pretrainedModelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';\n\n    tf.loadModel(pretrainedModelURL).then(model =\u003e {\n      const layer = model.getLayer('conv_pw_13_relu');\n      return tf.model({\n        inputs: [model.inputs[0]],\n        outputs: layer.output,\n      });\n    }).then(pretrainedModel =\u003e {\n      return tf.loadModel('/model.json').then(model =\u003e {\n        return loadImage('/trees/tree1.png').then(loadedImage =\u003e {\n          const image = tf.reshape(tf.fromPixels(loadedImage), [1,224,224,3]);\n          const pretrainedModelPrediction = pretrainedModel.predict(image);\n          const modelPrediction = model.predict(pretrainedModelPrediction);\n          const prediction = modelPrediction.as1D().argMax().dataSync()[0];\n          console.log(prediction);\n        });\n      });\n    }).catch(err =\u003e {\n      console.error('Error', err);\n    });\n```\n\n## API Documentation\n\nStart by instantiating a new instance of `MLClassifier` with:\n\n```\nconst mlClassifier = new MLClassifier();\n```\n\nThis will begin loading the pretrained model and provide you with an object onto which to add data and train.\n\n### `constructor`\n\n`MLClassifier` accepts a number of callbacks for beginning and end of various methods.\n\nYou can provide a custom pretrained model as a `pretrainedModel`.\n\nYou can provide a custom training model as a `trainingModel`.\n\n#### Parameters\n\n  * **pretrainedModel** (`string | tf.Model`) *Optional* - A string denoting which pretrained model to load from an internal config. Valid strings can be found on the exported object `PRETRAINED_MODELS`. You can also specify a preloaded pretrained model directly.\n  * **trainingModel** (`tf.Model | Function`) *Optional* - A custom model to use during training. Can be provided as a `tf.Model` or as a function that accepts `{xs: [...], ys: [...]`, number of `classes`, and `params` provided to train.\n  * **onLoadStart** (`Function`) *Optional* - A callback for when `load` (loading the pre-trained model) is first called.\n  * **onLoadComplete** (`Function`) *Optional* - A callback for when `load` (loading the pre-trained model) is complete.\n  * **onAddDataStart** (`Function`) *Optional* - A callback for when `addData` is first called.\n  * **onAddDataComplete** (`Function`) *Optional* - A callback for when `addData` is complete.\n  * **onClearDataStart** (`Function`) *Optional* - A callback for when `clearData` is first called.\n  * **onClearDataComplete** (`Function`) *Optional* - A callback for when `clearData` is complete.\n  * **onTrainStart** (`Function`) *Optional* - A callback for when `train` is first called.\n  * **onTrainComplete** (`Function`) *Optional* - A callback for when `train` is complete.\n  * **onEvaluateStart** (`Function`) *Optional* - A callback for when `evaluate` is first called.\n  * **onEvaluateComplete** (`Function`) *Optional* - A callback for when `evaluate` is complete.\n  * **onPredictStart** (`Function`) *Optional* - A callback for when `predict` is first called.\n  * **onPredictComplete** (`Function`) *Optional* - A callback for when `predict` is complete.\n  * **onSaveStart** (`Function`) *Optional* - A callback for when `save` is first called.\n  * **onSaveComplete** (`Function`) *Optional* - A callback for when `save` is complete.\n\n\n#### Example\n```\nimport MLClassifier, {\n  PRETRAINED_MODELS,\n} from 'ml-classifier';\n\nconst mlClassifier = new MLClassifier({\n  pretrainedModel: PRETRAINED_MODELS.MOBILENET,\n\n  onLoadStart: () =\u003e console.log('onLoadStart'),\n  onLoadComplete: () =\u003e console.log('onLoadComplete'),\n  onAddDataStart: () =\u003e console.log('onAddDataStart'),\n  onAddDataComplete: () =\u003e console.log('onAddDataComplete'),\n  onClearDataStart: () =\u003e console.log('onClearDataStart'),\n  onClearDataComplete: () =\u003e console.log('onClearDataComplete'),\n  onTrainStart: () =\u003e console.log('onTrainStart'),\n  onTrainComplete: () =\u003e console.log('onTrainComplete'),\n  onEvaluateStart: () =\u003e console.log('onEvaluateStart'),\n  onEvaluateComplete: () =\u003e console.log('onEvaluateComplete'),\n  onPredictStart: () =\u003e console.log('onPredictStart'),\n  onPredictComplete: () =\u003e console.log('onPredictComplete'),\n  onSaveStart: () =\u003e console.log('onSaveStart'),\n  onSaveComplete: () =\u003e console.log('onSaveComplete'),\n});\n```\n\nExample of specifying a preloaded pretrained model:\n\n```\nimport MLClassifier from 'ml-classifier';\n\nconst mlClassifier = tf.loadModel('... some pretrained model ...').then(model =\u003e {\n  return new MLClassifier({\n    pretrainedModel: model,\n  });\n});\n```\n\n### `addData`\n\nThis method takes an array of incoming images, an optional array of labels, and an optional dataType.\n\n#### Example\n\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, 'train');\n```\n\n#### Parameters\n\n* **images** (`Array\u003ctf.Tensor3D | ImageData | HTMLImageElement | string\u003e`) - an array of 3D tensors, ImageData (output from a canvas `toPixels`, a native browser `Image`, or a string representing the image `src`. Images can be any sizes, but will be cropped and sized down to match the pretrained model.\n* **labels** (`string[]`) - an array of strings, matching the images passed above.\n* **dataType** (`string`) *Optional* - an enum specifying which data type the images match. Data types can be `train` for data used in `model.train()`, and `eval`, for data used in `model.evaluate()`. If no argument is supplied, `dataType` will default to `train`.\n\n#### Returns\n\nNothing.\n\n### `train`\n\n`train` begins training on the given dataset.\n\n#### Example\n\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, DataType.TRAIN);\nmlClassifier.train({\n  callbacks: {\n    onTrainBegin: () =\u003e {\n      console.log('training begins');\n    },\n  },\n});\n```\n\n#### Parameters\n\n* **params** (`Object`) *Optional* - a set of parameters that will be passed directly to `model.fit`. [View the Tensorflow.JS docs](https://js.tensorflow.org/api/0.12.0/#tf.Model.fit) for an up-to-date list of arguments.\n\n#### Returns\n\n`train` returns the resolved promise from `fit`, an object containing loss and accuracy.\n\n## `evaluate`\n\n`evaluate` is used to evaluate a model's performance.\n\n#### Example\n\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, DataType.TRAIN);\nmlClassifier.train();\nmlClassifier.addData(evaluationImages, labels, DataType.EVALUATE);\nmlClassifier.evaluate();\n```\n\n#### Parameters\n\n* **params** (`Object`) *Optional* - a set of parameters that will be passed directly to `model.evaluate`. [View the Tensorflow.JS docs](https://js.tensorflow.org/api/0.12.0/#tf.Sequential.evaluate) for an up-to-date list of arguments.\n\n#### Returns\n\n`evaluate` returns a tf.Scalar representing the result of `evaluate`.\n\n## `predict`\n\n`predict` is used to make a specific prediction using a saved model.\n\n#### Example\n\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, DataType.TRAIN);\nmlClassifier.train();\nmlClassifier.predict(imageToPredict);\n```\n\n#### Parameters\n\n* **image** (`tf.Tensor3D`) - a single image encoded as a `tf.Tensor3D`. Image can be any size, but will be cropped and sized down to match the pretrained model.\n\n#### Returns\n\n`predict` will return a string matching the prediction.\n\n## `save`\n\n`save` is a proxy to `tf.model.save`, and will initiate a download from the browser, or save to local storage.\n\n#### Example\n\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, DataType.TRAIN);\nmlClassifier.train();\nmlClassifier.save(('path-to-save');\n```\n\n#### Parameters\n\n* **handlerOrUrl** (`io.IOHandler | string`) *Optional* - an argument to be passed to `model.save`. If omitted, the model's unique labels will be concatenated together in the form of `class1-class2-class3`.\n* **params** (`Object`) *Optional* - a set of parameters that will be passed directly to `model.save`. [View the Tensorflow.JS docs](https://js.tensorflow.org/api/0.12.0/#tf.Model.save) for an up-to-date list of arguments.\n\n\n## `getModel`\n\n`getModel` will return the trained Tensorflow.js model. Calling this method prior to calling `mlClassifier.train` will return `null`.\n\n#### Example\n\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, DataType.TRAIN);\nmlClassifier.train();\nmlClassifier.getModel();\n```\n\n#### Parameters\n\nNone.\n\n#### Returns\n\nThe saved Tensorflow.js model.\n\n## `clearData`\n\n`clearData` will clear out saved data.\n\n#### Example\n```\nimport MLClassifier from 'ml-classifier';\nconst mlClassifier = new MLClassifier();\nmlClassifier.addData(images, labels, DataType.TRAIN);\nmlClassifier.clearData(DataType.TRAIN);\n```\n\n#### Parameters\n\n* **dataType** (`DataType`) *Optional* - specifies which data to clear. If no argument is provided, all data will be cleared.\n\n#### Returns\n\nNothing.\n\n## Contributing\n\nContributions are welcome!\n\nYou can start up a local copy of `ml-classifier` with:\n\n```\nyarn watch\n```\n\n`ml-classifier` is written in Typescript.\n\n### Tests\n\nTests are a work in progress. Currently, the test suite only consists of unit tests. Pull requests for additional tests are welcome!\n\nRun tests with:\n\n```\nyarn test\n```\n\n## Author\n\n* [Kevin Scott](https://thekevinscott.com)\n\n## License\n\nThis project is licensed under the MIT License - see the LICENSE file for details\n\n![](https://ga-beacon.appspot.com/UA-112845439-4/ml-classifier/readme)\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fthekevinscott%2Fml-classifier","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fthekevinscott%2Fml-classifier","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fthekevinscott%2Fml-classifier/lists"}