bin: Make decoder type creation overridable.
[coe.git] / coe / bin.py
1 from . import data
2
3 T_END = 0
4 T_INT = 1
5 T_STR = 2
6 T_BIT = 3
7 T_NIL = 4
8 T_SYM = 5
9 T_CON = 6
10
11 INT_REF = 1
12
13 STR_SYM = 1
14
15 BIT_BFLOAT = 1
16 BIT_DFLOAT = 2
17
18 CON_SEQ = 0
19 CON_SET = 1
20 CON_MAP = 2
21 CON_OBJ = 3
22
23 NIL_FALSE = 1
24 NIL_TRUE = 2
25
26 class encoder(object):
27     def __init__(self, *, backrefs=True):
28         self.backrefs = backrefs
29         self.reftab = {}
30         self.nextref = 0
31         self.nstab = {}
32
33     @staticmethod
34     def enctag(pri, sec):
35         return bytes([(sec << 3) | pri])
36
37     def writetag(self, dst, pri, sec, datum):
38         dst.write(self.enctag(pri, sec))
39         if self.backrefs:
40             ref = self.nextref
41             self.nextref += 1
42             if datum is not None and id(datum) not in self.reftab:
43                 self.reftab[id(datum)] = ref
44             return ref
45         return None
46
47     @staticmethod
48     def encint(x):
49         ret = bytearray()
50         if x >= 0:
51             b = x & 0x7f
52             x >>= 7
53             while (x > 0) or (b & 0x40) != 0:
54                 ret.append(0x80 | b)
55                 b = x & 0x7f
56                 x >>= 7
57             ret.append(b)
58         elif x < 0:
59             b = x & 0x7f
60             x >>= 7
61             while x < -1 or (b & 0x40) == 0:
62                 ret.append(0x80 | b)
63                 b = x & 0x7f
64                 x >>= 7
65             ret.append(b)
66         return ret
67
68     @staticmethod
69     def writestr(dst, text):
70         dst.write(text.encode("utf-8"))
71         dst.write(b'\0')
72
73     def dumpseq(self, dst, seq):
74         for v in seq:
75             self.dump(dst, v)
76         dst.write(self.enctag(T_END, 0))
77
78     def dumpmap(self, dst, val):
79         for k, v in val.items():
80             self.dump(dst, k)
81             self.dump(dst, v)
82         dst.write(self.enctag(T_END, 0))
83
84     def dump(self, dst, datum):
85         ref = self.reftab.get(id(datum))
86         if ref is not None:
87             dst.write(self.enctag(T_INT, INT_REF))
88             dst.write(self.encint(ref))
89             return
90         if datum == None:
91             self.writetag(dst, T_NIL, 0, None)
92         elif datum == False:
93             self.writetag(dst, T_NIL, NIL_FALSE, None)
94         elif datum == True:
95             self.writetag(dst, T_NIL, NIL_TRUE, None)
96         elif isinstance(datum, int):
97             self.writetag(dst, T_INT, 0, None)
98             dst.write(self.encint(datum))
99         elif isinstance(datum, str):
100             self.writetag(dst, T_STR, 0, datum)
101             self.writestr(dst, datum)
102         elif isinstance(datum, (bytes, bytearray)):
103             self.writetag(dst, T_BIT, 0, datum)
104             dst.write(self.encint(len(datum)))
105             dst.write(datum)
106         elif isinstance(datum, data.symbol):
107             if datum.ns == "":
108                 self.writetag(dst, T_STR, STR_SYM, datum)
109                 self.writestr(dst, datum.name)
110             else:
111                 nsref = self.nstab.get(datum.ns)
112                 if nsref is None:
113                     nsref = self.writetag(dst, T_SYM, 0, datum)
114                     dst.write(b'\0')
115                     self.writestr(dst, datum.ns)
116                     self.writestr(dst, datum.name)
117                     if nsref is not None:
118                         self.nstab[datum.ns] = nsref
119                 else:
120                     self.writetag(dst, T_SYM, 0, datum)
121                     dst.write(b'\x01')
122                     dst.write(self.encint(nsref))
123                     self.writestr(dst, datum.name)
124         elif isinstance(datum, list):
125             self.writetag(dst, T_CON, CON_SEQ, datum)
126             self.dumpseq(dst, datum)
127         elif isinstance(datum, set):
128             self.writetag(dst, T_CON, CON_SET, datum)
129             self.dumpseq(dst, datum)
130         elif isinstance(datum, dict):
131             self.writetag(dst, T_CON, CON_MAP, datum)
132             self.dumpmap(dst, datum)
133         elif isinstance(datum, data.obj):
134             self.writetag(dst, T_CON, CON_OBJ, datum)
135             self.dump(dst, getattr(type(datum), "typename", None))
136             self.dumpmap(dst, datum.__dict__)
137         else:
138             raise ValueError("unsupported object type: " + repr(datum))
139
140 def dump(dst, datum):
141     encoder().dump(dst, datum)
142     return dst
143
144 class fmterror(Exception):
145     pass
146
147 class eoferror(fmterror):
148     def __init__(self):
149         super().__init__("unexpected end-of-data")
150
151 class referror(fmterror):
152     def __init__(self):
153         super().__init__("bad backref")
154
155 class namedtype(type):
156     def __new__(cls, *args, typename=None, **kwargs):
157         self = super().__new__(cls, *args, **kwargs)
158         self.typename = typename
159         return self
160
161 class decoder(object):
162     def __init__(self):
163         self.reftab = []
164         self.namedtypes = {}
165
166     @staticmethod
167     def byte(fp):
168         b = fp.read(1)
169         if b == b"":
170             raise eoferror()
171         return b[0]
172
173     @staticmethod
174     def loadint(fp):
175         ret = 0
176         p = 0
177         while True:
178             b = decoder.byte(fp)
179             ret += (b & 0x7f) << p
180             p += 7
181             if (b & 0x80) == 0:
182                 break
183         if (b & 0x40) != 0:
184             ret = ret - (1 << p)
185         return ret
186
187     @staticmethod
188     def loadstr(fp):
189         buf = bytearray()
190         while True:
191             b = decoder.byte(fp)
192             if b == 0:
193                 break
194             buf.append(b)
195         return buf.decode("utf-8")
196
197     def loadsym(self, fp):
198         h = self.byte(fp)
199         if h & 0x1:
200             nsref = self.loadint(fp)
201             if not 0 <= nsref < len(self.reftab):
202                 raise fmterror("illegal namespace ref: " + str(nsref))
203             nssym = self.reftab[nsref]
204             if not isinstance(nssym, data.symbol):
205                 raise fmterror("illegal namespace ref: " + str(nsref))
206             ns = nssym.ns
207         else:
208             ns = self.loadstr(fp)
209         nm = self.loadstr(fp)
210         ret = data.symbol.get(ns, nm)
211         return ret
212
213     def loadlist(self, fp, buf):
214         while True:
215             tag = self.byte(fp)
216             if tag == T_END:
217                 return buf
218             buf.append(self.loadtagged(fp, tag))
219
220     def loadmap(self, fp, buf):
221         while True:
222             tag = self.byte(fp)
223             if tag == T_END:
224                 return buf
225             key = self.loadtagged(fp, tag)
226             tag = self.byte(fp)
227             if tag == T_END:
228                 return buf
229             buf[key] = self.loadtagged(fp, tag)
230
231     def makeobjtype(self, nm):
232         return namedtype(str(nm), (data.obj, object), {}, typename=nm)
233
234     def loadobj(self, fp, ref=False):
235         if ref:
236             refid = len(self.reftab)
237             self.reftab.append(None)
238         nm = self.load(fp)
239         typ = self.namedtypes.get(nm)
240         if typ is None:
241             typ = self.namedtypes[nm] = self.makeobjtype(nm)
242         ret = typ()
243         if ref:
244             self.reftab[refid] = ret
245         # st = fp.tell()
246         # print(">", nm, hex(st))
247         ret.__dict__.update(self.loadmap(fp, {}))
248         # print("<", nm, hex(fp.tell()), hex(st))
249         return ret
250
251     def addref(self, obj):
252         self.reftab.append(obj)
253         return obj
254
255     def loadtagged(self, fp, tag):
256         pri, sec = (tag & 0x7), (tag & 0xf8) >> 3
257         if pri == T_END:
258             raise fmterror("unexpected end-tag")
259         elif pri == T_INT:
260             if sec == INT_REF:
261                 idx = self.loadint(fp)
262                 if not 0 <= idx < len(self.reftab):
263                     raise referror()
264                 # print(idx, self.reftab[idx], hex(fp.tell()))
265                 return self.reftab[idx]
266             return self.addref(self.loadint(fp))
267         elif pri == T_STR:
268             ret = self.addref(self.loadstr(fp))
269             if sec == STR_SYM:
270                 return data.symbol.get("", ret)
271             return ret
272         elif pri == T_BIT:
273             ln = self.loadint(fp)
274             ret = self.addref(fp.read(ln))
275             if len(ret) < ln:
276                 raise eoferror()
277             return ret
278         elif pri == T_NIL:
279             if sec == NIL_TRUE:
280                 return self.addref(True)
281             elif sec == NIL_FALSE:
282                 return self.addref(False)
283             return self.addref(None)
284         elif pri == T_SYM:
285             return self.addref(self.loadsym(fp))
286         elif pri == T_CON:
287             if sec == CON_MAP:
288                 return self.loadmap(fp, self.addref({}))
289             elif sec == CON_OBJ:
290                 return self.loadobj(fp, ref=True)
291             else:
292                 return self.loadlist(fp, self.addref([]))
293         else:
294             raise fmterror("unknown primary: " + str(pri))
295
296     def load(self, fp):
297         tag = self.byte(fp)
298         return self.loadtagged(fp, tag)
299
300 def load(fp):
301     decoder().load(fp)