M-% || M-x query-replace

Sachin’s random thoughts.

18 Feb 2024

Rust WASM Neural Net PART 3 - Website

Components

Main Components (organized by thread1):

  • UI rendering and interaction
  • Model training and inference
stateDiagram
   direction TB

   classDef colors fill:#3f3f3f,color:#dcdccc,stroke-width:1.5px,stroke:#dcdccc,font-family:Hack
   ui: UI Worker
   infer: Inference Process
   api: Weights Process
   model: Model Worker
   data1: Data Process
   data2: Data Process
   data3: Data Process

   ui --> infer
   ui --> model
   infer --> ui
   ui --> api
   model --> ui
   api --> ui
   model --> data1
   model --> data2
   model --> data3
   data1 --> model
   data2 --> model
   data3 --> model

   class ui colors
   class infer colors
   class api colors
   class model colors
   class data1 colors
   class data2 colors
   class data3 colors

Essentially, the rendering thread offloads all model interaction to a web worker, which then spawns data processes2 to replenish a cache of data that is then used to train the model. Inference is computationally light3 and done on the rendering worker directly. The tricky part is management of these workers/processes to create responsiveness in the UI while training efficiently.

Places to Compute

Web Workers

From MDN Web Docs

Web Workers makes it possible to run a script operation in a background thread separate from the main execution thread of a web application. The advantage of this is that laborious processing can be performed in a separate thread, allowing the main (usually the UI) thread to run without being blocked/slowed down.

Web workers are super useful, but creating communication channels can be confusing4. In an earlier version, I tried to do all my data caching in a separate web worker, but to simplify communication I decided to cache in a process. However, running the model in it’s own web worker allowed rendering to run smoothly and independently.

Processes

I created processes mostly with spawn_local from wasm_bindgen_futures, which is very simply:

#[inline]
pub fn spawn_local<F>(future: F)
where
    F: Future<Output = ()> + 'static,
{
    task::Task::spawn(Box::pin(future));
}

spawn_local combined with an async closure (through async move) and a handle5, allows you put some computations in the background and not block the main thread.

Callbacks

Yew Callbacks essentially take a function, and attach to a dom element to run when there is user input/interaction. This is how any input propagates into the rust code/WASM.

Main interface

I used the yew framework. I think that the library itself was somewhat unstable6. Next time, I would do some sort of hybrid solution where I write the UI in JavaScript/Typescript and interact with WASM blobs as necessary.

Grid

2 8 x C C C 2 e e e 8 l l l l l l C C C M e e e a l l l n l l l a g e r C C C e e e l l l l l l

The GridCell manage rendering and User input, while the GridManager manages overall state and sample loading. This is through a tangle of states and props7.

%%{
  init: {
    'theme': 'base',
    'themeVariables': {
      'primaryColor': '#3f3f3f',
      'primaryTextColor': '#dcdccc',
      'primaryBorderColor': '#dcdccc',
      'lineColor': '#303030',
      'secondaryColor': '#93e0e3',
      'tertiaryColor': '#7f9f7f',
      'fontFamily': 'Hack'
    }
  }
}%%
sequenceDiagram
    GridManager->>GridCell: Manager Callback + Initial State
    GridCell->>GridManager: Cell Callback
    GridCell-->>GridManager: User Update + New State
    GridManager-->>GridCell: Sample Update

Other stuff

I used tailwind css by compiling in the css with trunk and that is about it8.

Running the Model

Training

Communication

%%{
  init: {
    'theme': 'base',
    'themeVariables': {
      'primaryColor': '#3f3f3f',
      'primaryTextColor': '#dcdccc',
      'primaryBorderColor': '#dcdccc',
      'lineColor': '#303030',
      'secondaryColor': '#93e0e3',
      'tertiaryColor': '#7f9f7f',
      'fontFamily': 'Hack'
    }
  }
}%%
sequenceDiagram
    Renderer->>Model: Start Training
    Renderer->>Model: Stop Training
    Renderer->>Model: Send status
    Renderer->>Model: Set weights
    Renderer->>Model: Set batch size
    Renderer->>Model: Set learning rate
    Renderer->>Model: Set cache size
    Model->>Renderer: ResponseSignal

Where ResponseSignal is:

type Weights {
    weights: (Vec<Vec<f64>>, Vec<Vec<f64>>)
}
type ResponseSignal {
    weights: Weights
    loss: f64
    acc: f64
    batch_size: usize
    lrate: f64
    data_len: usize
    data_futures_len: usize
    iteration: usize
    cache_size: usize
}

The Model worker waits for a signal from the Renderer worker, and responds with a ResponseSignal. If the Model worker is training, then it sends a ResponseSignal every iteration.

Getting Data

The Model maintains an internal cache that is lazily replenished9. Every iteration (whether training or not), new futures are created to replenish the cache. If the model is training, but the cache is empty, then it delays that iteration10.

The trick to bringing the data back into the Model worker is to send a handle along with the spawn_local as described earlier.

type DataSingle {
    target: u8
    image: Vec<f64>
}
type Data {
    data: Vec<DataSingle>
}
type ModelData {
    data_vec: Arc<Mutex<VecDeque<Data>>>
}

In short, I spawn futures like this

for _ in 0..future_num {
    let data_vec_handle = self.data_vec.clone();
    spawn_local(async move {
        let data = get_block().await;
        data_vec_handle.lock().unwrap().push_back(data);
    });
}

Then I can hand the data and model to a small function to essentially just call a train function, and I update my model.


  1. Threads in the sense of parallel processes. Everything runs on a single thread because it runs in your browser, however web workers are “threads” ↩︎

  2. Through the use of spawn_local in wasm_bindgen_futures ↩︎

  3. also through spawn_local in wasm_bindgen_futures ↩︎

  4. Two reasons - you can only send JS types (objects need to be Send and Sync) as well as through the postMessage interface only ↩︎

  5. handle in the generic term could be a yew state, Arc<Mutex<T>>, or something else ↩︎

  6. I had to pin to a specific commit to get the newest features without breaking my project every couple of weeks. ↩︎

  7. Yew made this part especially hard, and I think I should have done it in React ↩︎

  8. I thought this would be much more painful than it was ↩︎

  9. Without a cache, you have to wait for the data to load every single iteration. One request runs in similar time as 5 parallel requests, so it was essential to parallelize. ↩︎

  10. The delay overhead is minimial, so it like polling until the data arrives (which waiting from signals from Renderer thread) ↩︎