Source file helpers.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
open! Base
open! Ppxlib
open Ast_builder.Default

let ( --> ) lhs rhs = case ~guard:None ~lhs ~rhs

(* Utility functions *)

let replace_variables_by_underscores =
  let map =
    object
      inherit Ast_traverse.map as super

      method! core_type_desc =
        function
        | Ptyp_var _ -> Ptyp_any
        | t -> super#core_type_desc t
    end
  in
  map#core_type
;;

let make_rigid_types tps =
  List.fold
    tps
    ~init:(Map.empty (module String))
    ~f:(fun map tp ->
      Map.update map tp.txt ~f:(function
        | None -> Fresh_name.of_string_loc tp
        | Some fresh ->
          (* Ignore duplicate names, the typechecker will raise after expansion. *)
          fresh))
;;

let find_rigid_type ~loc ~rigid_types name =
  match Map.find rigid_types name with
  | Some tp -> Fresh_name.to_string_loc tp
  | None ->
    (* Ignore unbound type names, the typechecker will raise after expansion. *)
    { txt = name; loc }
;;

let make_type_rigid ~rigid_types =
  let map =
    object
      inherit Ast_traverse.map as super

      method! core_type ty =
        let ptyp_desc =
          match ty.ptyp_desc with
          | Ptyp_var s ->
            Ptyp_constr
              (Located.map_lident (find_rigid_type ~loc:ty.ptyp_loc ~rigid_types s), [])
          | desc -> super#core_type_desc desc
        in
        { ty with ptyp_desc }
    end
  in
  map#core_type
;;

(* Generates the quantified type [ ! 'a .. 'z . (make_mono_type t ('a .. 'z)) ] or
   [type a .. z. make_mono_type t (a .. z)] when [use_rigid_variables] is true.
   Annotation are needed for non regular recursive datatypes and gadt when the return type
   of constructors are constrained. Unfortunately, putting rigid variables everywhere does
   not work because of certains types with constraints. We thus only use rigid variables
   for sum types, which includes all GADTs. *)

