Rust WASM Neural Net PART 3 - Website
Components
Main Components (organized by thread1):
- UI rendering and interaction
- Model training and inference
stateDiagram direction TB classDef colors fill:#3f3f3f,color:#dcdccc,stroke-width:1.5px,stroke:#dcdccc,font-family:Hack ui: UI Worker infer: Inference Process api: Weights Process model: Model Worker data1: Data Process data2: Data Process data3: Data Process ui --> infer ui --> model infer --> ui ui --> api model --> ui api --> ui model --> data1 model --> data2 model --> data3 data1 --> model data2 --> model data3 --> model class ui colors class infer colors class api colors class model colors class data1 colors class data2 colors class data3 colors
Essentially, the rendering thread offloads all model interaction to a web worker, which then spawns data processes2 to replenish a cache of data that is then used to train the model. Inference is computationally light3 and done on the rendering worker directly. The tricky part is management of these workers/processes to create responsiveness in the UI while training efficiently.
Places to Compute
Web Workers
From MDN Web Docs
Web Workers makes it possible to run a script operation in a background thread separate from the main execution thread of a web application. The advantage of this is that laborious processing can be performed in a separate thread, allowing the main (usually the UI) thread to run without being blocked/slowed down.
Web workers are super useful, but creating communication channels can be confusing4. In an earlier version, I tried to do all my data caching in a separate web worker, but to simplify communication I decided to cache in a process. However, running the model in it’s own web worker allowed rendering to run smoothly and independently.
Processes
I created processes mostly with spawn_local
from wasm_bindgen_futures, which is very simply:
#[inline]
pub fn spawn_local<F>(future: F)
where
F: Future<Output = ()> + 'static,
{
task::Task::spawn(Box::pin(future));
}
spawn_local
combined with an async
closure (through async move
) and a handle5, allows you put some computations in the background and not block the main thread.
Callbacks
Yew Callbacks essentially take a function, and attach to a dom element to run when there is user input/interaction. This is how any input propagates into the rust code/WASM.
Main interface
I used the yew framework. I think that the library itself was somewhat unstable6. Next time, I would do some sort of hybrid solution where I write the UI in JavaScript/Typescript and interact with WASM blobs as necessary.
Grid
The GridCell
manage rendering and User input, while the GridManager
manages overall state and sample loading. This is through a tangle of states and props7.
%%{ init: { 'theme': 'base', 'themeVariables': { 'primaryColor': '#3f3f3f', 'primaryTextColor': '#dcdccc', 'primaryBorderColor': '#dcdccc', 'lineColor': '#303030', 'secondaryColor': '#93e0e3', 'tertiaryColor': '#7f9f7f', 'fontFamily': 'Hack' } } }%% sequenceDiagram GridManager->>GridCell: Manager Callback + Initial State GridCell->>GridManager: Cell Callback GridCell-->>GridManager: User Update + New State GridManager-->>GridCell: Sample Update
Other stuff
I used tailwind css by compiling in the css with trunk and that is about it8.
Running the Model
Training
Communication
%%{ init: { 'theme': 'base', 'themeVariables': { 'primaryColor': '#3f3f3f', 'primaryTextColor': '#dcdccc', 'primaryBorderColor': '#dcdccc', 'lineColor': '#303030', 'secondaryColor': '#93e0e3', 'tertiaryColor': '#7f9f7f', 'fontFamily': 'Hack' } } }%% sequenceDiagram Renderer->>Model: Start Training Renderer->>Model: Stop Training Renderer->>Model: Send status Renderer->>Model: Set weights Renderer->>Model: Set batch size Renderer->>Model: Set learning rate Renderer->>Model: Set cache size Model->>Renderer: ResponseSignal
Where ResponseSignal
is:
type Weights {
weights: (Vec<Vec<f64>>, Vec<Vec<f64>>)
}
type ResponseSignal {
weights: Weights
loss: f64
acc: f64
batch_size: usize
lrate: f64
data_len: usize
data_futures_len: usize
iteration: usize
cache_size: usize
}
The Model
worker waits for a signal from the Renderer
worker, and responds with a ResponseSignal
. If the Model
worker is training, then it sends a ResponseSignal
every iteration.
Getting Data
The Model
maintains an internal cache that is lazily replenished9. Every iteration (whether training or not), new futures are created to replenish the cache. If the model is training, but the cache is empty, then it delays that iteration10.
The trick to bringing the data back into the Model
worker is to send a handle along with the spawn_local
as described earlier.
type DataSingle {
target: u8
image: Vec<f64>
}
type Data {
data: Vec<DataSingle>
}
type ModelData {
data_vec: Arc<Mutex<VecDeque<Data>>>
}
In short, I spawn futures like this
for _ in 0..future_num {
let data_vec_handle = self.data_vec.clone();
spawn_local(async move {
let data = get_block().await;
data_vec_handle.lock().unwrap().push_back(data);
});
}
Then I can hand the data and model to a small function to essentially just call a train
function, and I update my model.
-
Threads in the sense of parallel processes. Everything runs on a single thread because it runs in your browser, however web workers are “threads” ↩︎
-
Through the use of
spawn_local
in wasm_bindgen_futures ↩︎ -
also through
spawn_local
in wasm_bindgen_futures ↩︎ -
Two reasons - you can only send JS types (objects need to be Send and Sync) as well as through the postMessage interface only ↩︎
-
handle in the generic term could be a yew state,
Arc<Mutex<T>>
, or something else ↩︎ -
I had to pin to a specific commit to get the newest features without breaking my project every couple of weeks. ↩︎
-
Yew made this part especially hard, and I think I should have done it in React ↩︎
-
I thought this would be much more painful than it was ↩︎
-
Without a cache, you have to wait for the data to load every single iteration. One request runs in similar time as 5 parallel requests, so it was essential to parallelize. ↩︎
-
The delay overhead is minimial, so it like polling until the data arrives (which waiting from signals from
Renderer
thread) ↩︎