Added binary encoder.
[coe.git] / coe / bin.py
1 from . import data
2
3 T_END = 0
4 T_INT = 1
5 T_STR = 2
6 T_BIT = 3
7 T_NIL = 4
8 T_SYM = 5
9 T_CON = 6
10
11 INT_REF = 1
12
13 STR_SYM = 1
14
15 BIT_BFLOAT = 1
16 BIT_DFLOAT = 2
17
18 CON_SEQ = 0
19 CON_SET = 1
20 CON_MAP = 2
21 CON_OBJ = 3
22
23 NIL_FALSE = 1
24 NIL_TRUE = 2
25
26 class encoder(object):
27     def __init__(self, *, backrefs=True):
28         self.backrefs = backrefs
29         self.reftab = {}
30         self.nextref = 0
31         self.nstab = {}
32
33     @staticmethod
34     def enctag(pri, sec):
35         return bytes([(sec << 3) | pri])
36
37     def writetag(self, dst, pri, sec, datum):
38         dst.write(self.enctag(pri, sec))
39         if self.backrefs:
40             ref = self.nextref
41             self.nextref += 1
42             if datum is not None and id(datum) not in self.reftab:
43                 self.reftab[id(datum)] = ref
44             return ref
45         return None
46
47     @staticmethod
48     def encint(x):
49         ret = bytearray()
50         if x >= 0:
51             b = x & 0x7f
52             x >>= 7
53             while (x > 0) or (b & 0x40) != 0:
54                 ret.append(0x80 | b)
55                 b = x & 0x7f
56                 x >>= 7
57             ret.append(b)
58         elif x < 0:
59             b = x & 0x7f
60             x >>= 7
61             while x < -1 or (b & 0x40) == 0:
62                 ret.append(0x80 | b)
63                 b = x & 0x7f
64                 x >>= 7
65             ret.append(b)
66         return ret
67
68     @staticmethod
69     def writestr(dst, text):
70         dst.write(text.encode("utf-8"))
71         dst.write(b'\0')
72
73     def dumpseq(self, dst, seq):
74         for v in seq:
75             self.dump(dst, v)
76         dst.write(self.enctag(T_END, 0))
77
78     def dumpmap(self, dst, val):
79         for k, v in val.items():
80             self.dump(dst, k)
81             self.dump(dst, v)
82         dst.write(self.enctag(T_END, 0))
83
84     def dump(self, dst, datum):
85         ref = self.reftab.get(id(datum))
86         if ref is not None:
87             dst.write(self.enctag(T_INT, INT_REF))
88             dst.write(self.encint(ref))
89             return
90         if datum == None:
91             self.writetag(dst, T_NIL, 0, None)
92         elif datum == False:
93             self.writetag(dst, T_NIL, NIL_FALSE, None)
94         elif datum == True:
95             self.writetag(dst, T_NIL, NIL_TRUE, None)
96         elif isinstance(datum, int):
97             self.writetag(dst, T_INT, 0, None)
98             dst.write(self.encint(datum))
99         elif isinstance(datum, str):
100             self.writetag(dst, T_STR, 0, datum)
101             self.writestr(dst, datum)
102         elif isinstance(datum, (bytes, bytearray)):
103             self.writetag(dst, T_BIT, 0, datum)
104             dst.write(self.encint(len(datum)))
105             dst.write(datum)
106         elif isinstance(datum, data.symbol):
107             if datum.ns == "":
108                 self.writetag(dst, T_STR, STR_SYM, datum)
109                 self.writestr(dst, datum.name)
110             else:
111                 nsref = self.nstab.get(datum.ns)
112                 if nsref is None:
113                     nsref = self.writetag(dst, T_SYM, 0, datum)
114                     dst.write(b'\0')
115                     self.writestr(dst, datum.ns)
116                     self.writestr(dst, datum.name)
117                     if nsref is not None:
118                         self.nstab[datum.ns] = nsref
119                 else:
120                     self.writetag(dst, T_SYM, 0, datum)
121                     dst.write(b'\x01')
122                     dst.write(self.encint(nsref))
123                     self.writestr(dst, datum.name)
124         elif isinstance(datum, list):
125             self.writetag(dst, T_CON, CON_SEQ, datum)
126             self.dumpseq(dst, datum)
127         elif isinstance(datum, set):
128             self.writetag(dst, T_CON, CON_SET, datum)
129             self.dumpseq(dst, datum)
130         elif isinstance(datum, dict):
131             self.writetag(dst, T_CON, CON_MAP, datum)
132             self.dumpmap(dst, datum)
133         elif isinstance(datum, data.obj):
134             self.writetag(dst, T_CON, CON_OBJ, datum)
135             self.dump(dst, getattr(type(datum), "typename", None))
136             self.dumpmap(dst, datum.__dict__)
137         else:
138             raise ValueError("unsupported object type: " + repr(datum))
139
140 def dump(dst, datum):
141     encoder().dump(dst, datum)
142     return dst
143
144 class fmterror(Exception):
145     pass
146
147 class eoferror(fmterror):
148     def __init__(self):
149         super().__init__("unexpected end-of-data")
150
151 class referror(fmterror):
152     def __init__(self):
153         super().__init__("bad backref")
154
155 class namedtype(type):
156     pass
157
158 class decoder(object):
159     def __init__(self):
160         self.reftab = []
161         self.namedtypes = {}
162
163     @staticmethod
164     def byte(fp):
165         b = fp.read(1)
166         if b == b"":
167             raise eoferror()
168         return b[0]
169
170     @staticmethod
171     def loadint(fp):
172         ret = 0
173         p = 0
174         while True:
175             b = decoder.byte(fp)
176             ret += (b & 0x7f) << p
177             p += 7
178             if (b & 0x80) == 0:
179                 break
180         if (b & 0x40) != 0:
181             ret = ret - (1 << p)
182         return ret
183
184     @staticmethod
185     def loadstr(fp):
186         buf = bytearray()
187         while True:
188             b = decoder.byte(fp)
189             if b == 0:
190                 break
191             buf.append(b)
192         return buf.decode("utf-8")
193
194     def loadsym(self, fp):
195         h = self.byte(fp)
196         if h & 0x1:
197             nsref = self.loadint(fp)
198             if not 0 <= nsref < len(self.reftab):
199                 raise fmterror("illegal namespace ref: " + str(nsref))
200             nssym = self.reftab[nsref]
201             if not isinstance(nssym, data.symbol):
202                 raise fmterror("illegal namespace ref: " + str(nsref))
203             ns = nssym.ns
204         else:
205             ns = self.loadstr(fp)
206         nm = self.loadstr(fp)
207         ret = data.symbol.get(ns, nm)
208         return ret
209
210     def loadlist(self, fp, buf):
211         while True:
212             tag = self.byte(fp)
213             if tag == T_END:
214                 return buf
215             buf.append(self.loadtagged(fp, tag))
216
217     def loadmap(self, fp, buf):
218         while True:
219             tag = self.byte(fp)
220             if tag == T_END:
221                 return buf
222             key = self.loadtagged(fp, tag)
223             tag = self.byte(fp)
224             if tag == T_END:
225                 return buf
226             buf[key] = self.loadtagged(fp, tag)
227
228     def loadobj(self, fp, ref=False):
229         if ref:
230             refid = len(self.reftab)
231             self.reftab.append(None)
232         nm = self.load(fp)
233         typ = self.namedtypes.get(nm)
234         if typ is None:
235             typ = self.namedtypes[nm] = namedtype(str(nm), (data.obj, object), {})
236             typ.typename = nm
237         ret = typ()
238         if ref:
239             self.reftab[refid] = ret
240         # st = fp.tell()
241         # print(">", nm, hex(st))
242         ret.__dict__.update(self.loadmap(fp, {}))
243         # print("<", nm, hex(fp.tell()), hex(st))
244         return ret
245
246     def addref(self, obj):
247         self.reftab.append(obj)
248         return obj
249
250     def loadtagged(self, fp, tag):
251         pri, sec = (tag & 0x7), (tag & 0xf8) >> 3
252         if pri == T_END:
253             raise fmterror("unexpected end-tag")
254         elif pri == T_INT:
255             if sec == INT_REF:
256                 idx = self.loadint(fp)
257                 if not 0 <= idx < len(self.reftab):
258                     raise referror()
259                 # print(idx, self.reftab[idx], hex(fp.tell()))
260                 return self.reftab[idx]
261             return self.addref(self.loadint(fp))
262         elif pri == T_STR:
263             ret = self.addref(self.loadstr(fp))
264             if sec == STR_SYM:
265                 return data.symbol.get("", ret)
266             return ret
267         elif pri == T_BIT:
268             ln = self.loadint(fp)
269             ret = self.addref(fp.read(ln))
270             if len(ret) < ln:
271                 raise eoferror()
272             return ret
273         elif pri == T_NIL:
274             if sec == NIL_TRUE:
275                 return self.addref(True)
276             elif sec == NIL_FALSE:
277                 return self.addref(False)
278             return self.addref(None)
279         elif pri == T_SYM:
280             return self.addref(self.loadsym(fp))
281         elif pri == T_CON:
282             if sec == CON_MAP:
283                 return self.loadmap(fp, self.addref({}))
284             elif sec == CON_OBJ:
285                 return self.loadobj(fp, ref=True)
286             else:
287                 return self.loadlist(fp, self.addref([]))
288         else:
289             raise fmterror("unknown primary: " + str(pri))
290
291     def load(self, fp):
292         tag = self.byte(fp)
293         return self.loadtagged(fp, tag)
294
295 def load(fp):
296     decoder().load(fp)