Source file cancel.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
exception Cancelled = Exn.Cancelled

type state =
  | On
  | Cancelling of exn * Printexc.raw_backtrace
  | Finished

(* There is a tree of cancellation contexts for each domain.
   A fiber is always in exactly one context, but can move to a new child and back (see [sub]).
   While a fiber is performing a cancellable operation, it sets a cancel function.
   When a context is cancelled, we call each fiber's cancellation function (first replacing it with [ignore]).
   Cancelling always happens from the fiber's own domain.
   An operation may either finish normally or be cancelled (not both).
   If a function can succeed in a separate domain,
   the user's cancel function is responsible for ensuring that this is done atomically. *)
type t = {
  id : Trace.id;
  mutable state : state;
  children : t Lwt_dllist.t;
  fibers : fiber_context Lwt_dllist.t;
  protected : bool;
  domain : Domain.id;         (* Prevent access from other domains *)
}
and fiber_context = {
  tid : Trace.id;
  mutable cancel_context : t;
  mutable cancel_node : fiber_context Lwt_dllist.node option; (* Our entry in [cancel_context.fibers] *)
  mutable cancel_fn : exn -> unit;  (* Encourage the current operation to finish *)
  mutable vars : Hmap.t;
}

type _ Effect.t += Get_context : fiber_context Effect.t

let pp_state f t =
  begin match t.state with
    | On -> Fmt.string f "on"
    | Cancelling (ex, _) -> Fmt.pf f "cancelling(%a)" Fmt.exn ex
    | Finished -> Fmt.string f "finished"
  end;
  if t.protected then Fmt.pf f " (protected)"

let pp_fiber f fiber =
  Fmt.pf f "%d" (fiber.tid :> int)

let pp_lwt_dlist ~sep pp f t =
  let first = ref true in
  t |> Lwt_dllist.iter_l (fun item ->
      if !first then first := false
      else sep f ();
      pp f item;
    )

let rec dump f t =
  Fmt.pf f "@[<v2>%a [%a]%a@]"
    pp_state t
    (pp_lwt_dlist ~sep:(Fmt.any ",") pp_fiber) t.fibers
    pp_children t.children
and pp_children f ts =
  ts |> Lwt_dllist.iter_l (fun t ->
      Fmt.cut f ();
      dump f t
    )

let is_on t =
  match t.state with
  | On -> true
  | Cancelling _ | Finished -> false

let check t =
  match t.state with
  | On -> ()
  | Cancelling (ex, _) -> raise (Cancelled ex)
  | Finished -> invalid_arg "Cancellation context finished!"

let get_error t =
  match t.state with
  | On -> None
  | Cancelling (ex, _) -> Some (Cancelled ex)
  | Finished -> Some (Invalid_argument "Cancellation context finished!")

let is_finished t =
  match t.state with
  | Finished -> true
  | On | Cancelling _ -> false

let move_fiber_to t fiber =
  let new_node = Lwt_dllist.add_r fiber t.fibers in     (* Add to new context *)
  fiber.cancel_context <- t;
  Option.iter Lwt_dllist.remove fiber.cancel_node;      (* Remove from old context *)
  fiber.cancel_node <- Some new_node

(* Note: the new value is not linked into the cancellation tree. *)
let create ~protected purpose =
  let children = Lwt_dllist.create () in
  let fibers = Lwt_dllist.create () in
  let id = Trace.mint_id () in
  Trace.create_cc id purpose;
  { id; state = Finished; children; protected; fibers; domain = Domain.self () }

(* Links [t] into the tree as a child of [parent] and returns a function to remove it again. *)
let activate t ~parent =
  assert (t.state = Finished);
  assert (parent.state <> Finished);
  t.state <- On;
  let node = Lwt_dllist.add_r t parent.children in
  fun () ->
    assert (parent.state <> Finished);
    t.state <- Finished;
    Lwt_dllist.remove node

(* Runs [fn] with a fresh cancellation context. *)
let with_cc ~ctx:fiber ~parent ~protected purpose fn =
  if not protected then check parent;
  let t = create ~protected purpose in
  let deactivate = activate t ~parent in
  move_fiber_to t fiber;
  let cleanup () = move_fiber_to parent fiber; deactivate () in
  match fn t with
  | x            -> cleanup (); Trace.exit_cc (); x
  | exception ex -> cleanup (); Trace.exit_cc (); raise ex

let protect fn =
  let ctx = Effect.perform Get_context in
  with_cc ~ctx ~parent:ctx.cancel_context ~protected:true Protect @@ fun _ ->
  (* Note: there is no need to check the new context after [fn] returns;
     the goal of cancellation is only to finish the thread promptly, not to report the error.
     We also do not check the parent context, to make sure the caller has a chance to handle the result. *)
  fn ()

(* Mark the cancellation tree rooted at [t] as Cancelling (stopping at protected sub-contexts),
   and return a list of all fibers in the newly-cancelling contexts. Since modifying the cancellation
   tree can only be done from our domain, this is effectively an atomic operation. Once it returns,
   new (non-protected) fibers cannot be added to any of the cancelling contexts. *)
let rec cancel_internal t ex acc_fibers =
  match t.state with
  | Finished -> invalid_arg "Cancellation context finished!"
  | Cancelling _ -> acc_fibers
  | On ->
    let bt = Printexc.get_raw_backtrace () in
    t.state <- Cancelling (ex, bt);
    Trace.error t.id ex;
    let acc_fibers = Lwt_dllist.fold_r List.cons t.fibers acc_fibers in
    Lwt_dllist.fold_r (cancel_child ex) t.children acc_fibers
and cancel_child ex t acc =
  if t.protected then acc
  else cancel_internal t ex acc

let check_our_domain t =
  if Domain.self () <> t.domain then invalid_arg "Cancellation context accessed from wrong domain!"

let cancel t ex =
  check_our_domain t;
  let fibers = cancel_internal t ex [] in
  let cex = Cancelled ex in
  let rec aux = function
    | [] -> []
    | x :: xs ->
      let fn = x.cancel_fn in
      x.cancel_fn <- ignore;
      match fn cex with
      | () -> aux xs
      | exception ex2 ->
        let bt = Printexc.get_raw_backtrace () in
        (ex2, bt) :: aux xs
  in
  if fibers <> [] then (
    match aux fibers with
    | [] -> ()
    | ex :: exs ->
      let ex, bt = List.fold_left Exn.combine ex exs in
      Printexc.raise_with_backtrace ex bt
  )

let sub_checked ?name purpose fn =
  let ctx = Effect.perform Get_context in
  let parent = ctx.cancel_context in
  with_cc ~ctx ~parent ~protected:false purpose @@ fun t ->
  Option.iter (Trace.name t.id) name;
  fn t

let sub fn =
  sub_checked Sub fn

(* Like [sub], but it's OK if the new context is cancelled.
   (instead, return the parent context on exit so the caller can check that) *)
let sub_unchecked purpose fn =
  let ctx = Effect.perform Get_context in
  let parent = ctx.cancel_context in
  with_cc ~ctx ~parent ~protected:false purpose @@ fun t ->
  fn t;
  parent

module Fiber_context = struct
  type t = fiber_context

  let tid t = t.tid
  let cancellation_context t = t.cancel_context

  let get_error t = get_error t.cancel_context

  let set_cancel_fn t fn =
    t.cancel_fn <- fn

  let clear_cancel_fn t =
    t.cancel_fn <- ignore

  let make ~cc ~vars =
    let tid = Trace.mint_id () in
    Trace.create_fiber tid ~cc:cc.id;
    let t = { tid; cancel_context = cc; cancel_node = None; cancel_fn = ignore; vars } in
    t.cancel_node <- Some (Lwt_dllist.add_r t cc.fibers);
    t

  let make_root () =
    let cc = create ~protected:false Root in
    cc.state <- On;
    make ~cc ~vars:Hmap.empty

  let destroy t =
    Trace.exit_fiber t.tid;
    Option.iter Lwt_dllist.remove t.cancel_node

  let vars t = t.vars

  let get_vars () =
    vars (Effect.perform Get_context)

  let with_vars t vars fn =
    let old_vars = t.vars in
    t.vars <- vars;
    let cleanup () = t.vars <- old_vars in
    match fn () with
    | x            -> cleanup (); x
    | exception ex -> cleanup (); raise ex
end