Source file lwt_engine.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
(* This file is part of Lwt, released under the MIT license. See LICENSE.md for
   details, or visit https://github.com/ocsigen/lwt/blob/master/LICENSE.md. *)



(* [Lwt_sequence] is deprecated – we don't want users outside Lwt using it.
   However, it is still used internally by Lwt. So, briefly disable warning 3
   ("deprecated"), and create a local, non-deprecated alias for
   [Lwt_sequence] that can be referred to by the rest of the code in this
   module without triggering any more warnings. *)
[@@@ocaml.warning "-3"]
module Lwt_sequence = Lwt_sequence
[@@@ocaml.warning "+3"]

(* +-----------------------------------------------------------------+
   | Events                                                          |
   +-----------------------------------------------------------------+ *)

type _event = {
  stop : unit Lazy.t;
  (* The stop method of the event. *)
  node : Obj.t Lwt_sequence.node;
  (* The node in the sequence of registered events. *)
}

type event = _event ref

external cast_node : 'a Lwt_sequence.node -> Obj.t Lwt_sequence.node = "%identity"

let stop_event ev =
  let ev = !ev in
  Lwt_sequence.remove ev.node;
  Lazy.force ev.stop

let _fake_event = {
  stop = lazy ();
  node = Lwt_sequence.add_l (Obj.repr ()) (Lwt_sequence.create ());
}

let fake_event = ref _fake_event

(* +-----------------------------------------------------------------+
   | Engines                                                         |
   +-----------------------------------------------------------------+ *)

class virtual abstract = object(self)
  method virtual iter : bool -> unit
  method virtual private cleanup : unit
  method virtual private register_readable : Unix.file_descr -> (unit -> unit) -> unit Lazy.t
  method virtual private register_writable : Unix.file_descr -> (unit -> unit) -> unit Lazy.t
  method virtual private register_timer : float -> bool -> (unit -> unit) -> unit Lazy.t

  val readables = Lwt_sequence.create ()
  (* Sequence of callbacks waiting for a file descriptor to become
     readable. *)

  val writables = Lwt_sequence.create ()
  (* Sequence of callbacks waiting for a file descriptor to become
     writable. *)

  val timers = Lwt_sequence.create ()
  (* Sequence of timers. *)

  method destroy =
    Lwt_sequence.iter_l (fun (_fd, _f, _g, ev) -> stop_event ev) readables;
    Lwt_sequence.iter_l (fun (_fd, _f, _g, ev) -> stop_event ev) writables;
    Lwt_sequence.iter_l (fun (_delay, _repeat, _f, _g, ev) -> stop_event ev)
      timers;
    self#cleanup

  method transfer (engine : abstract) =
    Lwt_sequence.iter_l (fun (fd, f, _g, ev) ->
      stop_event ev; ev := !(engine#on_readable fd f)) readables;
    Lwt_sequence.iter_l (fun (fd, f, _g, ev) ->
      stop_event ev; ev := !(engine#on_writable fd f)) writables;
    Lwt_sequence.iter_l (fun (delay, repeat, f, _g, ev) ->
      stop_event ev; ev := !(engine#on_timer delay repeat f)) timers

  method fake_io fd =
    Lwt_sequence.iter_l (fun (fd', _f, g, _stop) ->
      if fd = fd' then g ()) readables;
    Lwt_sequence.iter_l (fun (fd', _f, g, _stop) ->
      if fd = fd' then g ()) writables

  method on_readable fd f =
    let ev = ref _fake_event in
    let g () = f ev in
    let stop = self#register_readable fd g in
    ev := { stop = stop; node = cast_node (Lwt_sequence.add_r (fd, f, g, ev) readables) };
    ev

  method on_writable fd f =
    let ev = ref _fake_event in
    let g () = f ev in
    let stop = self#register_writable fd g in
    ev := { stop = stop; node = cast_node (Lwt_sequence.add_r (fd, f, g, ev) writables) } ;
    ev

  method on_timer delay repeat f =
    let ev = ref _fake_event in
    let g () = f ev in
    let stop = self#register_timer delay repeat g in
    ev := { stop = stop; node = cast_node (Lwt_sequence.add_r (delay, repeat, f, g, ev) timers) };
    ev

  method readable_count = Lwt_sequence.length readables
  method writable_count = Lwt_sequence.length writables
  method timer_count = Lwt_sequence.length timers

  method fork = ()

  method forwards_signal (_signum:int) = false
end

class type t = object
  inherit abstract

  method iter : bool -> unit
  method private cleanup : unit
  method private register_readable : Unix.file_descr -> (unit -> unit) -> unit Lazy.t
  method private register_writable : Unix.file_descr -> (unit -> unit) -> unit Lazy.t
  method private register_timer : float -> bool -> (unit -> unit) -> unit Lazy.t
end

(* +-----------------------------------------------------------------+
   | The libev engine                                                |
   +-----------------------------------------------------------------+ *)

type ev_loop
type ev_io
type ev_timer

module Ev_backend =
struct
  type t =
    | EV_DEFAULT
    | EV_SELECT
    | EV_POLL
    | EV_EPOLL
    | EV_KQUEUE
    | EV_DEVPOLL
    | EV_PORT

  let default = EV_DEFAULT
  let select = EV_SELECT
  let poll = EV_POLL
  let epoll = EV_EPOLL
  let kqueue = EV_KQUEUE
  let devpoll = EV_DEVPOLL
  let port = EV_PORT

  let equal = ( = )

  let name = function
    | EV_DEFAULT -> "EV_DEFAULT"
    | EV_SELECT -> "EV_SELECT"
    | EV_POLL -> "EV_POLL"
    | EV_EPOLL -> "EV_EPOLL"
    | EV_KQUEUE -> "EV_KQUEUE"
    | EV_DEVPOLL -> "EV_DEVPOLL"
    | EV_PORT -> "EV_PORT"

  let pp fmt t = Format.pp_print_string fmt (name t)
end

external ev_init : Ev_backend.t -> ev_loop = "lwt_libev_init"
external ev_backend : ev_loop -> Ev_backend.t = "lwt_libev_backend"
external ev_stop : ev_loop -> unit = "lwt_libev_stop"
external ev_loop : ev_loop -> bool -> unit = "lwt_libev_loop"
external ev_unloop : ev_loop -> unit = "lwt_libev_unloop"
external ev_readable_init : ev_loop -> Unix.file_descr -> (unit -> unit) -> ev_io = "lwt_libev_readable_init"
external ev_writable_init : ev_loop -> Unix.file_descr -> (unit -> unit) -> ev_io = "lwt_libev_writable_init"
external ev_io_stop : ev_loop -> ev_io -> unit = "lwt_libev_io_stop"
external ev_timer_init : ev_loop -> float -> bool -> (unit -> unit) -> ev_timer = "lwt_libev_timer_init"
external ev_timer_stop : ev_loop -> ev_timer -> unit  = "lwt_libev_timer_stop"

class libev ?(backend=Ev_backend.default) () = object
  inherit abstract

  val loop = ev_init backend
  method loop = loop

  method backend = ev_backend loop

  method private cleanup = ev_stop loop

  method iter block =
    try
      ev_loop loop block
    with exn ->
      ev_unloop loop;
      raise exn

  method private register_readable fd f =
    let ev = ev_readable_init loop fd f in
    lazy(ev_io_stop loop ev)

  method private register_writable fd f =
    let ev = ev_writable_init loop fd f in
    lazy(ev_io_stop loop ev)

  method private register_timer delay repeat f =
    let ev = ev_timer_init loop delay repeat f in
    lazy(ev_timer_stop loop ev)
end

class libev_deprecated = libev ()

(* +-----------------------------------------------------------------+
   | Select/poll based engines                                       |
   +-----------------------------------------------------------------+ *)

(* Type of a sleeper for the select engine. *)
type sleeper = {
  mutable time : float;
  (* The time at which the sleeper should be wakeup. *)

  mutable stopped : bool;
  (* [true] iff the event has been stopped. *)

  action : unit -> unit;
  (* The action for the sleeper. *)
}

module Sleep_queue =
  Lwt_pqueue.Make(struct
    type t = sleeper
    let compare {time = t1; _} {time = t2; _} = compare t1 t2
  end)
  [@@ocaml.warning "-3"]

module Fd_map = Map.Make(struct type t = Unix.file_descr let compare = compare end)

let rec restart_actions sleep_queue now =
  match Sleep_queue.lookup_min sleep_queue with
  | Some{ stopped = true; _ } ->
    restart_actions (Sleep_queue.remove_min sleep_queue) now
  | Some{ time = time; action = action; _ } when time <= now ->
    (* We have to remove the sleeper to the queue before performing
       the action. The action can change the sleeper's time, and this
       might break the priority queue invariant if the sleeper is
       still in the queue. *)
    let q = Sleep_queue.remove_min sleep_queue in
    action ();
    restart_actions q now
  | _ ->
    sleep_queue

let rec get_next_timeout sleep_queue =
  match Sleep_queue.lookup_min sleep_queue with
  | Some{ stopped = true; _ } ->
    get_next_timeout (Sleep_queue.remove_min sleep_queue)
  | Some{ time = time; _ } ->
    max 0. (time -. Unix.gettimeofday ())
  | None ->
    -1.

let bad_fd fd =
  try
    let _ = Unix.fstat fd in
    false
  with Unix.Unix_error (_, _, _) ->
    true

let invoke_actions fd map =
  match Fd_map.find fd map with
  | exception Not_found -> ()
  | actions -> Lwt_sequence.iter_l (fun f -> f ()) actions

class virtual select_or_poll_based = object
  inherit abstract

  val mutable sleep_queue = Sleep_queue.empty
  (* Threads waiting for a timeout to expire. *)

  val mutable new_sleeps = []
  (* Sleepers added since the last iteration of the main loop:

     They are not added immediately to the main sleep queue in order
     to prevent them from being wakeup immediately.  *)

  val mutable wait_readable = Fd_map.empty
  (* Sequences of actions waiting for file descriptors to become
     readable. *)

  val mutable wait_writable = Fd_map.empty
  (* Sequences of actions waiting for file descriptors to become
     writable. *)

  method private cleanup = ()

  method private register_timer delay repeat f =
    if repeat then begin
      let rec sleeper = { time = Unix.gettimeofday () +. delay; stopped = false; action = g }
      and g () =
        sleeper.time <- Unix.gettimeofday () +. delay;
        new_sleeps <- sleeper :: new_sleeps;
        f ()
      in
      new_sleeps <- sleeper :: new_sleeps;
      lazy(sleeper.stopped <- true)
    end else begin
      let sleeper = { time = Unix.gettimeofday () +. delay; stopped = false; action = f } in
      new_sleeps <- sleeper :: new_sleeps;
      lazy(sleeper.stopped <- true)
    end

  method private register_readable fd f =
    let actions =
      try
        Fd_map.find fd wait_readable
      with Not_found ->
        let actions = Lwt_sequence.create () in
        wait_readable <- Fd_map.add fd actions wait_readable;
        actions
    in
    let node = Lwt_sequence.add_l f actions in
    lazy(Lwt_sequence.remove node;
         if Lwt_sequence.is_empty actions then wait_readable <- Fd_map.remove fd wait_readable)

  method private register_writable fd f =
    let actions =
      try
        Fd_map.find fd wait_writable
      with Not_found ->
        let actions = Lwt_sequence.create () in
        wait_writable <- Fd_map.add fd actions wait_writable;
        actions
    in
    let node = Lwt_sequence.add_l f actions in
    lazy(Lwt_sequence.remove node;
         if Lwt_sequence.is_empty actions then wait_writable <- Fd_map.remove fd wait_writable)
end

class virtual select_based = object(self)
  inherit select_or_poll_based

  method private virtual select : Unix.file_descr list -> Unix.file_descr list -> float -> Unix.file_descr list * Unix.file_descr list

  method iter block =
    (* Transfer all sleepers added since the last iteration to the
       main sleep queue: *)
    sleep_queue <- List.fold_left (fun q e -> Sleep_queue.add e q) sleep_queue new_sleeps;
    new_sleeps <- [];
    (* Collect file descriptors. *)
    let fds_r = Fd_map.fold (fun fd _ l -> fd :: l) wait_readable [] in
    let fds_w = Fd_map.fold (fun fd _ l -> fd :: l) wait_writable [] in
    (* Compute the timeout. *)
    let timeout = if block then get_next_timeout sleep_queue else 0. in
    (* Do the blocking call *)
    let fds_r, fds_w =
      try
        self#select fds_r fds_w timeout
      with
      | Unix.Unix_error (Unix.EINTR, _, _) ->
        ([], [])
      | Unix.Unix_error (Unix.EBADF, _, _) ->
        (* Keeps only bad file descriptors. Actions registered on
           them have to handle the error: *)
        (List.filter bad_fd fds_r,
         List.filter bad_fd fds_w)
    in
    (* Restart threads waiting for a timeout: *)
    sleep_queue <- restart_actions sleep_queue (Unix.gettimeofday ());
    (* Restart threads waiting on a file descriptors: *)
    List.iter (fun fd -> invoke_actions fd wait_readable) fds_r;
    List.iter (fun fd -> invoke_actions fd wait_writable) fds_w
end

class virtual poll_based = object(self)
  inherit select_or_poll_based

  method private virtual poll : (Unix.file_descr * bool * bool) list -> float -> (Unix.file_descr * bool * bool) list

  method iter block =
    (* Transfer all sleepers added since the last iteration to the
       main sleep queue: *)
    sleep_queue <- List.fold_left (fun q e -> Sleep_queue.add e q) sleep_queue new_sleeps;
    new_sleeps <- [];
    (* Collect file descriptors. *)
    let fds = [] in
    let fds = Fd_map.fold (fun fd _ l -> (fd, true, false) :: l) wait_readable fds in
    let fds = Fd_map.fold (fun fd _ l -> (fd, false, true) :: l) wait_writable fds in
    (* Compute the timeout. *)
    let timeout = if block then get_next_timeout sleep_queue else 0. in
    (* Do the blocking call *)
    let fds =
      try
        self#poll fds timeout
      with
      | Unix.Unix_error (Unix.EINTR, _, _) ->
        []
      | Unix.Unix_error (Unix.EBADF, _, _) ->
        (* Keeps only bad file descriptors. Actions registered on
           them have to handle the error: *)
        List.filter (fun (fd, _, _) -> bad_fd fd) fds
    in
    (* Restart threads waiting for a timeout: *)
    sleep_queue <- restart_actions sleep_queue (Unix.gettimeofday ());
    (* Restart threads waiting on a file descriptors: *)
    List.iter
      (fun (fd, readable, writable) ->
         if readable then invoke_actions fd wait_readable;
         if writable then invoke_actions fd wait_writable)
      fds
end

class select = object
  inherit select_based

  method private select fds_r fds_w timeout =
    let fds_r, fds_w, _ = Unix.select fds_r fds_w [] timeout in
    (fds_r, fds_w)
end

(* +-----------------------------------------------------------------+
   | The current engine                                              |
   +-----------------------------------------------------------------+ *)

let current =
  if Lwt_config._HAVE_LIBEV && Lwt_config.libev_default then
    ref (new libev () :> t)
  else
    ref (new select :> t)

let get () =
  !current

let set ?(transfer=true) ?(destroy=true) engine =
  if transfer then !current#transfer (engine : #t :> abstract);
  if destroy then !current#destroy;
  current := (engine : #t :> t)

let iter block = !current#iter block
let on_readable fd f = !current#on_readable fd f
let on_writable fd f = !current#on_writable fd f
let on_timer delay repeat f = !current#on_timer delay repeat f
let fake_io fd = !current#fake_io fd
let readable_count () = !current#readable_count
let writable_count () = !current#writable_count
let timer_count () = !current#timer_count
let fork () = !current#fork
let forwards_signal n = !current#forwards_signal n

module Versioned =
struct
  class libev_1 = libev_deprecated
  class libev_2 = libev
end