ff291b2db1589b9c52b711f3c04cce4d87261ad2
[didex.git] / didex / index.py
1 import struct, contextlib, math
2 from . import db, lib
3 from .db import bd, txnfun, dloopfun
4
5 __all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
6
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
9
10 class simpletype(object):
11     def __init__(self, encode, decode):
12         self.enc = encode
13         self.dec = decode
14
15     def encode(self, ob):
16         return self.enc(ob)
17     def decode(self, dat):
18         return self.dec(dat)
19     def compare(self, a, b):
20         if a < b:
21             return -1
22         elif a > b:
23             return 1
24         else:
25             return 0
26
27     @classmethod
28     def struct(cls, fmt):
29         return cls(lambda ob: struct.pack(fmt, ob),
30                    lambda dat: struct.unpack(fmt, dat)[0])
31
32 class foldtype(simpletype):
33     def __init__(self, encode, decode, fold):
34         super().__init__(encode, decode)
35         self.fold = fold
36
37     def compare(self, a, b):
38         return super().compare(self.fold(a), self.fold(b))
39
40 class maybe(object):
41     def __init__(self, bk):
42         self.bk = bk
43
44     def encode(self, ob):
45         if ob is None: return b""
46         return b"\0" + self.bk.encode(ob)
47     def decode(self, dat):
48         if dat == b"": return None
49         return self.bk.dec(dat[1:])
50     def compare(self, a, b):
51         if a is b is None:
52             return 0
53         elif a is None:
54             return -1
55         elif b is None:
56             return 1
57         else:
58             return self.bk.compare(a[1:], b[1:])
59
60 class compound(object):
61     def __init__(self, *parts):
62         self.parts = parts
63
64     small = object()
65     large = object()
66     def minim(self, *parts):
67         return parts + tuple([self.small] * (len(self.parts) - len(parts)))
68     def maxim(self, *parts):
69         return parts + tuple([self.large] * (len(self.parts) - len(parts)))
70
71     def encode(self, obs):
72         if len(obs) != len(self.parts):
73             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
74         buf = bytearray()
75         for ob, part in zip(obs, self.parts):
76             if ob is self.small:
77                 buf.append(0x01)
78             elif ob is self.large:
79                 buf.append(0x02)
80             else:
81                 dat = part.encode(ob)
82                 if len(dat) < 128:
83                     buf.append(0x80 | len(dat))
84                     buf.extend(dat)
85                 else:
86                     buf.extend(struct.pack(">BI", 0, len(dat)))
87                     buf.extend(dat)
88         return bytes(buf)
89     def decode(self, dat):
90         ret = []
91         off = 0
92         for part in self.parts:
93             fl = dat[off]
94             off += 1
95             if fl & 0x80:
96                 ln = fl & 0x7f
97             elif fl == 0x01:
98                 ret.append(self.small)
99                 continue
100             elif fl == 0x02:
101                 ret.append(self.large)
102                 continue
103             else:
104                 ln = struct.unpack(">I", dat[off:off + 4])[0]
105                 off += 4
106             ret.append(part.decode(dat[off:off + ln]))
107             off += ln
108         return tuple(ret)
109     def compare(self, al, bl):
110         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
111             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
112         for a, b, part in zip(al, bl, self.parts):
113             if a in (self.small, self.large) or b in (self.small, self.large):
114                 if a is b:
115                     return 0
116                 if a is self.small:
117                     return -1
118                 elif b is self.small:
119                     return 1
120                 elif a is self.large:
121                     return 1
122                 elif b is self.large:
123                     return -1
124             c = part.compare(a, b)
125             if c != 0:
126                 return c
127         return 0
128
129 def floatcmp(a, b):
130     if math.isnan(a) and math.isnan(b):
131         return 0
132     elif math.isnan(a):
133         return -1
134     elif math.isnan(b):
135         return 1
136     elif a < b:
137         return -1
138     elif a > b:
139         return 1
140     else:
141         return 0
142
143 t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True))
144 t_int = simpletype.struct(">q")
145 t_uint = simpletype.struct(">Q")
146 t_dbid = t_uint
147 t_float = simpletype.struct(">d")
148 t_float.compare = floatcmp
149 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
150 t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
151                      (lambda st: st.lower()))
152
153 class index(object):
154     def __init__(self, db, name, datatype):
155         self.db = db
156         self.nm = name
157         self.typ = datatype
158
159 missing = object()
160
161 class ordered(index, lib.closable):
162     def __init__(self, db, name, datatype, create=True, *, tx=None):
163         super().__init__(db, name, datatype)
164         fl = bd.DB_THREAD
165         if create: fl |= bd.DB_CREATE
166         def initdb(db):
167             def compare(a, b):
168                 if a == b == "": return 0
169                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
170             db.set_flags(bd.DB_DUPSORT)
171             db.set_bt_compare(compare)
172         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx)
173         self.bk.set_get_returns_none(False)
174
175     def close(self):
176         self.bk.close()
177
178     class cursor(lib.closable):
179         def __init__(self, idx, fd, fi, ld, li, reverse):
180             self.idx = idx
181             self.typ = idx.typ
182             self.cur = self.idx.bk.cursor()
183             self.item = None
184             self.fd = fd
185             self.fi = fi
186             self.ld = ld
187             self.li = li
188             self.rev = reverse
189
190         def close(self):
191             if self.cur is not None:
192                 self.cur.close()
193                 self.cur = None
194
195         def __iter__(self):
196             return self
197
198         def _decode(self, d):
199             k, v = d
200             k = self.typ.decode(k)
201             v = struct.unpack(">Q", v)[0]
202             return k, v
203
204         @dloopfun
205         def first(self):
206             try:
207                 if self.fd is missing:
208                     self.item = self._decode(self.cur.first())
209                 else:
210                     k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
211                     if not self.fi:
212                         while self.typ.compare(k, self.fd) == 0:
213                             k, v = self._decode(self.cur.next())
214                     self.item = k, v
215             except notfound:
216                 self.item = StopIteration
217
218         @dloopfun
219         def last(self):
220             try:
221                 if self.ld is missing:
222                     self.item = self._decode(self.cur.last())
223                 else:
224                     try:
225                         k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
226                     except notfound:
227                         k, v = self._decode(self.cur.last())
228                     if self.li:
229                         while self.typ.compare(k, self.ld) == 0:
230                             k, v = self._decode(self.cur.next())
231                         while self.typ.compare(k, self.ld) > 0:
232                             k, v = self._decode(self.cur.prev())
233                     else:
234                         while self.typ.compare(k, self.ld) >= 0:
235                             k, v = self._decode(self.cur.prev())
236                     self.item = k, v
237             except notfound:
238                 self.item = StopIteration
239
240         @dloopfun
241         def next(self):
242             try:
243                 k, v = self.item = self._decode(self.cur.next())
244                 if (self.ld is not missing and
245                     ((self.li and self.typ.compare(k, self.ld) > 0) or
246                      (not self.li and self.typ.compare(k, self.ld) >= 0))):
247                     self.item = StopIteration
248             except notfound:
249                 self.item = StopIteration
250
251         @dloopfun
252         def prev(self):
253             try:
254                 self.item = self._decode(self.cur.prev())
255                 if (self.fd is not missing and
256                     ((self.fi and self.typ.compare(k, self.fd) < 0) or
257                      (not self.fi and self.typ.compare(k, self.fd) <= 0))):
258                     self.item = StopIteration
259             except notfound:
260                 self.item = StopIteration
261
262         def __next__(self):
263             if self.item is None:
264                 if not self.rev:
265                     self.next()
266                 else:
267                     self.prev()
268             if self.item is StopIteration:
269                 raise StopIteration()
270             ret, self.item = self.item, None
271             return ret
272
273         def skip(self, n=1):
274             try:
275                 for i in range(n):
276                     next(self)
277             except StopIteration:
278                 return
279
280     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
281         if all:
282             cur = self.cursor(self, missing, True, missing, True, reverse)
283         elif match is not missing:
284             cur = self.cursor(self, match, True, match, True, reverse)
285         elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
286             if ge is not missing:
287                 fd, fi = ge, True
288             elif gt is not missing:
289                 fd, fi = gt, False
290             else:
291                 fd, fi = missing, True
292             if le is not missing:
293                 ld, li = le, True
294             elif lt is not missing:
295                 ld, li = lt, False
296             else:
297                 ld, li = missing, True
298             cur = self.cursor(self, fd, fi, ld, li, reverse)
299         else:
300             raise NameError("invalid get() specification")
301         done = False
302         try:
303             if not reverse:
304                 cur.first()
305             else:
306                 cur.last()
307             done = True
308             return cur
309         finally:
310             if not done:
311                 cur.close()
312
313     @txnfun(lambda self: self.db.env.env)
314     def put(self, key, id, *, tx):
315         obid = struct.pack(">Q", id)
316         if not self.db.ob.has_key(obid, txn=tx.tx):
317             raise ValueError("no such object in database: " + str(id))
318         try:
319             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
320         except bd.DBKeyExistError:
321             return False
322         return True
323
324     @txnfun(lambda self: self.db.env.env)
325     def remove(self, key, id, *, tx):
326         obid = struct.pack(">Q", id)
327         if not self.db.ob.has_key(obid, txn=tx.tx):
328             raise ValueError("no such object in database: " + str(id))
329         cur = self.bk.cursor(txn=tx.tx)
330         try:
331             try:
332                 cur.get_both(self.typ.encode(key), obid)
333             except notfound:
334                 return False
335             cur.delete()
336         finally:
337             cur.close()
338         return True