jon.recoil.org

Source file code_matcher.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
open! Import
module Format = Stdlib.Format
module Filename = Stdlib.Filename
let allow_deriving_end = ref false

let end_marker_sig =
  Attribute.Floating.declare "ppxlib.inline.end" Signature_item
    Ast_pattern.(pstr nil)
    ()

let end_marker_str =
  Attribute.Floating.declare "ppxlib.inline.end" Structure_item
    Ast_pattern.(pstr nil)
    ()

let deprecated_end_marker_sig =
  Attribute.Floating.declare "@deriving.end" Signature_item
    Ast_pattern.(pstr nil)
    ()

let deprecated_end_marker_str =
  Attribute.Floating.declare "@deriving.end" Structure_item
    Ast_pattern.(pstr nil)
    ()

module type T1 = sig
  type 'a t
end

module Make (M : sig
  type t

  val get_loc : t -> Location.t
  val end_marker : (t, unit) Attribute.Floating.t
  val deprecated_end_marker : (t, unit) Attribute.Floating.t

  module Transform (T : T1) : sig
    val apply :
      < structure_item : structure_item T.t
      ; signature_item : signature_item T.t
      ; .. > ->
      t T.t
  end

  val parse : Lexing.lexbuf -> t list
  val pp : Format.formatter -> t -> unit
  val to_sexp : t -> Sexp.t

  val update_locs_to_include_doc_comments : t list -> t list
end) =
struct
  let extract_prefix ~pos l =
    let rec loop acc = function
      | [] ->
          let loc =
            { Location.loc_start = pos; loc_end = pos; loc_ghost = false }
          in
          Error
            ( Location.Error.createf ~loc "ppxlib: [@@@@@@%s] attribute missing"
                (Attribute.Floating.name M.end_marker),
              [] )
      | x :: l -> (
        match Attribute.Floating.convert_res [ M.end_marker ] x with
          | Ok (Some ()) -> Ok (List.rev acc, (M.get_loc x).loc_start)
          | Error e -> Error e
          | exception Failure _ | Ok None ->
           (match Attribute.Floating.convert_res [ M.deprecated_end_marker ] x with
            | Ok (Some ()) ->
              if !allow_deriving_end then
                Ok (List.rev acc, (M.get_loc x).loc_start)
              else
                Error (
                  Location.Error.createf
                    ~loc:(M.get_loc x)
                    "ppxlib: [@@@@@@%s] is deprecated, please use [@@@@@@%s]. If you need \
                     the deprecated attribute temporarily, pass [-allow-deriving-end] to \
                     the ppx driver)."
                    ( Attribute.Floating.name M.deprecated_end_marker )
                    ( Attribute.Floating.name M.end_marker ) ,
                    []
                )
            | Error e -> Error e
            | exception Failure _ | Ok None -> loop (x :: acc) l))
    in
    loop [] l

  (* When checking for whether the parsed code matches the expected code, there are
     certain expected changes between the AST generated by ppxlib and the AST parsed
     by the compiler and translated to match the ppxlib AST version.

     Here, we normalize the AST so that these expected changes don't cause spurious
     error messages about the round-trip check failing. *)
  let traverse_normalize =
    object
      inherit Ast_traverse.map as super
      (* Ignore locations *)
      method! location _ = Location.none
      method! location_stack _ = []



      (* Drop erasable attributes encoding syntactic arity *)
      method! attributes attrs =
        List.filter attrs ~f:(fun { attr_name; _ } ->
            not (String.starts_with attr_name.txt ~prefix:"jane.erasable._builtin"))
        |> super#attributes



      (* Reconcile how value binding constraints are handled between the two versions.
         As of 2024-01-12, the AST generated by the compiler parse and translation inserts
         an extra [Pexp_constraint] on the expression. *)
      method! value_binding ({ pvb_pat; pvb_expr; _ } as pvb) =
        super#value_binding
          (match pvb_pat.ppat_desc, pvb_expr.pexp_desc with
           | ( Ppat_constraint (_, Some pat_type, [])
             , Pexp_constraint (expr, Some expr_type, []) )
             when Poly.( = ) pat_type expr_type ->
             { pvb with pvb_expr = expr }
           | _ -> pvb)
    end


  module M_map = M.Transform (struct
    type 'a t = 'a -> 'a
  end)

  let normalize x = M_map.apply traverse_normalize x

  let rec last prev = function [] -> prev | x :: l -> last x l

  let diff_asts ~generated ~round_trip =
    let with_temp_file f =
      Exn.protectx
        (Filename.temp_file "ppxlib" "")
        ~finally:Stdlib.Sys.remove ~f
    in
    with_temp_file (fun fn1 ->
        with_temp_file (fun fn2 ->
            with_temp_file (fun out ->
                let dump fn ast =
                  Out_channel.with_file fn ~f:(fun oc ->
                      let ppf = Format.formatter_of_out_channel oc in
                      Sexp.pp_hum ppf (M.to_sexp ast);
                      Format.pp_print_flush ppf ())
                in
                dump fn1 generated;
                dump fn2 round_trip;
                let cmd =
                  Printf.sprintf
                    "patdiff -ascii -alt-old generated -alt-new \
                     'generated->printed->parsed' %s %s &> %s"
                    (Filename.quote fn1) (Filename.quote fn2)
                    (Filename.quote out)
                in
                let ok =
                  Stdlib.Sys.command cmd = 1
                  ||
                  let cmd =
                    Printf.sprintf
                      "diff --label generated --label \
                       'generated->printed->parsed' %s %s &> %s"
                      (Filename.quote fn1) (Filename.quote fn2)
                      (Filename.quote out)
                  in
                  Stdlib.Sys.command cmd = 1
                in
                if ok then In_channel.read_all out
                else "<no differences produced by diff>")))

  let parse_string s =
    match M.parse (Lexing.from_string s) with [ x ] -> x | _ -> assert false

  let rec match_loop ~end_pos ~mismatch_handler ~expected ~source =
    match (expected, source) with
    | [], [] -> ()
    | [], x :: l ->
        let loc =
          { (M.get_loc x) with loc_end = (M.get_loc (last x l)).loc_end }
        in
        mismatch_handler loc []
    | _, [] ->
        let loc =
          { Location.loc_ghost = false; loc_start = end_pos; loc_end = end_pos }
        in
        mismatch_handler loc expected
    | x :: expected, y :: source ->
        let loc = M.get_loc y in
        let x = normalize x in
        let y = normalize y in
        if Poly.( <> ) x y then (
          let round_trip =
            normalize (parse_string (Format.asprintf "%a@." M.pp x))
          in
          if Poly.( <> ) x round_trip then
            Location.raise_errorf ~loc
              "ppxlib: the corrected code doesn't round-trip.\n\
               This is probably a bug in the OCaml printer:\n\
               %s"
              (diff_asts ~generated:x ~round_trip);
          mismatch_handler loc [ x ]);
        match_loop ~end_pos ~mismatch_handler ~expected ~source

  let do_match ~pos ~expected ~mismatch_handler source =
    let open Result in
    let source = M.update_locs_to_include_doc_comments source in
    extract_prefix ~pos source >>| fun (source, end_pos) ->
    match_loop ~end_pos ~mismatch_handler ~expected ~source
