M-% || M-x query-replace

Sachin’s random thoughts.

06 Feb 2024

Rust WASM Neural Net PART 1 - Overview

Introduction

I created a Neural Net that sits in your browser1 to detect handwritten digits2 (primarily from the MNIST data set). This is built with rust and compiled to WebAssembly to run with some speed3.

main link

probably want to train a bit before trying an example

What is this?

Architecture

W e b M s o i d t e e l W e i g h t s / D a t a A P I

Model

code

The model has the following features

  • 1 layer of ReLU
  • 1 layer of LogSoftmax
  • Written in Rust only with the help of the ndarray package
  • Optimized for single sample and batch training

Website

code

The website has the following features and more:

  • Built with the yew framework
  • Training of the model in a parallel web worker
  • Pre-cached samples with async replenishment
  • Loading of samples on-demand for inference
  • Grid to draw your own digit
  • Various inputs to tune the training and inference

API

code

The api has the following features and more:

  • Load an MNIST sample
  • Load a block of MNIST samples
  • Getter and Setter for Model weights

How did this happen?

  • 2 Weeks: I want to build a Neural Net with Rust as a learning experience4. It will be interesting to implement the activation functions and backpropogation.
  • 1 Months: Rust and WASM play nice, would it be possible to put the model on the web.
  • 2 Months: If I put this on the web, then I better use Rust for the rest of the website functionality.
  • 6 Months: If I have a website, I better be able to train it, cache effectively, and run this in parallel5 to the rendering
%%{init: { 'logLevel': 'debug', 'theme': 'default', 'themeVariables': {
              'git0': '#93e0e3',
              'git1': '#cc9393',
              'git2': '#f0dfaf',
              'git3': '#7f9f7f',
              'gitBranchLabel0': '#3f3f3f',
              'gitBranchLabel1': '#3f3f3f',
              'gitBranchLabel2': '#3f3f3f',
              'gitBranchLabel3': '#3f3f3f',
              'commitLabelColor': '#dcdccc',
              'commitLabelBackground': '#383838',
              'tagLabelColor': '#dcdccc',
              'tagLabelBackground': '#3f3f3f',
              'tagLabelBorder': '#303030'
       } } }%%
---
title: Timeline 
---
gitGraph TB:
   commit id: "Start Project" tag: "0 Weeks"
   branch model
   checkout model
   commit id: "Activation Functions"
   commit id: "Test/Train"
   commit id: "Structure/Clean"
   checkout main
   merge model id: "Start UI/Finish Model" tag: "2 Weeks"
   branch site
   checkout site
   commit id: "Initial UI"
   checkout main
   commit id: "Start API" tag: "3 Weeks"
   branch api
   checkout api
   commit id: "Initial API"
   commit id: "Automate training"
   checkout main
   merge api id: "Finish API" tag: "1 Month"
   checkout site
   commit id: "Put model in WASM"
   checkout main
   merge site id: "Midpoint UI" tag: "2 Months"
   checkout site
   commit id: "Parallely cache data"
   commit id: "Deploy"
   checkout main
   merge site id: "Finish UI" tag: "6 Months"
   commit id: "Finish Project" tag: "6.25 Months"

Everyone, this is what we call scope creep.

What is next?

I will be creating 3 more posts:

Part 2

Details about the model with a little bit of math and code

Part 3

Details about the frontend, and how I was able to jam in all the features I desired

Part 4

Details about the api, deployment, and integration


  1. By this, I mean the computations and model all run on your machine, instead of some server somewhere. This is also known as running client side. ↩︎

  2. Trained on digits written by Census Bereau Employees and high-school students source ↩︎

  3. Not a ton of speed…but some speed. ↩︎

  4. This is heavily inspired by covnetjs ↩︎

  5. And complicate everything by using web workers ↩︎