1 import struct, contextlib, math
3 from .db import bd, txnfun, dloopfun
5 __all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
10 class simpletype(object):
11 def __init__(self, encode, decode):
17 def decode(self, dat):
19 def compare(self, a, b):
29 return cls(lambda ob: struct.pack(fmt, ob),
30 lambda dat: struct.unpack(fmt, dat)[0])
32 class foldtype(simpletype):
33 def __init__(self, encode, decode, fold):
34 super().__init__(encode, decode)
37 def compare(self, a, b):
38 return super().compare(self.fold(a), self.fold(b))
41 def __init__(self, bk):
45 if ob is None: return b""
46 return b"\0" + self.bk.encode(ob)
47 def decode(self, dat):
48 if dat == b"": return None
49 return self.bk.dec(dat[1:])
50 def compare(self, a, b):
58 return self.bk.compare(a[1:], b[1:])
60 class compound(object):
61 def __init__(self, *parts):
66 def minim(self, *parts):
67 return parts + tuple([self.small] * (len(self.parts) - len(parts)))
68 def maxim(self, *parts):
69 return parts + tuple([self.large] * (len(self.parts) - len(parts)))
71 def encode(self, obs):
72 if len(obs) != len(self.parts):
73 raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
75 for ob, part in zip(obs, self.parts):
78 elif ob is self.large:
83 buf.append(0x80 | len(dat))
86 buf.extend(struct.pack(">BI", 0, len(dat)))
89 def decode(self, dat):
92 for part in self.parts:
98 ret.append(self.small)
101 ret.append(self.large)
104 ln = struct.unpack(">I", dat[off:off + 4])[0]
106 ret.append(part.decode(dat[off:off + ln]))
109 def compare(self, al, bl):
110 if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
111 raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
112 for a, b, part in zip(al, bl, self.parts):
113 if a in (self.small, self.large) or b in (self.small, self.large):
118 elif b is self.small:
120 elif a is self.large:
122 elif b is self.large:
124 c = part.compare(a, b)
130 if math.isnan(a) and math.isnan(b):
143 t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True))
144 t_int = simpletype.struct(">q")
145 t_uint = simpletype.struct(">Q")
147 t_float = simpletype.struct(">d")
148 t_float.compare = floatcmp
149 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
150 t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
151 (lambda st: st.lower()))
154 def __init__(self, db, name, datatype):
161 class ordered(index, lib.closable):
162 def __init__(self, db, name, datatype, create=True, *, tx=None):
163 super().__init__(db, name, datatype)
165 if create: fl |= bd.DB_CREATE
168 if a == b == "": return 0
169 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
170 db.set_flags(bd.DB_DUPSORT)
171 db.set_bt_compare(compare)
172 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx)
173 self.bk.set_get_returns_none(False)
178 class cursor(lib.closable):
179 def __init__(self, idx, fd, fi, ld, li, reverse):
182 self.cur = self.idx.bk.cursor()
191 if self.cur is not None:
198 def _decode(self, d):
200 k = self.typ.decode(k)
201 v = struct.unpack(">Q", v)[0]
207 if self.fd is missing:
208 self.item = self._decode(self.cur.first())
210 k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
212 while self.typ.compare(k, self.fd) == 0:
213 k, v = self._decode(self.cur.next())
216 self.item = StopIteration
221 if self.ld is missing:
222 self.item = self._decode(self.cur.last())
225 k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
227 k, v = self._decode(self.cur.last())
229 while self.typ.compare(k, self.ld) == 0:
230 k, v = self._decode(self.cur.next())
231 while self.typ.compare(k, self.ld) > 0:
232 k, v = self._decode(self.cur.prev())
234 while self.typ.compare(k, self.ld) >= 0:
235 k, v = self._decode(self.cur.prev())
238 self.item = StopIteration
243 k, v = self.item = self._decode(self.cur.next())
244 if (self.ld is not missing and
245 ((self.li and self.typ.compare(k, self.ld) > 0) or
246 (not self.li and self.typ.compare(k, self.ld) >= 0))):
247 self.item = StopIteration
249 self.item = StopIteration
254 self.item = self._decode(self.cur.prev())
255 if (self.fd is not missing and
256 ((self.fi and self.typ.compare(k, self.fd) < 0) or
257 (not self.fi and self.typ.compare(k, self.fd) <= 0))):
258 self.item = StopIteration
260 self.item = StopIteration
263 if self.item is None:
268 if self.item is StopIteration:
269 raise StopIteration()
270 ret, self.item = self.item, None
277 except StopIteration:
280 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
282 cur = self.cursor(self, missing, True, missing, True, reverse)
283 elif match is not missing:
284 cur = self.cursor(self, match, True, match, True, reverse)
285 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
286 if ge is not missing:
288 elif gt is not missing:
291 fd, fi = missing, True
292 if le is not missing:
294 elif lt is not missing:
297 ld, li = missing, True
298 cur = self.cursor(self, fd, fi, ld, li, reverse)
300 raise NameError("invalid get() specification")
313 @txnfun(lambda self: self.db.env.env)
314 def put(self, key, id, *, tx):
315 obid = struct.pack(">Q", id)
316 if not self.db.ob.has_key(obid, txn=tx.tx):
317 raise ValueError("no such object in database: " + str(id))
319 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
320 except bd.DBKeyExistError:
324 @txnfun(lambda self: self.db.env.env)
325 def remove(self, key, id, *, tx):
326 obid = struct.pack(">Q", id)
327 if not self.db.ob.has_key(obid, txn=tx.tx):
328 raise ValueError("no such object in database: " + str(id))
329 cur = self.bk.cursor(txn=tx.tx)
332 cur.get_both(self.typ.encode(key), obid)