Rust WASM Neural Net PART 3 - Website
Components
Main Components (organized by thread1):
- UI rendering and interaction
- Model training and inference
stateDiagram direction TB classDef colors fill:#3f3f3f,color:#dcdccc,stroke-width:1.5px,stroke:#dcdccc,font-family:Hack ui: UI Worker infer: Inference Process api: Weights Process model: Model Worker data1: Data Process data2: Data Process data3: Data Process ui --> infer ui --> model infer --> ui ui --> api model --> ui api --> ui model --> data1 model --> data2 model --> data3 data1 --> model data2 --> model data3 --> model class ui colors class infer colors class api colors class model colors class data1 colors class data2 colors class data3 colors
Essentially, the rendering thread offloads all model interaction to a web worker, which then spawns data processes2 to replenish a cache of data that is then used to train the model. Inference is computationally light3 and done on the rendering worker directly. The tricky part is management of these workers/processes to create responsiveness in the UI while training efficiently.
Places to Compute
Web Workers
From MDN Web Docs
Web Workers makes it possible to run a script operation in a background thread separate from the main execution thread of a web application. The advantage of this is that laborious processing can be performed in a separate thread, allowing the main (usually the UI) thread to run without being blocked/slowed down.
Web workers are super useful, but creating communication channels can be confusing4. In an earlier version, I tried to do all my data caching in a separate web worker, but to simplify communication I decided to cache in a process. However, running the model in it’s own web worker allowed rendering to run smoothly and independently.
Processes
I created processes mostly with spawn_local from wasm_bindgen_futures, which is very simply:
#[inline]
pub fn spawn_local<F>(future: F)
where
F: Future<Output = ()> + 'static,
{
task::Task::spawn(Box::pin(future));
}
spawn_local combined with an async closure (through async move) and a handle5, allows you put some computations in the background and not block the main thread.
Callbacks
Yew Callbacks essentially take a function, and attach to a dom element to run when there is user input/interaction. This is how any input propagates into the rust code/WASM.
Main interface
I used the yew framework. I think that the library itself was somewhat unstable6. Next time, I would do some sort of hybrid solution where I write the UI in JavaScript/Typescript and interact with WASM blobs as necessary.
Grid
The GridCell manage rendering and User input, while the GridManager manages overall state and sample loading. This is through a tangle of states and props7.
%%{
init: {
'theme': 'base',
'themeVariables': {
'primaryColor': '#3f3f3f',
'primaryTextColor': '#dcdccc',
'primaryBorderColor': '#dcdccc',
'lineColor': '#303030',
'secondaryColor': '#93e0e3',
'tertiaryColor': '#7f9f7f',
'fontFamily': 'Hack'
}
}
}%%
sequenceDiagram
GridManager->>GridCell: Manager Callback + Initial State
GridCell->>GridManager: Cell Callback
GridCell-->>GridManager: User Update + New State
GridManager-->>GridCell: Sample Update
Other stuff
I used tailwind css by compiling in the css with trunk and that is about it8.
Running the Model
Training
Communication
%%{
init: {
'theme': 'base',
'themeVariables': {
'primaryColor': '#3f3f3f',
'primaryTextColor': '#dcdccc',
'primaryBorderColor': '#dcdccc',
'lineColor': '#303030',
'secondaryColor': '#93e0e3',
'tertiaryColor': '#7f9f7f',
'fontFamily': 'Hack'
}
}
}%%
sequenceDiagram
Renderer->>Model: Start Training
Renderer->>Model: Stop Training
Renderer->>Model: Send status
Renderer->>Model: Set weights
Renderer->>Model: Set batch size
Renderer->>Model: Set learning rate
Renderer->>Model: Set cache size
Model->>Renderer: ResponseSignal
Where ResponseSignal is:
type Weights {
weights: (Vec<Vec<f64>>, Vec<Vec<f64>>)
}
type ResponseSignal {
weights: Weights
loss: f64
acc: f64
batch_size: usize
lrate: f64
data_len: usize
data_futures_len: usize
iteration: usize
cache_size: usize
}
The Model worker waits for a signal from the Renderer worker, and responds with a ResponseSignal. If the Model worker is training, then it sends a ResponseSignal every iteration.
Getting Data
The Model maintains an internal cache that is lazily replenished9. Every iteration (whether training or not), new futures are created to replenish the cache. If the model is training, but the cache is empty, then it delays that iteration10.
The trick to bringing the data back into the Model worker is to send a handle along with the spawn_local as described earlier.
type DataSingle {
target: u8
image: Vec<f64>
}
type Data {
data: Vec<DataSingle>
}
type ModelData {
data_vec: Arc<Mutex<VecDeque<Data>>>
}
In short, I spawn futures like this
for _ in 0..future_num {
let data_vec_handle = self.data_vec.clone();
spawn_local(async move {
let data = get_block().await;
data_vec_handle.lock().unwrap().push_back(data);
});
}
Then I can hand the data and model to a small function to essentially just call a train function, and I update my model.
-
Threads in the sense of parallel processes. Everything runs on a single thread because it runs in your browser, however web workers are “threads” ↩︎
-
Through the use of
spawn_localin wasm_bindgen_futures ↩︎ -
also through
spawn_localin wasm_bindgen_futures ↩︎ -
Two reasons - you can only send JS types (objects need to be Send and Sync) as well as through the postMessage interface only ↩︎
-
handle in the generic term could be a yew state,
Arc<Mutex<T>>, or something else ↩︎ -
I had to pin to a specific commit to get the newest features without breaking my project every couple of weeks. ↩︎
-
Yew made this part especially hard, and I think I should have done it in React ↩︎
-
I thought this would be much more painful than it was ↩︎
-
Without a cache, you have to wait for the data to load every single iteration. One request runs in similar time as 5 parallel requests, so it was essential to parallelize. ↩︎
-
The delay overhead is minimial, so it like polling until the data arrives (which waiting from signals from
Rendererthread) ↩︎