Sentiment Analysis
This notebook runs a DistilBERT sentiment analysis model entirely in the browser using ONNX Runtime. It includes a vocabulary loader, WordPiece tokenizer, and inference pipeline — all editable.
Setup
Load onnxrt, Note (FRP), and the widget library:
#require "onnxrt";;
#require "note";;
#require "js_top_worker-widget";;let () =
Js_of_ocaml.Js.Unsafe.meth_call
Js_of_ocaml.Js.Unsafe.global "importScripts"
[| Js_of_ocaml.Js.Unsafe.inject
(Js_of_ocaml.Js.string
"https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.min.js") |]
let () =
let open Js_of_ocaml in
let ort = Js.Unsafe.get Js.Unsafe.global (Js.string "ort") in
let env = Js.Unsafe.get ort (Js.string "env") in
let wasm = Js.Unsafe.get env (Js.string "wasm") in
Js.Unsafe.set wasm (Js.string "wasmPaths")
(Js.string "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/")
let () = print_endline "ort.js loaded"Vocabulary
The Vocab module maps tokens to integer IDs. It parses a vocab.txt file (one token per line) and tracks special token IDs used by BERT.
module Vocab = struct
type t = {
token_to_id : (string, int) Hashtbl.t;
unk_id : int;
cls_id : int;
sep_id : int;
pad_id : int;
}
let load_from_string text =
let lines = String.split_on_char '\n' text in
let table = Hashtbl.create 32000 in
List.iteri
(fun id line ->
let token = String.trim line in
if token <> "" then Hashtbl.replace table token id)
lines;
let find key =
match Hashtbl.find_opt table key with Some id -> id | None -> 0
in
{
token_to_id = table;
unk_id = find "[UNK]";
cls_id = find "[CLS]";
sep_id = find "[SEP]";
pad_id = find "[PAD]";
}
let find_token t token = Hashtbl.find_opt t.token_to_id token
endTokenizer
The Tokenizer module implements BERT's WordPiece tokenization. It lowercases input, splits on whitespace and punctuation, then greedily matches subwords from the vocabulary.
module Tokenizer = struct
type encoded = {
input_ids : int array;
attention_mask : int array;
}
let is_punctuation c =
match c with
| '!' | '"' | '#' | '$' | '%' | '&' | '\'' | '(' | ')' | '*' | '+' | ','
| '-' | '.' | '/' | ':' | ';' | '<' | '=' | '>' | '?' | '@' | '[' | '\\'
| ']' | '^' | '_' | '`' | '{' | '|' | '}' | '~' ->
true
| _ -> false
let split_on_punctuation word =
let len = String.length word in
if len = 0 then []
else begin
let tokens = ref [] in
let buf = Buffer.create 16 in
for i = 0 to len - 1 do
let c = word.[i] in
if is_punctuation c then begin
if Buffer.length buf > 0 then begin
tokens := Buffer.contents buf :: !tokens;
Buffer.clear buf
end;
tokens := String.make 1 c :: !tokens
end else
Buffer.add_char buf c
done;
if Buffer.length buf > 0 then
tokens := Buffer.contents buf :: !tokens;
List.rev !tokens
end
let wordpiece_tokenize vocab word =
let len = String.length word in
if len = 0 then []
else begin
let tokens = ref [] in
let start = ref 0 in
let failed = ref false in
while !start < len && not !failed do
let found = ref false in
let sub_end = ref len in
while !sub_end > !start && not !found do
let sub =
if !start > 0 then
"##" ^ String.sub word !start (!sub_end - !start)
else String.sub word !start (!sub_end - !start)
in
match Vocab.find_token vocab sub with
| Some _id ->
tokens := sub :: !tokens;
start := !sub_end;
found := true
| None -> decr sub_end
done;
if not !found then begin
tokens := "[UNK]" :: !tokens;
failed := true
end
done;
List.rev !tokens
end
let encode vocab text ~max_length =
let text = String.lowercase_ascii text in
let words =
String.split_on_char ' ' text
|> List.concat_map (String.split_on_char '\t')
|> List.concat_map (String.split_on_char '\n')
|> List.filter (fun s -> s <> "")
in
let subtokens =
words
|> List.concat_map split_on_punctuation
|> List.concat_map (wordpiece_tokenize vocab)
in
let lookup tok =
match Vocab.find_token vocab tok with
| Some id -> id
| None -> vocab.Vocab.unk_id
in
let max_tokens = max_length - 2 in
let subtokens =
if List.length subtokens > max_tokens then
List.filteri (fun i _ -> i < max_tokens) subtokens
else subtokens
in
let ids = List.map lookup subtokens in
let token_ids = [ vocab.Vocab.cls_id ] @ ids @ [ vocab.Vocab.sep_id ] in
let real_len = List.length token_ids in
let input_ids = Array.make max_length vocab.Vocab.pad_id in
let attention_mask = Array.make max_length 0 in
List.iteri (fun i id -> input_ids.(i) <- id) token_ids;
for i = 0 to real_len - 1 do
attention_mask.(i) <- 1
done;
{ input_ids; attention_mask }
endHelper Functions
Create int64 tensors (required by DistilBERT) and compute softmax over logits:
open Onnxrt
let make_int64_tensor (data : int array) (dims : int array) : Tensor.t =
let ort_obj = Js_of_ocaml.Js.Unsafe.get
Js_of_ocaml.Js.Unsafe.global (Js_of_ocaml.Js.string "ort") in
let tensor_ctor = Js_of_ocaml.Js.Unsafe.get ort_obj
(Js_of_ocaml.Js.string "Tensor") in
let js_data =
Js_of_ocaml.Js.array
(Array.map
(fun x ->
Js_of_ocaml.Js.Unsafe.eval_string (Printf.sprintf "%dn" x))
data)
in
let bigint64_ctor =
Js_of_ocaml.Js.Unsafe.get Js_of_ocaml.Js.Unsafe.global
(Js_of_ocaml.Js.string "BigInt64Array") in
let bigint64_arr =
Js_of_ocaml.Js.Unsafe.meth_call bigint64_ctor "from"
[| Js_of_ocaml.Js.Unsafe.inject js_data |]
in
let js_dims =
Js_of_ocaml.Js.Unsafe.inject
(Js_of_ocaml.Js.array
(Array.map (fun d -> Js_of_ocaml.Js.Unsafe.inject d) dims))
in
let js_tensor =
Js_of_ocaml.Js.Unsafe.new_obj tensor_ctor
[| Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string "int64");
Js_of_ocaml.Js.Unsafe.inject bigint64_arr;
js_dims |]
in
(Obj.magic (Js_of_ocaml.Js.Unsafe.coerce js_tensor, false) : Tensor.t)
let softmax (logits : float array) : float array =
let max_val = Array.fold_left max neg_infinity logits in
let exps = Array.map (fun x -> exp (x -. max_val)) logits in
let sum = Array.fold_left ( +. ) 0.0 exps in
Array.map (fun e -> e /. sum) exps
let fetch_text_sync url =
let xhr = Js_of_ocaml.Js.Unsafe.new_obj
(Js_of_ocaml.Js.Unsafe.get Js_of_ocaml.Js.Unsafe.global
(Js_of_ocaml.Js.string "XMLHttpRequest")) [||] in
Js_of_ocaml.Js.Unsafe.meth_call xhr "open"
[| Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string "GET");
Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string url);
Js_of_ocaml.Js.Unsafe.inject Js_of_ocaml.Js._false |];
Js_of_ocaml.Js.Unsafe.meth_call xhr "send" [||];
Js_of_ocaml.Js.to_string
(Js_of_ocaml.Js.Unsafe.get xhr (Js_of_ocaml.Js.string "responseText"))
let () = print_endline "Helpers defined"Load Model
Fetch the vocabulary file (synchronous XHR in the worker) and load the quantized DistilBERT model. A status widget updates as loading progresses:
let max_length = 128
let vocab = ref None
let session = ref None
let model_status_e, send_model_status = Note.E.create ()
let model_status = Note.S.hold "Loading vocabulary..." model_status_e
let status_view msg =
let open Widget.View in
Element { tag = "div"; attrs = [
Style ("padding", "0.75em 1em");
Style ("border-radius", "6px");
Style ("font-family", "monospace");
Style ("border", "1px solid currentColor");
Style ("opacity", "0.8");
]; children = [Text msg] }
let () =
Widget.display ~id:"model-status" ~handlers:[]
(status_view "Loading vocabulary...")
let _logr = Note.S.log
(Note.S.map status_view model_status)
(Widget.update ~id:"model-status")
let () = Note.Logr.hold _logr
let () =
let v = Vocab.load_from_string (fetch_text_sync "vocab.txt") in
vocab := Some v;
send_model_status "Vocabulary loaded. Loading model..."
let () = Lwt.async (fun () ->
let open Lwt.Syntax in
let* s = Session.create "model_quantized.onnx" () in
session := Some s;
send_model_status "Model ready.";
Lwt.return_unit)Analyze Sentiment
Type or edit the text below — the model classifies it reactively via Note signals whenever you click Analyze.
let input_e, send_input = Note.E.create ()
let input_text = Note.S.hold
"This movie was absolutely wonderful, I loved every minute of it!"
input_e
let result_e, send_result = Note.E.create ()
let result_s = Note.S.hold "Type something and click Analyze." result_e
let analyze text =
match !vocab, !session with
| Some v, Some s ->
send_result "Running inference...";
Lwt.async (fun () ->
let open Lwt.Syntax in
let encoded = Tokenizer.encode v text ~max_length in
let input_ids_tensor =
make_int64_tensor encoded.Tokenizer.input_ids [| 1; max_length |] in
let attention_mask_tensor =
make_int64_tensor encoded.Tokenizer.attention_mask [| 1; max_length |] in
let* outputs =
Session.run s
[ ("input_ids", input_ids_tensor);
("attention_mask", attention_mask_tensor) ] in
let logits_tensor = List.assoc "logits" outputs in
let logits_data = Tensor.to_bigarray1_exn Dtype.Float32 logits_tensor in
let logits =
[| Bigarray.Array1.get logits_data 0;
Bigarray.Array1.get logits_data 1 |] in
let probs = softmax logits in
let label, confidence =
if probs.(1) > probs.(0) then ("POSITIVE", probs.(1))
else ("NEGATIVE", probs.(0)) in
let emoji = if label = "POSITIVE" then "👍" else "👎" in
send_result (Printf.sprintf "%s %s (%.1f%% confident)"
emoji label (confidence *. 100.0));
Tensor.dispose input_ids_tensor;
Tensor.dispose attention_mask_tensor;
Tensor.dispose logits_tensor;
Lwt.return_unit)
| _ ->
send_result "Model not loaded yet — wait a moment and try again."
let input_view text =
let open Widget.View in
Element { tag = "div"; attrs = [
Style ("display", "flex");
Style ("flex-direction", "column");
Style ("gap", "0.75em");
]; children = [
Element { tag = "textarea"; attrs = [
Property ("rows", "3");
Style ("width", "100%");
Style ("padding", "0.75em");
Style ("border-radius", "6px");
Style ("border", "1px solid currentColor");
Style ("font-size", "1em");
Style ("font-family", "inherit");
Style ("background", "transparent");
Style ("color", "inherit");
Style ("resize", "vertical");
Style ("opacity", "0.9");
Handler ("input", "text_changed");
]; children = [Text text] };
Element { tag = "button"; attrs = [
Style ("padding", "0.5em 1.5em");
Style ("border-radius", "6px");
Style ("border", "1px solid currentColor");
Style ("background", "transparent");
Style ("color", "inherit");
Style ("font-size", "1em");
Style ("cursor", "pointer");
Handler ("click", "analyze");
]; children = [Text "Analyze"] };
] }
let result_view result =
let open Widget.View in
Element { tag = "div"; attrs = [
Style ("font-family", "monospace");
Style ("padding", "0.75em 1em");
Style ("border-radius", "6px");
Style ("border", "1px solid currentColor");
Style ("opacity", "0.8");
]; children = [Text result] }
let () =
Widget.display ~id:"sentiment-input"
~handlers:[
"text_changed", (fun v -> send_input (Option.value ~default:"" v));
"analyze", (fun _ -> analyze (Note.S.value input_text));
]
(input_view (Note.S.value input_text))
let () =
Widget.display ~id:"sentiment-result" ~handlers:[]
(result_view (Note.S.value result_s))
let _logr_input = Note.S.log input_text (fun _ -> ())
let () = Note.Logr.hold _logr_input
let _logr_result = Note.S.log
(Note.S.map result_view result_s)
(Widget.update ~id:"sentiment-result")
let () = Note.Logr.hold _logr_result