end

module Str = Make (struct
  type t = structure_item

  let get_loc x = x.pstr_loc
  let end_marker = end_marker_str
  let deprecated_end_marker = deprecated_end_marker_str

  module Transform (T : T1) = struct
    let apply o = o#structure_item
  end

  let parse = Parse.implementation
  let pp = Pprintast.structure_item
  let to_sexp = Ast_traverse.sexp_of#structure_item

  let update_locs_to_include_doc_comments =
    Utils.update_locs_to_include_doc_comments#structure
end)


module Sig = Make (struct
  type t = signature_item

  let get_loc x = x.psig_loc
  let end_marker = end_marker_sig
  let deprecated_end_marker = deprecated_end_marker_sig

  module Transform (T : T1) = struct
    let apply o = o#signature_item
  end

  let parse x = let { psg_items; _ } = Parse.interface x in psg_items
  let pp = Pprintast.signature_item
  let to_sexp = Ast_traverse.sexp_of#signature_item

  let update_locs_to_include_doc_comments =
    Utils.update_locs_to_include_doc_comments#signature_items
end)

let match_structure_res = Str.do_match

let match_structure ~pos ~expected ~mismatch_handler l =
  match_structure_res ~pos ~expected ~mismatch_handler l
  |> Result.handle_error ~f:(fun (err, _) -> Location.Error.raise err)

let match_signature_res = Sig.do_match

let match_signature ~pos ~expected ~mismatch_handler l =
  match_signature_res ~pos ~expected ~mismatch_handler l
  |> Result.handle_error ~f:(fun (err, _) -> Location.Error.raise err)