Source file ppx_generator_expander.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
open! Import

let arrow
  ~generator_of_core_type
  ~observer_of_core_type
  ~loc
  ~arg_label
  ~input_type
  ~output_type
  =
  let input_observer =
    match arg_label with
    | Nolabel | Labelled _ -> observer_of_core_type input_type
    | Optional _ ->
      [%expr
        Ppx_quickcheck_runtime.Base_quickcheck.Observer.option
          [%e observer_of_core_type input_type]]
  in
  let output_generator = generator_of_core_type output_type in
  let unlabelled =
    [%expr
      Ppx_quickcheck_runtime.Base_quickcheck.Generator.fn
        [%e input_observer]
        [%e output_generator]]
  in
  match arg_label with
  | Nolabel -> unlabelled
  | Labelled _ | Optional _ ->
    [%expr
      Ppx_quickcheck_runtime.Base_quickcheck.Generator.map
        ~f:[%e fn_map_label ~loc ~from:Nolabel ~to_:arg_label]
        [%e unlabelled]]
;;

let compound_generator ~loc ~make_compound_expr generator_list =
  let loc = { loc with loc_ghost = true } in
  let size_pat, size_expr = gensym "size" loc in
  let random_pat, random_expr = gensym "random" loc in
  [%expr
    Ppx_quickcheck_runtime.Base_quickcheck.Generator.create
      (fun ~size:[%p size_pat] ~random:[%p random_pat] ->
      [%e
        make_compound_expr
          ~loc
          (List.map generator_list ~f:(fun generator ->
             let loc = { generator.pexp_loc with loc_ghost = true } in
             [%expr
               Ppx_quickcheck_runtime.Base_quickcheck.Generator.generate
                 [%e generator]
                 ~size:[%e size_expr]
                 ~random:[%e random_expr]]))])]
;;

let compound
  (type field)
  ~generator_of_core_type
  ~loc
  ~fields
  (module Field : Field_syntax.S with type ast = field)
  =
  let fields = List.map fields ~f:Field.create in
  compound_generator
    ~loc
    ~make_compound_expr:(Field.expression fields)
    (List.map fields ~f:(fun field -> generator_of_core_type (Field.core_type field)))
;;

let does_refer_to name_set =
  object (self)
    inherit [bool] Ast_traverse.fold as super

    method! core_type ty acc =
      match ty.ptyp_desc with
      | Ptyp_constr (name, args) ->
        acc
        || Set.mem name_set (Longident.name name.txt)
        || List.exists args ~f:(fun arg -> self#core_type arg false)
      | _ -> super#core_type ty acc
  end
;;

let clause_is_recursive
  (type clause)
  ~clause
  ~rec_names
  (module Clause : Clause_syntax.S with type t = clause)
  =
  List.exists (Clause.core_type_list clause) ~f:(fun ty ->
    (does_refer_to rec_names)#core_type ty false)
;;

let variant
  (type clause)
  ~generator_of_core_type
  ~loc
  ~variant_type
  ~clauses
  ~rec_names
  (module Clause : Clause_syntax.S with type ast = clause)
  =
  let clauses = Clause.create_list clauses in
  let make_generator clause =
    compound_generator
      ~loc:(Clause.location clause)
      ~make_compound_expr:(Clause.expression clause variant_type)
      (List.map (Clause.core_type_list clause) ~f:generator_of_core_type)
  in
  let make_pair clause =
    Option.map (Clause.weight clause) ~f:(fun weight ->
      pexp_tuple
        ~loc:{ (Clause.location clause) with loc_ghost = true }
        [ weight; make_generator clause ])
  in
  (* We filter out clauses with weight None now. If we don't, then we can get code in
     [body] below that relies on bindings that don't get generated. *)
  let clauses =
    List.filter clauses ~f:(fun clause -> Option.is_some (Clause.weight clause))
  in
  match
    List.partition_tf clauses ~f:(fun clause ->
      clause_is_recursive ~clause ~rec_names (module Clause))
  with
  | [], [] -> invalid ~loc "variant had no (generated) cases"
  | [], clauses | clauses, [] ->
    let pairs = List.filter_map clauses ~f:make_pair in
    [%expr
      Ppx_quickcheck_runtime.Base_quickcheck.Generator.weighted_union
        [%e elist ~loc pairs]]
  | recursive_clauses, nonrecursive_clauses ->
    let size_pat, size_expr = gensym "size" loc in
    let nonrec_pat, nonrec_expr = gensym "gen" loc in
    let rec_pat, rec_expr = gensym "gen" loc in
    let nonrec_pats, nonrec_exprs =
      gensyms "pair" (List.map nonrecursive_clauses ~f:Clause.location)
    in
    let rec_pats, rec_exprs =
      gensyms "pair" (List.map recursive_clauses ~f:Clause.location)
    in
    let bindings =
      List.filter_opt
        (List.map2_exn nonrec_pats nonrecursive_clauses ~f:(fun pat clause ->
           let loc = { (Clause.location clause) with loc_ghost = true } in
           Option.map (make_pair clause) ~f:(fun expr -> value_binding ~loc ~pat ~expr))
         @ List.map2_exn rec_pats recursive_clauses ~f:(fun pat clause ->
             Option.map (Clause.weight clause) ~f:(fun weight_expr ->
               let loc = { (Clause.location clause) with loc_ghost = true } in
               let gen_expr =
                 [%expr
                   Ppx_quickcheck_runtime.Base_quickcheck.Generator.bind
                     Ppx_quickcheck_runtime.Base_quickcheck.Generator.size
                     ~f:(fun [%p size_pat] ->
                     Ppx_quickcheck_runtime.Base_quickcheck.Generator.with_size
                       ~size:(Ppx_quickcheck_runtime.Base.Int.pred [%e size_expr])
                       [%e make_generator clause])]
               in
               let expr = pexp_tuple ~loc [ weight_expr; gen_expr ] in
               value_binding ~loc ~pat ~expr)))
    in
    let body =
      [%expr
        let [%p nonrec_pat] =
          Ppx_quickcheck_runtime.Base_quickcheck.Generator.weighted_union
            [%e elist ~loc nonrec_exprs]
        and [%p rec_pat] =
          Ppx_quickcheck_runtime.Base_quickcheck.Generator.weighted_union
            [%e elist ~loc (nonrec_exprs @ rec_exprs)]
        in
        Ppx_quickcheck_runtime.Base_quickcheck.Generator.bind
          Ppx_quickcheck_runtime.Base_quickcheck.Generator.size
          ~f:(function
          | 0 -> [%e nonrec_expr]
          | _ -> [%e rec_expr])]
    in
    pexp_let ~loc Nonrecursive bindings body
;;