1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
exception Cancelled = Exn.Cancelled
type state =
| On
| Cancelling of exn * Printexc.raw_backtrace
| Finished
type t = {
id : Trace.id;
mutable state : state;
children : t Lwt_dllist.t;
fibers : fiber_context Lwt_dllist.t;
protected : bool;
domain : Domain.id;
}
and fiber_context = {
tid : Trace.id;
mutable cancel_context : t;
mutable cancel_node : fiber_context Lwt_dllist.node option;
mutable cancel_fn : exn -> unit;
mutable vars : Hmap.t;
}
type _ Effect.t += Get_context : fiber_context Effect.t
let pp_state f t =
begin match t.state with
| On -> Fmt.string f "on"
| Cancelling (ex, _) -> Fmt.pf f "cancelling(%a)" Fmt.exn ex
| Finished -> Fmt.string f "finished"
end;
if t.protected then Fmt.pf f " (protected)"
let pp_fiber f fiber =
Fmt.pf f "%d" (fiber.tid :> int)
let pp_lwt_dlist ~sep pp f t =
let first = ref true in
t |> Lwt_dllist.iter_l (fun item ->
if !first then first := false
else sep f ();
pp f item;
)
let rec dump f t =
Fmt.pf f "@[<v2>%a [%a]%a@]"
pp_state t
(pp_lwt_dlist ~sep:(Fmt.any ",") pp_fiber) t.fibers
pp_children t.children
and pp_children f ts =
ts |> Lwt_dllist.iter_l (fun t ->
Fmt.cut f ();
dump f t
)
let is_on t =
match t.state with
| On -> true
| Cancelling _ | Finished -> false
let check t =
match t.state with
| On -> ()
| Cancelling (ex, _) -> raise (Cancelled ex)
| Finished -> invalid_arg "Cancellation context finished!"
let get_error t =
match t.state with
| On -> None
| Cancelling (ex, _) -> Some (Cancelled ex)
| Finished -> Some (Invalid_argument "Cancellation context finished!")
let is_finished t =
match t.state with
| Finished -> true
| On | Cancelling _ -> false
let move_fiber_to t fiber =
let new_node = Lwt_dllist.add_r fiber t.fibers in
fiber.cancel_context <- t;
Option.iter Lwt_dllist.remove fiber.cancel_node;
fiber.cancel_node <- Some new_node
let create ~protected purpose =
let children = Lwt_dllist.create () in
let fibers = Lwt_dllist.create () in
let id = Trace.mint_id () in
Trace.create_cc id purpose;
{ id; state = Finished; children; protected; fibers; domain = Domain.self () }
let activate t ~parent =
assert (t.state = Finished);
assert (parent.state <> Finished);
t.state <- On;
let node = Lwt_dllist.add_r t parent.children in
fun () ->
assert (parent.state <> Finished);
t.state <- Finished;
Lwt_dllist.remove node
let with_cc ~ctx:fiber ~parent ~protected purpose fn =
if not protected then check parent;
let t = create ~protected purpose in
let deactivate = activate t ~parent in
move_fiber_to t fiber;
let cleanup () = move_fiber_to parent fiber; deactivate () in
match fn t with
| x -> cleanup (); Trace.exit_cc (); x
| exception ex -> cleanup (); Trace.exit_cc (); raise ex
let protect fn =
let ctx = Effect.perform Get_context in
with_cc ~ctx ~parent:ctx.cancel_context ~protected:true Protect @@ fun _ ->
fn ()
let rec cancel_internal t ex acc_fibers =
match t.state with
| Finished -> invalid_arg "Cancellation context finished!"
| Cancelling _ -> acc_fibers
| On ->
let bt = Printexc.get_raw_backtrace () in
t.state <- Cancelling (ex, bt);
Trace.error t.id ex;
let acc_fibers = Lwt_dllist.fold_r List.cons t.fibers acc_fibers in
Lwt_dllist.fold_r (cancel_child ex) t.children acc_fibers
and cancel_child ex t acc =
if t.protected then acc
else cancel_internal t ex acc
let check_our_domain t =
if Domain.self () <> t.domain then invalid_arg "Cancellation context accessed from wrong domain!"
let cancel t ex =
check_our_domain t;
let fibers = cancel_internal t ex [] in
let cex = Cancelled ex in
let rec aux = function
| [] -> []
| x :: xs ->
let fn = x.cancel_fn in
x.cancel_fn <- ignore;
match fn cex with
| () -> aux xs
| exception ex2 ->
let bt = Printexc.get_raw_backtrace () in
(ex2, bt) :: aux xs
in
if fibers <> [] then (
match aux fibers with
| [] -> ()
| ex :: exs ->
let ex, bt = List.fold_left Exn.combine ex exs in
Printexc.raise_with_backtrace ex bt
)
let sub_checked ?name purpose fn =
let ctx = Effect.perform Get_context in
let parent = ctx.cancel_context in
with_cc ~ctx ~parent ~protected:false purpose @@ fun t ->
Option.iter (Trace.name t.id) name;
fn t
let sub fn =
sub_checked Sub fn
let sub_unchecked purpose fn =
let ctx = Effect.perform Get_context in
let parent = ctx.cancel_context in
with_cc ~ctx ~parent ~protected:false purpose @@ fun t ->
fn t;
parent
module Fiber_context = struct
type t = fiber_context
let tid t = t.tid
let cancellation_context t = t.cancel_context
let get_error t = get_error t.cancel_context
let set_cancel_fn t fn =
t.cancel_fn <- fn
let clear_cancel_fn t =
t.cancel_fn <- ignore
let make ~cc ~vars =
let tid = Trace.mint_id () in
Trace.create_fiber tid ~cc:cc.id;
let t = { tid; cancel_context = cc; cancel_node = None; cancel_fn = ignore; vars } in
t.cancel_node <- Some (Lwt_dllist.add_r t cc.fibers);
t
let make_root () =
let cc = create ~protected:false Root in
cc.state <- On;
make ~cc ~vars:Hmap.empty
let destroy t =
Trace.exit_fiber t.tid;
Option.iter Lwt_dllist.remove t.cancel_node
let vars t = t.vars
let get_vars () =
vars (Effect.perform Get_context)
let with_vars t vars fn =
let old_vars = t.vars in
t.vars <- vars;
let cleanup () = t.vars <- old_vars in
match fn () with
| x -> cleanup (); x
| exception ex -> cleanup (); raise ex
end