Source file npy.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
type _ dtype =
| Int8 : int dtype
| Uint8 : int dtype
| Float32 : float dtype
| Float64 : float dtype
type descr = D_int8 | D_uint8 | D_float32 | D_float64
type t = {
shape : int array;
fortran_order : bool;
descr : descr;
data : string;
}
let find_substring haystack needle =
let nlen = String.length needle in
let hlen = String.length haystack in
let rec search i =
if i + nlen > hlen then None
else if String.sub haystack i nlen = needle then Some i
else search (i + 1)
in
search 0
let key =
let pattern = "'" ^ key ^ "': " in
match find_substring header pattern with
| None -> Error (Printf.sprintf "missing key: %s" key)
| Some i ->
let start = i + String.length pattern in
if start >= String.length header then Error (Printf.sprintf "truncated value for key: %s" key)
else
let c = header.[start] in
if c = '\'' then
let value_start = start + 1 in
(match find_substring (String.sub header value_start (String.length header - value_start)) "'" with
| None -> Error (Printf.sprintf "unterminated string for key: %s" key)
| Some len -> Ok (String.sub header value_start len))
else
let rec find_end j =
if j >= String.length header then j
else match header.[j] with
| ',' | '}' | ')' -> j
| _ -> find_end (j + 1)
in
let end_pos = find_end start in
let value = String.trim (String.sub header start (end_pos - start)) in
Ok value
let parse_descr s =
match s with
| "|i1" -> Ok D_int8
| "|u1" -> Ok D_uint8
| "<f4" -> Ok D_float32
| "<f8" -> Ok D_float64
| _ -> Error (Printf.sprintf "unsupported dtype: %s" s)
let parse_fortran_order s =
match s with
| "True" -> Ok true
| "False" -> Ok false
| _ -> Error (Printf.sprintf "invalid fortran_order: %s" s)
let parse_shape =
let pattern = "'shape': (" in
match find_substring header pattern with
| None -> Error "missing shape"
| Some i ->
let start = i + String.length pattern in
(match find_substring (String.sub header start (String.length header - start)) ")" with
| None -> Error "unterminated shape"
| Some len ->
let shape_str = String.sub header start len in
let shape_str = String.trim shape_str in
if shape_str = "" then Ok [||]
else
let parts = String.split_on_char ',' shape_str in
let parts = List.filter (fun s -> String.trim s <> "") parts in
let dims = List.map (fun s -> int_of_string (String.trim s)) parts in
Ok (Array.of_list dims))
let of_string s =
let len = String.length s in
if len < 10 then Error "too short for .npy file"
else if String.sub s 0 6 <> "\x93NUMPY" then Error "bad magic number"
else
let major = Char.code s.[6] in
let _minor = Char.code s.[7] in
let , =
if major = 1 then
let hl = Char.code s.[8] lor (Char.code s.[9] lsl 8) in
(hl, 10)
else if major = 2 then
if len < 12 then (0, 12)
else
let hl =
Char.code s.[8]
lor (Char.code s.[9] lsl 8)
lor (Char.code s.[10] lsl 16)
lor (Char.code s.[11] lsl 24)
in
(hl, 12)
else (0, 10)
in
if header_offset + header_len > len then Error "truncated header"
else
let = String.sub s header_offset header_len in
match extract_quoted_value header "descr" with
| Error e -> Error e
| Ok descr_str ->
match parse_descr descr_str with
| Error e -> Error e
| Ok descr ->
match extract_quoted_value header "fortran_order" with
| Error e -> Error e
| Ok fo_str ->
match parse_fortran_order fo_str with
| Error e -> Error e
| Ok fortran_order ->
match parse_shape header with
| Error e -> Error e
| Ok shape ->
let data_offset = header_offset + header_len in
let data = String.sub s data_offset (len - data_offset) in
Ok { shape; fortran_order; descr; data }
let shape t = t.shape
let fortran_order t = t.fortran_order
let data_int8 t =
match t.descr with
| D_int8 ->
let n = String.length t.data in
let ba = Bigarray.Array1.create Bigarray.int8_signed Bigarray.c_layout n in
for i = 0 to n - 1 do
let v = Char.code t.data.[i] in
let v = if v >= 128 then v - 256 else v in
Bigarray.Array1.set ba i v
done;
Some ba
| _ -> None
let data_uint8 t =
match t.descr with
| D_uint8 ->
let n = String.length t.data in
let ba = Bigarray.Array1.create Bigarray.int8_unsigned Bigarray.c_layout n in
for i = 0 to n - 1 do
Bigarray.Array1.set ba i (Char.code t.data.[i])
done;
Some ba
| _ -> None
let read_le_int32 s off =
let b0 = Char.code s.[off] in
let b1 = Char.code s.[off + 1] in
let b2 = Char.code s.[off + 2] in
let b3 = Char.code s.[off + 3] in
Int32.logor
(Int32.of_int b0)
(Int32.logor
(Int32.shift_left (Int32.of_int b1) 8)
(Int32.logor
(Int32.shift_left (Int32.of_int b2) 16)
(Int32.shift_left (Int32.of_int b3) 24)))
let read_le_int64 s off =
let b i = Int64.of_int (Char.code s.[off + i]) in
let ( lor ) = Int64.logor in
let ( lsl ) = Int64.shift_left in
(b 0) lor ((b 1) lsl 8) lor ((b 2) lsl 16) lor ((b 3) lsl 24)
lor ((b 4) lsl 32) lor ((b 5) lsl 40) lor ((b 6) lsl 48) lor ((b 7) lsl 56)
let data_float32 t =
match t.descr with
| D_float32 ->
let n = String.length t.data / 4 in
let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout n in
for i = 0 to n - 1 do
let bits = read_le_int32 t.data (i * 4) in
Bigarray.Array1.set ba i (Int32.float_of_bits bits)
done;
Some ba
| _ -> None
let data_float64 t =
match t.descr with
| D_float64 ->
let n = String.length t.data / 8 in
let ba = Bigarray.Array1.create Bigarray.float64 Bigarray.c_layout n in
for i = 0 to n - 1 do
let bits = read_le_int64 t.data (i * 8) in
Bigarray.Array1.set ba i (Int64.float_of_bits bits)
done;
Some ba
| _ -> None