1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
open! Stdlib
let times = Debug.find "times"
open Code
let rec remove_last l =
match l with
| [] -> assert false
| [ _ ] -> []
| x :: r -> x :: remove_last r
let rec tail_call x f l =
match l with
| [] -> None
| [ Let (y, Apply { f = g; args; _ }) ] when Var.compare x y = 0 && Var.compare f g = 0
-> Some args
| _ :: rem -> tail_call x f rem
let rewrite_block (f, f_params, f_pc, args) pc blocks =
let block = Addr.Map.find pc blocks in
match block.branch with
| Return x -> (
match tail_call x f block.body with
| Some f_args when List.length f_params = List.length f_args ->
let m = Subst.build_mapping f_params f_args in
List.iter2 f_params f_args ~f:(fun p a -> Code.Var.propagate_name p a);
Addr.Map.add
pc
{ params = block.params
; body = remove_last block.body
; branch = Branch (f_pc, List.map args ~f:(fun x -> Var.Map.find x m))
}
blocks
| _ -> blocks)
| _ -> blocks
let rec traverse f pc visited blocks =
if not (Addr.Set.mem pc visited)
then
let visited = Addr.Set.add pc visited in
let blocks = rewrite_block f pc blocks in
let visited, blocks =
Code.fold_children_skip_try_body
blocks
pc
(fun pc (visited, blocks) ->
let visited, blocks = traverse f pc visited blocks in
visited, blocks)
(visited, blocks)
in
visited, blocks
else visited, blocks
let f p =
let t = Timer.make () in
let blocks =
fold_closures
p
(fun f params (pc, args) blocks ->
match f with
| Some f when List.length params = List.length args ->
let _, blocks = traverse (f, params, pc, args) pc Addr.Set.empty blocks in
blocks
| _ -> blocks)
p.blocks
in
if times () then Format.eprintf " tail calls: %a@." Timer.print t;
{ p with blocks }