Source file attrs.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
open! Base
open! Ppxlib

module To_lift = struct
  type 'a t = { to_lift : 'a } [@@unboxed]
end

open To_lift

let default =
  Attribute.declare
    "sexp.default"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr (pstr_eval __ nil ^:: nil))
    (fun x -> { to_lift = x })
;;

let drop_default =
  Attribute.declare
    "sexp.sexp_drop_default"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr (alt_option (pstr_eval __ nil ^:: nil) nil))
    (function
     | None -> None
     | Some x -> Some { to_lift = x })
;;

let drop_default_equal =
  Attribute.declare
    "sexp.@sexp_drop_default.equal"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let drop_default_compare =
  Attribute.declare
    "sexp.@sexp_drop_default.compare"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let drop_default_sexp =
  Attribute.declare
    "sexp.@sexp_drop_default.sexp"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let drop_if =
  Attribute.declare
    "sexp.sexp_drop_if"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr (pstr_eval __ nil ^:: nil))
    (fun x -> { to_lift = x })
;;

let opaque =
  Attribute.declare "sexp.opaque" Attribute.Context.core_type Ast_pattern.(pstr nil) ()
;;

let omit_nil =
  Attribute.declare
    "sexp.omit_nil"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let option =
  Attribute.declare
    "sexp.option"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let list =
  Attribute.declare
    "sexp.list"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let array =
  Attribute.declare
    "sexp.array"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let bool =
  Attribute.declare
    "sexp.bool"
    Attribute.Context.label_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let list_variant =
  Attribute.declare
    "sexp.list"
    Attribute.Context.constructor_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let list_exception =
  Attribute.declare "sexp.list" Attribute.Context.type_exception Ast_pattern.(pstr nil) ()
;;

let list_poly =
  Attribute.declare "sexp.list" Attribute.Context.rtag Ast_pattern.(pstr nil) ()
;;

let allow_extra_fields_td =
  Attribute.declare
    "sexp.allow_extra_fields"
    Attribute.Context.type_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let allow_extra_fields_cd =
  Attribute.declare
    "sexp.allow_extra_fields"
    Attribute.Context.constructor_declaration
    Ast_pattern.(pstr nil)
    ()
;;

let grammar_custom =
  Attribute.declare
    "sexp_grammar.custom"
    Attribute.Context.core_type
    Ast_pattern.(single_expr_payload __)
    (fun x -> x)
;;

let grammar_any =
  Attribute.declare
    "sexp_grammar.any"
    Attribute.Context.core_type
    Ast_pattern.(alt_option (single_expr_payload (estring __)) (pstr nil))
    (fun x -> x)
;;

let tag_attribute_for_context context =
  let open Ast_pattern in
  let key_equals_value =
    Ast_pattern.(
      pexp_apply (pexp_ident (lident (string "="))) (no_label __ ^:: no_label __ ^:: nil)
      |> pack2)
  in
  let get_captured_values ast_pattern context expression =
    Ast_pattern.to_func ast_pattern context expression.pexp_loc expression (fun x -> x)
  in
  let rec collect_sequence expression =
    match expression.pexp_desc with
    | Pexp_sequence (l, r) -> l :: collect_sequence r
    | _ -> [ expression ]
  in
  let esequence ast_pattern =
    Ast_pattern.of_func (fun context _loc expression k ->
      collect_sequence expression
      |> List.map ~f:(get_captured_values ast_pattern context)
      |> k)
  in
  Attribute.declare
    "sexp_grammar.tag"
    context
    (pstr (pstr_eval (esequence key_equals_value) nil ^:: nil))
    (fun x -> x)
;;

let tag_type = tag_attribute_for_context Core_type
let tag_ld = tag_attribute_for_context Label_declaration
let tag_cd = tag_attribute_for_context Constructor_declaration
let tag_poly = tag_attribute_for_context Rtag

let tags_attribute_for_context context =
  Attribute.declare
    "sexp_grammar.tags"
    context
    Ast_pattern.(single_expr_payload __)
    (fun x -> x)
;;

let tags_type = tags_attribute_for_context Core_type
let tags_ld = tags_attribute_for_context Label_declaration
let tags_cd = tags_attribute_for_context Constructor_declaration
let tags_poly = tags_attribute_for_context Rtag

let invalid_attribute ~loc attr description =
  Location.raise_errorf
    ~loc
    "ppx_sexp_conv: [@%s] is only allowed on type [%s]."
    (Attribute.name attr)
    description
;;

let fail_if_allow_extra_field_cd ~loc x =
  if Option.is_some (Attribute.get allow_extra_fields_cd x)
  then
    Location.raise_errorf
      ~loc
      "ppx_sexp_conv: [@@allow_extra_fields] is only allowed on inline records."
;;

let fail_if_allow_extra_field_td ~loc x =
  if Option.is_some (Attribute.get allow_extra_fields_td x)
  then (
    match x.ptype_kind with
    | Ptype_variant cds
      when List.exists cds ~f:(fun cd ->
             match cd.pcd_args with
             | Pcstr_record _ -> true
             | _ -> false) ->
      Location.raise_errorf
        ~loc
        "ppx_sexp_conv: [@@@@allow_extra_fields] only works on records. For inline \
         records, do: type t = A of { a : int } [@@allow_extra_fields] | B [@@@@deriving \
         sexp]"
    | _ ->
      Location.raise_errorf
        ~loc
        "ppx_sexp_conv: [@@@@allow_extra_fields] is only allowed on records.")
;;