jon.recoil.org

Sentiment Analysis

This notebook runs a DistilBERT sentiment analysis model entirely in the browser using ONNX Runtime. It includes a vocabulary loader, WordPiece tokenizer, and inference pipeline — all editable.

Setup

Load onnxrt, Note (FRP), and the widget library:

#require "onnxrt";; #require "note";; #require "js_top_worker-widget";;let () = Js_of_ocaml.Js.Unsafe.meth_call Js_of_ocaml.Js.Unsafe.global "importScripts" [| Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.min.js") |] let () = let open Js_of_ocaml in let ort = Js.Unsafe.get Js.Unsafe.global (Js.string "ort") in let env = Js.Unsafe.get ort (Js.string "env") in let wasm = Js.Unsafe.get env (Js.string "wasm") in Js.Unsafe.set wasm (Js.string "wasmPaths") (Js.string "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/") let () = print_endline "ort.js loaded"

Vocabulary

The Vocab module maps tokens to integer IDs. It parses a vocab.txt file (one token per line) and tracks special token IDs used by BERT.

module Vocab = struct type t = { token_to_id : (string, int) Hashtbl.t; unk_id : int; cls_id : int; sep_id : int; pad_id : int; } let load_from_string text = let lines = String.split_on_char '\n' text in let table = Hashtbl.create 32000 in List.iteri (fun id line -> let token = String.trim line in if token <> "" then Hashtbl.replace table token id) lines; let find key = match Hashtbl.find_opt table key with Some id -> id | None -> 0 in { token_to_id = table; unk_id = find "[UNK]"; cls_id = find "[CLS]"; sep_id = find "[SEP]"; pad_id = find "[PAD]"; } let find_token t token = Hashtbl.find_opt t.token_to_id token end

Tokenizer

The Tokenizer module implements BERT's WordPiece tokenization. It lowercases input, splits on whitespace and punctuation, then greedily matches subwords from the vocabulary.

module Tokenizer = struct type encoded = { input_ids : int array; attention_mask : int array; } let is_punctuation c = match c with | '!' | '"' | '#' | '$' | '%' | '&' | '\'' | '(' | ')' | '*' | '+' | ',' | '-' | '.' | '/' | ':' | ';' | '<' | '=' | '>' | '?' | '@' | '[' | '\\' | ']' | '^' | '_' | '`' | '{' | '|' | '}' | '~' -> true | _ -> false let split_on_punctuation word = let len = String.length word in if len = 0 then [] else begin let tokens = ref [] in let buf = Buffer.create 16 in for i = 0 to len - 1 do let c = word.[i] in if is_punctuation c then begin if Buffer.length buf > 0 then begin tokens := Buffer.contents buf :: !tokens; Buffer.clear buf end; tokens := String.make 1 c :: !tokens end else Buffer.add_char buf c done; if Buffer.length buf > 0 then tokens := Buffer.contents buf :: !tokens; List.rev !tokens end let wordpiece_tokenize vocab word = let len = String.length word in if len = 0 then [] else begin let tokens = ref [] in let start = ref 0 in let failed = ref false in while !start < len && not !failed do let found = ref false in let sub_end = ref len in while !sub_end > !start && not !found do let sub = if !start > 0 then "##" ^ String.sub word !start (!sub_end - !start) else String.sub word !start (!sub_end - !start) in match Vocab.find_token vocab sub with | Some _id -> tokens := sub :: !tokens; start := !sub_end; found := true | None -> decr sub_end done; if not !found then begin tokens := "[UNK]" :: !tokens; failed := true end done; List.rev !tokens end let encode vocab text ~max_length = let text = String.lowercase_ascii text in let words = String.split_on_char ' ' text |> List.concat_map (String.split_on_char '\t') |> List.concat_map (String.split_on_char '\n') |> List.filter (fun s -> s <> "") in let subtokens = words |> List.concat_map split_on_punctuation |> List.concat_map (wordpiece_tokenize vocab) in let lookup tok = match Vocab.find_token vocab tok with | Some id -> id | None -> vocab.Vocab.unk_id in let max_tokens = max_length - 2 in let subtokens = if List.length subtokens > max_tokens then List.filteri (fun i _ -> i < max_tokens) subtokens else subtokens in let ids = List.map lookup subtokens in let token_ids = [ vocab.Vocab.cls_id ] @ ids @ [ vocab.Vocab.sep_id ] in let real_len = List.length token_ids in let input_ids = Array.make max_length vocab.Vocab.pad_id in let attention_mask = Array.make max_length 0 in List.iteri (fun i id -> input_ids.(i) <- id) token_ids; for i = 0 to real_len - 1 do attention_mask.(i) <- 1 done; { input_ids; attention_mask } end

