bin: Fixed float decoding bug.
[coe.git] / coe / bin.py
CommitLineData
e39561b7 1import io, math
e8a122ff
FT
2from . import data
3
4T_END = 0
5T_INT = 1
6T_STR = 2
7T_BIT = 3
8T_NIL = 4
9T_SYM = 5
10T_CON = 6
11
12INT_REF = 1
13
14STR_SYM = 1
15
16BIT_BFLOAT = 1
17BIT_DFLOAT = 2
18
82855e29 19CON_SEQ = 0
e8a122ff
FT
20CON_SET = 1
21CON_MAP = 2
22CON_OBJ = 3
23
24NIL_FALSE = 1
25NIL_TRUE = 2
26
82855e29
FT
27class encoder(object):
28 def __init__(self, *, backrefs=True):
29 self.backrefs = backrefs
30 self.reftab = {}
31 self.nextref = 0
32 self.nstab = {}
33
34 @staticmethod
35 def enctag(pri, sec):
36 return bytes([(sec << 3) | pri])
37
38 def writetag(self, dst, pri, sec, datum):
39 dst.write(self.enctag(pri, sec))
40 if self.backrefs:
41 ref = self.nextref
42 self.nextref += 1
43 if datum is not None and id(datum) not in self.reftab:
44 self.reftab[id(datum)] = ref
45 return ref
46 return None
47
48 @staticmethod
49 def encint(x):
50 ret = bytearray()
51 if x >= 0:
52 b = x & 0x7f
53 x >>= 7
54 while (x > 0) or (b & 0x40) != 0:
55 ret.append(0x80 | b)
56 b = x & 0x7f
57 x >>= 7
58 ret.append(b)
59 elif x < 0:
60 b = x & 0x7f
61 x >>= 7
62 while x < -1 or (b & 0x40) == 0:
63 ret.append(0x80 | b)
64 b = x & 0x7f
65 x >>= 7
66 ret.append(b)
67 return ret
68
69 @staticmethod
70 def writestr(dst, text):
71 dst.write(text.encode("utf-8"))
72 dst.write(b'\0')
73
e39561b7
FT
74 @staticmethod
75 def writefloat(dst, x):
76 if x == 0.0:
77 mnt, exp = (0, 0)
78 elif math.isinf(x):
79 if x > 0:
80 mnt, exp = (0, 2)
81 else:
82 mnt, exp = (0, 3)
83 elif math.isnan(x):
84 mnt, exp = (0, 4)
85 else:
86 mnt, exp = math.frexp(x)
87 mnt *= 2
88 exp -= 1
89 while mnt != int(mnt):
90 mnt *= 2
91 mnt = int(mnt)
92 buf = bytearray()
93 buf.extend(encoder.encint(mnt))
94 buf.extend(encoder.encint(exp))
95 dst.write(encoder.encint(len(buf)))
96 dst.write(buf)
97
82855e29
FT
98 def dumpseq(self, dst, seq):
99 for v in seq:
100 self.dump(dst, v)
101 dst.write(self.enctag(T_END, 0))
102
103 def dumpmap(self, dst, val):
104 for k, v in val.items():
105 self.dump(dst, k)
106 self.dump(dst, v)
107 dst.write(self.enctag(T_END, 0))
108
109 def dump(self, dst, datum):
110 ref = self.reftab.get(id(datum))
111 if ref is not None:
112 dst.write(self.enctag(T_INT, INT_REF))
113 dst.write(self.encint(ref))
114 return
115 if datum == None:
116 self.writetag(dst, T_NIL, 0, None)
72f60d5b 117 elif datum is False:
82855e29 118 self.writetag(dst, T_NIL, NIL_FALSE, None)
72f60d5b 119 elif datum is True:
82855e29
FT
120 self.writetag(dst, T_NIL, NIL_TRUE, None)
121 elif isinstance(datum, int):
122 self.writetag(dst, T_INT, 0, None)
123 dst.write(self.encint(datum))
124 elif isinstance(datum, str):
125 self.writetag(dst, T_STR, 0, datum)
126 self.writestr(dst, datum)
127 elif isinstance(datum, (bytes, bytearray)):
128 self.writetag(dst, T_BIT, 0, datum)
129 dst.write(self.encint(len(datum)))
130 dst.write(datum)
e39561b7
FT
131 elif isinstance(datum, float):
132 self.writetag(dst, T_BIT, BIT_BFLOAT, datum)
133 self.writefloat(dst, datum)
82855e29
FT
134 elif isinstance(datum, data.symbol):
135 if datum.ns == "":
136 self.writetag(dst, T_STR, STR_SYM, datum)
137 self.writestr(dst, datum.name)
138 else:
139 nsref = self.nstab.get(datum.ns)
140 if nsref is None:
141 nsref = self.writetag(dst, T_SYM, 0, datum)
142 dst.write(b'\0')
143 self.writestr(dst, datum.ns)
144 self.writestr(dst, datum.name)
145 if nsref is not None:
146 self.nstab[datum.ns] = nsref
147 else:
148 self.writetag(dst, T_SYM, 0, datum)
149 dst.write(b'\x01')
150 dst.write(self.encint(nsref))
151 self.writestr(dst, datum.name)
152 elif isinstance(datum, list):
153 self.writetag(dst, T_CON, CON_SEQ, datum)
154 self.dumpseq(dst, datum)
155 elif isinstance(datum, set):
156 self.writetag(dst, T_CON, CON_SET, datum)
157 self.dumpseq(dst, datum)
158 elif isinstance(datum, dict):
159 self.writetag(dst, T_CON, CON_MAP, datum)
160 self.dumpmap(dst, datum)
161 elif isinstance(datum, data.obj):
162 self.writetag(dst, T_CON, CON_OBJ, datum)
163 self.dump(dst, getattr(type(datum), "typename", None))
164 self.dumpmap(dst, datum.__dict__)
165 else:
166 raise ValueError("unsupported object type: " + repr(datum))
167
168def dump(dst, datum):
169 encoder().dump(dst, datum)
170 return dst
171
e8a122ff
FT
172class fmterror(Exception):
173 pass
174
175class eoferror(fmterror):
176 def __init__(self):
177 super().__init__("unexpected end-of-data")
178
179class referror(fmterror):
180 def __init__(self):
181 super().__init__("bad backref")
182
e8a122ff
FT
183class decoder(object):
184 def __init__(self):
185 self.reftab = []
186 self.namedtypes = {}
187
188 @staticmethod
189 def byte(fp):
190 b = fp.read(1)
191 if b == b"":
192 raise eoferror()
193 return b[0]
194
195 @staticmethod
196 def loadint(fp):
197 ret = 0
198 p = 0
199 while True:
200 b = decoder.byte(fp)
201 ret += (b & 0x7f) << p
202 p += 7
203 if (b & 0x80) == 0:
204 break
205 if (b & 0x40) != 0:
206 ret = ret - (1 << p)
207 return ret
208
209 @staticmethod
210 def loadstr(fp):
211 buf = bytearray()
212 while True:
213 b = decoder.byte(fp)
214 if b == 0:
215 break
216 buf.append(b)
217 return buf.decode("utf-8")
218
219 def loadsym(self, fp):
220 h = self.byte(fp)
221 if h & 0x1:
222 nsref = self.loadint(fp)
223 if not 0 <= nsref < len(self.reftab):
224 raise fmterror("illegal namespace ref: " + str(nsref))
225 nssym = self.reftab[nsref]
226 if not isinstance(nssym, data.symbol):
227 raise fmterror("illegal namespace ref: " + str(nsref))
228 ns = nssym.ns
229 else:
230 ns = self.loadstr(fp)
231 nm = self.loadstr(fp)
232 ret = data.symbol.get(ns, nm)
233 return ret
234
235 def loadlist(self, fp, buf):
236 while True:
237 tag = self.byte(fp)
238 if tag == T_END:
239 return buf
240 buf.append(self.loadtagged(fp, tag))
241
242 def loadmap(self, fp, buf):
243 while True:
244 tag = self.byte(fp)
245 if tag == T_END:
246 return buf
247 key = self.loadtagged(fp, tag)
248 tag = self.byte(fp)
249 if tag == T_END:
250 return buf
251 buf[key] = self.loadtagged(fp, tag)
252
26256ae3 253 def makeobjtype(self, nm):
ef6415d5 254 return data.namedtype.make(str(nm), (data.obj, object), {}, typename=nm)
26256ae3 255
e8a122ff
FT
256 def loadobj(self, fp, ref=False):
257 if ref:
258 refid = len(self.reftab)
259 self.reftab.append(None)
260 nm = self.load(fp)
261 typ = self.namedtypes.get(nm)
262 if typ is None:
26256ae3 263 typ = self.namedtypes[nm] = self.makeobjtype(nm)
e8a122ff
FT
264 ret = typ()
265 if ref:
266 self.reftab[refid] = ret
267 # st = fp.tell()
268 # print(">", nm, hex(st))
269 ret.__dict__.update(self.loadmap(fp, {}))
270 # print("<", nm, hex(fp.tell()), hex(st))
271 return ret
272
273 def addref(self, obj):
274 self.reftab.append(obj)
275 return obj
276
277 def loadtagged(self, fp, tag):
278 pri, sec = (tag & 0x7), (tag & 0xf8) >> 3
279 if pri == T_END:
280 raise fmterror("unexpected end-tag")
281 elif pri == T_INT:
282 if sec == INT_REF:
283 idx = self.loadint(fp)
284 if not 0 <= idx < len(self.reftab):
285 raise referror()
286 # print(idx, self.reftab[idx], hex(fp.tell()))
287 return self.reftab[idx]
288 return self.addref(self.loadint(fp))
289 elif pri == T_STR:
5fa20111 290 ret = self.loadstr(fp)
e8a122ff 291 if sec == STR_SYM:
5fa20111
FT
292 return self.addref(data.symbol.get("", ret))
293 return self.addref(ret)
e8a122ff
FT
294 elif pri == T_BIT:
295 ln = self.loadint(fp)
e39561b7 296 ret = fp.read(ln)
e8a122ff
FT
297 if len(ret) < ln:
298 raise eoferror()
e39561b7
FT
299 if sec == BIT_BFLOAT:
300 buf = io.BytesIO(ret)
301 mnt = self.loadint(buf)
302 exp = self.loadint(buf)
303 if mnt == 0:
304 if exp == 0:
cef5442b 305 ret = 0.0
e39561b7 306 elif exp == 1:
cef5442b 307 ret = -0.0
e39561b7 308 elif exp == 2:
cef5442b 309 ret = float("inf")
e39561b7 310 elif exp == 3:
cef5442b 311 ret = -float("inf")
e39561b7 312 else:
cef5442b 313 ret = float("nan")
e39561b7
FT
314 else:
315 ret = math.ldexp(mnt, exp - (mnt.bit_length() - 1))
316 return self.addref(ret)
317 return self.addref(ret)
e8a122ff
FT
318 elif pri == T_NIL:
319 if sec == NIL_TRUE:
320 return self.addref(True)
321 elif sec == NIL_FALSE:
322 return self.addref(False)
323 return self.addref(None)
324 elif pri == T_SYM:
325 return self.addref(self.loadsym(fp))
326 elif pri == T_CON:
327 if sec == CON_MAP:
328 return self.loadmap(fp, self.addref({}))
329 elif sec == CON_OBJ:
330 return self.loadobj(fp, ref=True)
331 else:
332 return self.loadlist(fp, self.addref([]))
333 else:
334 raise fmterror("unknown primary: " + str(pri))
335
336 def load(self, fp):
337 tag = self.byte(fp)
338 return self.loadtagged(fp, tag)
339
340def load(fp):
50ffd2c3 341 return decoder().load(fp)