Source file switch.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
type t = {
  mutable fibers : int;         (* Total, including daemon_fibers and the main function *)
  mutable daemon_fibers : int;
  mutable exs : (exn * Printexc.raw_backtrace) option;
  on_release_lock : Mutex.t;
  mutable on_release : (unit -> unit) Lwt_dllist.t option;      (* [None] when closed. *)
  waiter : unit Single_waiter.t;              (* The main [top]/[sub] function may wait here for fibers to finish. *)
  cancel : Cancel.t;
}

type hook =
  | Null
  | Hook : Mutex.t * (unit -> unit) Lwt_dllist.node -> hook

let null_hook = Null

let cancelled () = assert false

let try_remove_hook = function
  | Null -> false
  | Hook (on_release_lock, n) ->
    Mutex.lock on_release_lock;
    Lwt_dllist.remove n;
    let fn = Lwt_dllist.get n in
    Lwt_dllist.set n cancelled;
    Mutex.unlock on_release_lock;
    fn != cancelled

let remove_hook x = ignore (try_remove_hook x : bool)

let dump f t =
  Fmt.pf f "@[<v2>Switch %d (%d extra fibers):@,%a@]"
    (t.cancel.id :> int)
    t.fibers
    Cancel.dump t.cancel

let is_finished t = Cancel.is_finished t.cancel

(* Check switch belongs to this domain (and isn't finished). It's OK if it's cancelling. *)
let check_our_domain t =
  if is_finished t then invalid_arg "Switch finished!";
  if Domain.self () <> t.cancel.domain then invalid_arg "Switch accessed from wrong domain!"

(* Check isn't cancelled (or finished). *)
let check t =
  if is_finished t then invalid_arg "Switch finished!";
  Cancel.check t.cancel

let get_error t =
  Cancel.get_error t.cancel

let combine_exn ex = function
  | None -> ex
  | Some ex1 -> Exn.combine ex1 ex

(* Note: raises if [t] is finished or called from wrong domain. *)
let fail ?(bt=Exn.empty_backtrace) t ex =
  check_our_domain t;
  t.exs <- Some (combine_exn (ex, bt) t.exs);
  try
    Cancel.cancel t.cancel ex
  with ex ->
    let bt = Printexc.get_raw_backtrace () in
    t.exs <- Some (combine_exn (ex, bt) t.exs)

let inc_fibers t =
  check t;
  t.fibers <- t.fibers + 1

let dec_fibers t =
  t.fibers <- t.fibers - 1;
  if t.daemon_fibers > 0 && t.fibers = t.daemon_fibers then
    Cancel.cancel t.cancel Exit;
  if t.fibers = 0 then
    Single_waiter.wake_if_sleeping t.waiter

let with_op t fn =
  inc_fibers t;
  Fun.protect fn
    ~finally:(fun () -> dec_fibers t)

let with_daemon t fn =
  inc_fibers t;
  t.daemon_fibers <- t.daemon_fibers + 1;
  Fun.protect fn
    ~finally:(fun () ->
        t.daemon_fibers <- t.daemon_fibers - 1;
        dec_fibers t
      )

let or_raise = function
  | Ok x -> x
  | Error ex -> raise ex

let rec await_idle t =
  (* Wait for fibers to finish: *)
  while t.fibers > 0 do
    Trace.try_get t.cancel.id;
    Single_waiter.await_protect t.waiter "Switch.await_idle" t.cancel.id
  done;
  (* Collect on_release handlers: *)
  let queue = ref [] in
  let enqueue n =
    let fn = Lwt_dllist.get n in
    Lwt_dllist.set n cancelled;
    queue := fn :: !queue
  in
  Mutex.lock t.on_release_lock;
  Option.iter (Lwt_dllist.iter_node_l enqueue) t.on_release;
  t.on_release <- None;
  Mutex.unlock t.on_release_lock;
  (* Run on_release handlers *)
  !queue |> List.iter (fun fn -> try Cancel.protect fn with ex -> fail t ex);
  if t.fibers > 0 then await_idle t

let maybe_raise_exs t =
  match t.exs with
  | None -> ()
  | Some (ex, bt) -> Printexc.raise_with_backtrace ex bt

let create cancel =
  {
    fibers = 1;         (* The main function counts as a fiber *)
    daemon_fibers = 0;
    exs = None;
    waiter = Single_waiter.create ();
    on_release_lock = Mutex.create ();
    on_release = Some (Lwt_dllist.create ());
    cancel;
  }

let run_internal t fn =
  match fn t with
  | v ->
    dec_fibers t;
    await_idle t;
    Trace.get t.cancel.id;
    maybe_raise_exs t;        (* Check for failure while finishing *)
    (* Success. *)
    v
  | exception ex ->
    let bt = Printexc.get_raw_backtrace () in
    (* Main function failed.
       Turn the switch off to cancel any running fibers, if it's not off already. *)
    dec_fibers t;
    fail ~bt t ex;
    await_idle t;
    Trace.get t.cancel.id;
    maybe_raise_exs t;
    assert false

let run ?name fn = Cancel.sub_checked ?name Switch (fun cc -> run_internal (create cc) fn)

let run_protected ?name fn =
  let ctx = Effect.perform Cancel.Get_context in
  Cancel.with_cc ~ctx ~parent:ctx.cancel_context ~protected:true Switch @@ fun cancel ->
  Option.iter (Trace.name cancel.id) name;
  run_internal (create cancel) fn

(* Run [fn ()] in [t]'s cancellation context.
   This prevents [t] from finishing until [fn] is done,
   and means that cancelling [t] will cancel [fn]. *)
let run_in t fn =
  with_op t @@ fun () ->
  let ctx = Effect.perform Cancel.Get_context in
  let old_cc = ctx.cancel_context in
  Cancel.move_fiber_to t.cancel ctx;
  match fn () with
  | ()           -> Cancel.move_fiber_to old_cc ctx;
  | exception ex -> Cancel.move_fiber_to old_cc ctx; raise ex

exception Release_error of string * exn

let () =
  Printexc.register_printer (function
      | Release_error (msg, ex) -> Some (Fmt.str "@[<v2>%s@,while handling %a@]" msg Exn.pp ex)
      | _ -> None
    )

let on_release_full t fn =
  Mutex.lock t.on_release_lock;
  match t.on_release with
  | Some handlers ->
    let node = Lwt_dllist.add_r fn handlers in
    Mutex.unlock t.on_release_lock;
    node
  | None ->
    Mutex.unlock t.on_release_lock;
    match Cancel.protect fn with
    | () -> invalid_arg "Switch finished!"
    | exception ex ->
      let bt = Printexc.get_raw_backtrace () in
      Printexc.raise_with_backtrace (Release_error ("Switch finished!", ex)) bt

let on_release t fn =
  ignore (on_release_full t fn : _ Lwt_dllist.node)

let on_release_cancellable t fn =
  Hook (t.on_release_lock, on_release_full t fn)