Helper Functions

Create int64 tensors (required by DistilBERT) and compute softmax over logits:

open Onnxrt let make_int64_tensor (data : int array) (dims : int array) : Tensor.t = let ort_obj = Js_of_ocaml.Js.Unsafe.get Js_of_ocaml.Js.Unsafe.global (Js_of_ocaml.Js.string "ort") in let tensor_ctor = Js_of_ocaml.Js.Unsafe.get ort_obj (Js_of_ocaml.Js.string "Tensor") in let js_data = Js_of_ocaml.Js.array (Array.map (fun x -> Js_of_ocaml.Js.Unsafe.eval_string (Printf.sprintf "%dn" x)) data) in let bigint64_ctor = Js_of_ocaml.Js.Unsafe.get Js_of_ocaml.Js.Unsafe.global (Js_of_ocaml.Js.string "BigInt64Array") in let bigint64_arr = Js_of_ocaml.Js.Unsafe.meth_call bigint64_ctor "from" [| Js_of_ocaml.Js.Unsafe.inject js_data |] in let js_dims = Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.array (Array.map (fun d -> Js_of_ocaml.Js.Unsafe.inject d) dims)) in let js_tensor = Js_of_ocaml.Js.Unsafe.new_obj tensor_ctor [| Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string "int64"); Js_of_ocaml.Js.Unsafe.inject bigint64_arr; js_dims |] in (Obj.magic (Js_of_ocaml.Js.Unsafe.coerce js_tensor, false) : Tensor.t) let softmax (logits : float array) : float array = let max_val = Array.fold_left max neg_infinity logits in let exps = Array.map (fun x -> exp (x -. max_val)) logits in let sum = Array.fold_left ( +. ) 0.0 exps in Array.map (fun e -> e /. sum) exps let fetch_text_sync url = let xhr = Js_of_ocaml.Js.Unsafe.new_obj (Js_of_ocaml.Js.Unsafe.get Js_of_ocaml.Js.Unsafe.global (Js_of_ocaml.Js.string "XMLHttpRequest")) [||] in Js_of_ocaml.Js.Unsafe.meth_call xhr "open" [| Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string "GET"); Js_of_ocaml.Js.Unsafe.inject (Js_of_ocaml.Js.string url); Js_of_ocaml.Js.Unsafe.inject Js_of_ocaml.Js._false |]; Js_of_ocaml.Js.Unsafe.meth_call xhr "send" [||]; Js_of_ocaml.Js.to_string (Js_of_ocaml.Js.Unsafe.get xhr (Js_of_ocaml.Js.string "responseText")) let () = print_endline "Helpers defined"

Load Model

Fetch the vocabulary file (synchronous XHR in the worker) and load the quantized DistilBERT model. A status widget updates as loading progresses:

let max_length = 128 let vocab = ref None let session = ref None let model_status_e, send_model_status = Note.E.create () let model_status = Note.S.hold "Loading vocabulary..." model_status_e let status_view msg = let open Widget.View in Element { tag = "div"; attrs = [ Style ("padding", "0.75em 1em"); Style ("border-radius", "6px"); Style ("font-family", "monospace"); Style ("border", "1px solid currentColor"); Style ("opacity", "0.8"); ]; children = [Text msg] } let () = Widget.display ~id:"model-status" ~handlers:[] (status_view "Loading vocabulary...") let _logr = Note.S.log (Note.S.map status_view model_status) (Widget.update ~id:"model-status") let () = Note.Logr.hold _logr let () = let v = Vocab.load_from_string (fetch_text_sync "vocab.txt") in vocab := Some v; send_model_status "Vocabulary loaded. Loading model..." let () = Lwt.async (fun () -> let open Lwt.Syntax in let* s = Session.create "model_quantized.onnx" () in session := Some s; send_model_status "Model ready."; Lwt.return_unit)