let tvars_of_core_type : core_type -> string list =
  let tvars =
    object
      inherit [string list] Ast_traverse.fold as super

      method! core_type x acc =
        match x.ptyp_desc with
        | Ptyp_var x -> if List.mem acc x ~equal:String.equal then acc else x :: acc
        | _ -> super#core_type x acc
    end
  in
  fun typ -> List.rev (tvars#core_type typ [])
;;

let constrained_function_binding
  (* placing a suitably polymorphic or rigid type constraint on the pattern or body *)
    (loc : Location.t)
  (td : type_declaration)
  (typ : core_type)
  ~(tps : string loc list)
  ~(func_name : string)
  (body : expression)
  =
  let vars = tvars_of_core_type typ in
  let has_vars =
    match vars with
    | [] -> false
    | _ :: _ -> true
  in
  let pat =
    let pat = pvar ~loc func_name in
    if not has_vars
    then pat
    else (
      let vars = List.map ~f:(fun txt -> { txt; loc }) vars in
      ppat_constraint ~loc pat (ptyp_poly ~loc vars typ))
  in
  let body =
    let use_rigid_variables =
      match td.ptype_kind with
      | Ptype_variant _ -> true
      | _ -> false
    in
    if use_rigid_variables
    then (
      let rigid_types = make_rigid_types tps in
      List.fold_right
        tps
        ~f:(fun tp body ->
          pexp_newtype ~loc (find_rigid_type ~loc:tp.loc ~rigid_types tp.txt) body)
        ~init:(pexp_constraint ~loc body (make_type_rigid ~rigid_types typ)))
    else if has_vars
    then body
    else pexp_constraint ~loc body typ
  in
  value_binding ~loc ~pat ~expr:body
;;

let with_let ~loc ~binds body =
  List.fold_right binds ~init:body ~f:(fun bind body ->
    if List.is_empty bind then body else pexp_let ~loc Nonrecursive bind body)
;;

let with_types ~loc ~types body =
  if List.is_empty types
  then body
  else
    pexp_open
      ~loc
      (open_infos
         ~loc
         ~override:Fresh
         ~expr:
           (pmod_structure
              ~loc
              (List.map types ~f:(fun type_decl -> pstr_type ~loc Recursive [ type_decl ]))))
      body
;;

let fresh_lambda ~loc apply =
  let var = gen_symbol ~prefix:"x" () in
  let pat = pvar ~loc var in
  let arg = evar ~loc var in
  let body = apply ~arg in
  pexp_fun ~loc Nolabel None pat body
;;

let rec is_value_expression expr =
  match expr.pexp_desc with
  (* Syntactic values. *)
  | Pexp_ident _ | Pexp_constant _ | Pexp_function _ | Pexp_fun _ | Pexp_lazy _ -> true
  (* Type-only wrappers; we check their contents. *)
  | Pexp_constraint (expr, (_ : core_type))
  | Pexp_coerce (expr, (_ : core_type option), (_ : core_type))
  | Pexp_newtype ((_ : string loc), expr) -> is_value_expression expr
  (* Allocating constructors; they are only values if all of their contents are. *)
  | Pexp_tuple exprs -> List.for_all exprs ~f:is_value_expression
  | Pexp_construct (_, maybe_expr) -> Option.for_all maybe_expr ~f:is_value_expression
  | Pexp_variant (_, maybe_expr) -> Option.for_all maybe_expr ~f:is_value_expression
  | Pexp_record (fields, maybe_expr) ->
    List.for_all fields ~f:(fun (_, expr) -> is_value_expression expr)
    && Option.for_all maybe_expr ~f:is_value_expression
  (* Not values, or not always values. We make a conservative approximation. *)
  | Pexp_unreachable
  | Pexp_let _
  | Pexp_apply _
  | Pexp_match _
  | Pexp_try _
  | Pexp_field _
  | Pexp_setfield _
  | Pexp_array _
  | Pexp_ifthenelse _
  | Pexp_sequence _
  | Pexp_while _
  | Pexp_for _
  | Pexp_send _
  | Pexp_new _
  | Pexp_setinstvar _
  | Pexp_override _
  | Pexp_letmodule _
  | Pexp_letexception _
  | Pexp_assert _
  | Pexp_poly _
  | Pexp_object _
  | Pexp_pack _
  | Pexp_open _
  | Pexp_letop _
  | Pexp_extension _ -> false
;;

let really_recursive_respecting_opaque rec_flag tds =
  (object
     inherit type_is_recursive rec_flag tds as super

     method! core_type ctype =
       match ctype with
       | _ when Option.is_some (Attribute.get ~mark_as_seen:false Attrs.opaque ctype) ->
         ()
       | [%type: [%t? _] sexp_opaque] -> ()
       | _ -> super#core_type ctype
  end)
    #go
    ()
;;

let strip_attributes =
  object
    inherit Ast_traverse.map

    method! attribute attr =
      Location.raise_errorf ~loc:attr.attr_loc "failed to strip attribute from syntax"

    method! attributes _ = []

    method! signature items =
      List.filter items ~f:(fun item ->
        match item.psig_desc with
        | Psig_attribute _ -> false
        | _ -> true)

    method! structure items =
      List.filter items ~f:(fun item ->
        match item.pstr_desc with
        | Pstr_attribute _ -> false
        | _ -> true)

    method! class_signature csig =
      { csig with
        pcsig_fields =
          List.filter csig.pcsig_fields ~f:(fun field ->
            match field.pctf_desc with
            | Pctf_attribute _ -> false
            | _ -> true)
      }

    method! class_structure cstr =
      { cstr with
        pcstr_fields =
          List.filter cstr.pcstr_fields ~f:(fun field ->
            match field.pcf_desc with
            | Pcf_attribute _ -> false
            | _ -> true)
      }
  end
;;