1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
exception Cancelled = Exn.Cancelled
type state =
  | On
  | Cancelling of exn * Printexc.raw_backtrace
  | Finished
type t = {
  id : Trace.id;
  mutable state : state;
  children : t Lwt_dllist.t;
  fibers : fiber_context Lwt_dllist.t;
  protected : bool;
  domain : Domain.id;         
}
and fiber_context = {
  tid : Trace.id;
  mutable cancel_context : t;
  mutable cancel_node : fiber_context Lwt_dllist.node option; 
  mutable cancel_fn : exn -> unit;  
  mutable vars : Hmap.t;
}
type _ Effect.t += Get_context : fiber_context Effect.t
let pp_state f t =
  begin match t.state with
    | On -> Fmt.string f "on"
    | Cancelling (ex, _) -> Fmt.pf f "cancelling(%a)" Fmt.exn ex
    | Finished -> Fmt.string f "finished"
  end;
  if t.protected then Fmt.pf f " (protected)"
let pp_fiber f fiber =
  Fmt.pf f "%d" (fiber.tid :> int)
let pp_lwt_dlist ~sep pp f t =
  let first = ref true in
  t |> Lwt_dllist.iter_l (fun item ->
      if !first then first := false
      else sep f ();
      pp f item;
    )
let rec dump f t =
  Fmt.pf f "@[<v2>%a [%a]%a@]"
    pp_state t
    (pp_lwt_dlist ~sep:(Fmt.any ",") pp_fiber) t.fibers
    pp_children t.children
and pp_children f ts =
  ts |> Lwt_dllist.iter_l (fun t ->
      Fmt.cut f ();
      dump f t
    )
let is_on t =
  match t.state with
  | On -> true
  | Cancelling _ | Finished -> false
let check t =
  match t.state with
  | On -> ()
  | Cancelling (ex, _) -> raise (Cancelled ex)
  | Finished -> invalid_arg "Cancellation context finished!"
let get_error t =
  match t.state with
  | On -> None
  | Cancelling (ex, _) -> Some (Cancelled ex)
  | Finished -> Some (Invalid_argument "Cancellation context finished!")
let is_finished t =
  match t.state with
  | Finished -> true
  | On | Cancelling _ -> false
let move_fiber_to t fiber =
  let new_node = Lwt_dllist.add_r fiber t.fibers in     
  fiber.cancel_context <- t;
  Option.iter Lwt_dllist.remove fiber.cancel_node;      
  fiber.cancel_node <- Some new_node
let create ~protected purpose =
  let children = Lwt_dllist.create () in
  let fibers = Lwt_dllist.create () in
  let id = Trace.mint_id () in
  Trace.create_cc id purpose;
  { id; state = Finished; children; protected; fibers; domain = Domain.self () }
let activate t ~parent =
  assert (t.state = Finished);
  assert (parent.state <> Finished);
  t.state <- On;
  let node = Lwt_dllist.add_r t parent.children in
  fun () ->
    assert (parent.state <> Finished);
    t.state <- Finished;
    Lwt_dllist.remove node
let with_cc ~ctx:fiber ~parent ~protected purpose fn =
  if not protected then check parent;
  let t = create ~protected purpose in
  let deactivate = activate t ~parent in
  move_fiber_to t fiber;
  let cleanup () = move_fiber_to parent fiber; deactivate () in
  match fn t with
  | x            -> cleanup (); Trace.exit_cc (); x
  | exception ex -> cleanup (); Trace.exit_cc (); raise ex
let protect fn =
  let ctx = Effect.perform Get_context in
  with_cc ~ctx ~parent:ctx.cancel_context ~protected:true Protect @@ fun _ ->
  
  fn ()
let rec cancel_internal t ex acc_fibers =
  match t.state with
  | Finished -> invalid_arg "Cancellation context finished!"
  | Cancelling _ -> acc_fibers
  | On ->
    let bt = Printexc.get_raw_backtrace () in
    t.state <- Cancelling (ex, bt);
    Trace.error t.id ex;
    let acc_fibers = Lwt_dllist.fold_r List.cons t.fibers acc_fibers in
    Lwt_dllist.fold_r (cancel_child ex) t.children acc_fibers
and cancel_child ex t acc =
  if t.protected then acc
  else cancel_internal t ex acc
let check_our_domain t =
  if Domain.self () <> t.domain then invalid_arg "Cancellation context accessed from wrong domain!"
let cancel t ex =
  check_our_domain t;
  let fibers = cancel_internal t ex [] in
  let cex = Cancelled ex in
  let rec aux = function
    | [] -> []
    | x :: xs ->
      let fn = x.cancel_fn in
      x.cancel_fn <- ignore;
      match fn cex with
      | () -> aux xs
      | exception ex2 ->
        let bt = Printexc.get_raw_backtrace () in
        (ex2, bt) :: aux xs
  in
  if fibers <> [] then (
    match aux fibers with
    | [] -> ()
    | ex :: exs ->
      let ex, bt = List.fold_left Exn.combine ex exs in
      Printexc.raise_with_backtrace ex bt
  )
let sub_checked ?name purpose fn =
  let ctx = Effect.perform Get_context in
  let parent = ctx.cancel_context in
  with_cc ~ctx ~parent ~protected:false purpose @@ fun t ->
  Option.iter (Trace.name t.id) name;
  fn t
let sub fn =
  sub_checked Sub fn
let sub_unchecked purpose fn =
  let ctx = Effect.perform Get_context in
  let parent = ctx.cancel_context in
  with_cc ~ctx ~parent ~protected:false purpose @@ fun t ->
  fn t;
  parent
module Fiber_context = struct
  type t = fiber_context
  let tid t = t.tid
  let cancellation_context t = t.cancel_context
  let get_error t = get_error t.cancel_context
  let set_cancel_fn t fn =
    t.cancel_fn <- fn
  let clear_cancel_fn t =
    t.cancel_fn <- ignore
  let make ~cc ~vars =
    let tid = Trace.mint_id () in
    Trace.create_fiber tid ~cc:cc.id;
    let t = { tid; cancel_context = cc; cancel_node = None; cancel_fn = ignore; vars } in
    t.cancel_node <- Some (Lwt_dllist.add_r t cc.fibers);
    t
  let make_root () =
    let cc = create ~protected:false Root in
    cc.state <- On;
    make ~cc ~vars:Hmap.empty
  let destroy t =
    Trace.exit_fiber t.tid;
    Option.iter Lwt_dllist.remove t.cancel_node
  let vars t = t.vars
  let get_vars () =
    vars (Effect.perform Get_context)
  let with_vars t vars fn =
    let old_vars = t.vars in
    t.vars <- vars;
    let cleanup () = t.vars <- old_vars in
    match fn () with
    | x            -> cleanup (); x
    | exception ex -> cleanup (); raise ex
end