Analyze Sentiment

Type or edit the text below — the model classifies it reactively via Note signals whenever you click Analyze.

let input_e, send_input = Note.E.create () let input_text = Note.S.hold "This movie was absolutely wonderful, I loved every minute of it!" input_e let result_e, send_result = Note.E.create () let result_s = Note.S.hold "Type something and click Analyze." result_e let analyze text = match !vocab, !session with | Some v, Some s -> send_result "Running inference..."; Lwt.async (fun () -> let open Lwt.Syntax in let encoded = Tokenizer.encode v text ~max_length in let input_ids_tensor = make_int64_tensor encoded.Tokenizer.input_ids [| 1; max_length |] in let attention_mask_tensor = make_int64_tensor encoded.Tokenizer.attention_mask [| 1; max_length |] in let* outputs = Session.run s [ ("input_ids", input_ids_tensor); ("attention_mask", attention_mask_tensor) ] in let logits_tensor = List.assoc "logits" outputs in let logits_data = Tensor.to_bigarray1_exn Dtype.Float32 logits_tensor in let logits = [| Bigarray.Array1.get logits_data 0; Bigarray.Array1.get logits_data 1 |] in let probs = softmax logits in let label, confidence = if probs.(1) > probs.(0) then ("POSITIVE", probs.(1)) else ("NEGATIVE", probs.(0)) in let emoji = if label = "POSITIVE" then "👍" else "👎" in send_result (Printf.sprintf "%s %s (%.1f%% confident)" emoji label (confidence *. 100.0)); Tensor.dispose input_ids_tensor; Tensor.dispose attention_mask_tensor; Tensor.dispose logits_tensor; Lwt.return_unit) | _ -> send_result "Model not loaded yet — wait a moment and try again." let input_view text = let open Widget.View in Element { tag = "div"; attrs = [ Style ("display", "flex"); Style ("flex-direction", "column"); Style ("gap", "0.75em"); ]; children = [ Element { tag = "textarea"; attrs = [ Property ("rows", "3"); Style ("width", "100%"); Style ("padding", "0.75em"); Style ("border-radius", "6px"); Style ("border", "1px solid currentColor"); Style ("font-size", "1em"); Style ("font-family", "inherit"); Style ("background", "transparent"); Style ("color", "inherit"); Style ("resize", "vertical"); Style ("opacity", "0.9"); Handler ("input", "text_changed"); ]; children = [Text text] }; Element { tag = "button"; attrs = [ Style ("padding", "0.5em 1.5em"); Style ("border-radius", "6px"); Style ("border", "1px solid currentColor"); Style ("background", "transparent"); Style ("color", "inherit"); Style ("font-size", "1em"); Style ("cursor", "pointer"); Handler ("click", "analyze"); ]; children = [Text "Analyze"] }; ] } let result_view result = let open Widget.View in Element { tag = "div"; attrs = [ Style ("font-family", "monospace"); Style ("padding", "0.75em 1em"); Style ("border-radius", "6px"); Style ("border", "1px solid currentColor"); Style ("opacity", "0.8"); ]; children = [Text result] } let () = Widget.display ~id:"sentiment-input" ~handlers:[ "text_changed", (fun v -> send_input (Option.value ~default:"" v)); "analyze", (fun _ -> analyze (Note.S.value input_text)); ] (input_view (Note.S.value input_text)) let () = Widget.display ~id:"sentiment-result" ~handlers:[] (result_view (Note.S.value result_s)) let _logr_input = Note.S.log input_text (fun _ -> ()) let () = Note.Logr.hold _logr_input let _logr_result = Note.S.log (Note.S.map result_view result_s) (Widget.update ~id:"sentiment-result") let () = Note.Logr.hold _logr_result