Commit | Line | Data |
---|---|---|
abb94f83 | 1 | import struct, contextlib, math |
a95055e8 | 2 | from . import db, lib |
6efe4e23 | 3 | from .db import bd, txnfun, dloopfun |
a95055e8 | 4 | |
bd776ebd | 5 | __all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"] |
cbf73d3a | 6 | |
a95055e8 FT |
7 | deadlock = bd.DBLockDeadlockError |
8 | notfound = bd.DBNotFoundError | |
9 | ||
10 | class simpletype(object): | |
11 | def __init__(self, encode, decode): | |
12 | self.enc = encode | |
13 | self.dec = decode | |
14 | ||
15 | def encode(self, ob): | |
16 | return self.enc(ob) | |
17 | def decode(self, dat): | |
18 | return self.dec(dat) | |
19 | def compare(self, a, b): | |
20 | if a < b: | |
21 | return -1 | |
22 | elif a > b: | |
23 | return 1 | |
24 | else: | |
25 | return 0 | |
26 | ||
27 | @classmethod | |
28 | def struct(cls, fmt): | |
29 | return cls(lambda ob: struct.pack(fmt, ob), | |
30 | lambda dat: struct.unpack(fmt, dat)[0]) | |
31 | ||
f9b1d040 FT |
32 | class foldtype(simpletype): |
33 | def __init__(self, encode, decode, fold): | |
34 | super().__init__(encode, decode) | |
35 | self.fold = fold | |
36 | ||
37 | def compare(self, a, b): | |
38 | return super().compare(self.fold(a), self.fold(b)) | |
39 | ||
a95055e8 FT |
40 | class maybe(object): |
41 | def __init__(self, bk): | |
42 | self.bk = bk | |
43 | ||
44 | def encode(self, ob): | |
45 | if ob is None: return b"" | |
46 | return b"\0" + self.bk.encode(ob) | |
47 | def decode(self, dat): | |
48 | if dat == b"": return None | |
49 | return self.bk.dec(dat[1:]) | |
50 | def compare(self, a, b): | |
51 | if a is b is None: | |
52 | return 0 | |
53 | elif a is None: | |
54 | return -1 | |
55 | elif b is None: | |
56 | return 1 | |
57 | else: | |
58 | return self.bk.compare(a[1:], b[1:]) | |
59 | ||
bd14729f FT |
60 | class compound(object): |
61 | def __init__(self, *parts): | |
62 | self.parts = parts | |
63 | ||
177fbee6 FT |
64 | small = object() |
65 | large = object() | |
66 | def minim(self, *parts): | |
67 | return parts + tuple([self.small] * (len(self.parts) - len(parts))) | |
68 | def maxim(self, *parts): | |
69 | return parts + tuple([self.large] * (len(self.parts) - len(parts))) | |
70 | ||
bd14729f FT |
71 | def encode(self, obs): |
72 | if len(obs) != len(self.parts): | |
73 | raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts)) | |
74 | buf = bytearray() | |
75 | for ob, part in zip(obs, self.parts): | |
177fbee6 FT |
76 | if ob is self.small: |
77 | buf.append(0x01) | |
78 | elif ob is self.large: | |
79 | buf.append(0x02) | |
bd14729f | 80 | else: |
177fbee6 FT |
81 | dat = part.encode(ob) |
82 | if len(dat) < 128: | |
83 | buf.append(0x80 | len(dat)) | |
84 | buf.extend(dat) | |
85 | else: | |
86 | buf.extend(struct.pack(">BI", 0, len(dat))) | |
87 | buf.extend(dat) | |
bd14729f FT |
88 | return bytes(buf) |
89 | def decode(self, dat): | |
90 | ret = [] | |
91 | off = 0 | |
92 | for part in self.parts: | |
177fbee6 FT |
93 | fl = dat[off] |
94 | off += 1 | |
95 | if fl & 0x80: | |
96 | ln = fl & 0x7f | |
97 | elif fl == 0x01: | |
98 | ret.append(self.small) | |
99 | continue | |
100 | elif fl == 0x02: | |
101 | ret.append(self.large) | |
102 | continue | |
bd14729f | 103 | else: |
177fbee6 | 104 | ln = struct.unpack(">I", dat[off:off + 4])[0] |
bd14729f | 105 | off += 4 |
177fbee6 FT |
106 | ret.append(part.decode(dat[off:off + ln])) |
107 | off += ln | |
bd14729f FT |
108 | return tuple(ret) |
109 | def compare(self, al, bl): | |
110 | if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)): | |
111 | raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts)) | |
112 | for a, b, part in zip(al, bl, self.parts): | |
177fbee6 FT |
113 | if a in (self.small, self.large) or b in (self.small, self.large): |
114 | if a is b: | |
115 | return 0 | |
116 | if a is self.small: | |
117 | return -1 | |
118 | elif b is self.small: | |
119 | return 1 | |
120 | elif a is self.large: | |
121 | return 1 | |
122 | elif b is self.large: | |
123 | return -1 | |
bd14729f FT |
124 | c = part.compare(a, b) |
125 | if c != 0: | |
126 | return c | |
127 | return 0 | |
128 | ||
abb94f83 FT |
129 | def floatcmp(a, b): |
130 | if math.isnan(a) and math.isnan(b): | |
131 | return 0 | |
132 | elif math.isnan(a): | |
133 | return -1 | |
134 | elif math.isnan(b): | |
135 | return 1 | |
136 | elif a < b: | |
137 | return -1 | |
138 | elif a > b: | |
139 | return 1 | |
140 | else: | |
141 | return 0 | |
142 | ||
bd776ebd | 143 | t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True)) |
abb94f83 FT |
144 | t_int = simpletype.struct(">q") |
145 | t_uint = simpletype.struct(">Q") | |
61b65544 | 146 | t_dbid = t_uint |
abb94f83 FT |
147 | t_float = simpletype.struct(">d") |
148 | t_float.compare = floatcmp | |
149 | t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8"))) | |
f9b1d040 FT |
150 | t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")), |
151 | (lambda st: st.lower())) | |
a95055e8 FT |
152 | |
153 | class index(object): | |
154 | def __init__(self, db, name, datatype): | |
155 | self.db = db | |
156 | self.nm = name | |
157 | self.typ = datatype | |
158 | ||
159 | missing = object() | |
160 | ||
161 | class ordered(index, lib.closable): | |
73761d10 | 162 | def __init__(self, db, name, datatype, create=True, *, tx=None): |
a95055e8 | 163 | super().__init__(db, name, datatype) |
73761d10 | 164 | fl = bd.DB_THREAD |
a95055e8 FT |
165 | if create: fl |= bd.DB_CREATE |
166 | def initdb(db): | |
167 | def compare(a, b): | |
168 | if a == b == "": return 0 | |
169 | return self.typ.compare(self.typ.decode(a), self.typ.decode(b)) | |
170 | db.set_flags(bd.DB_DUPSORT) | |
171 | db.set_bt_compare(compare) | |
73761d10 | 172 | self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx) |
a95055e8 FT |
173 | self.bk.set_get_returns_none(False) |
174 | ||
175 | def close(self): | |
176 | self.bk.close() | |
177 | ||
178 | class cursor(lib.closable): | |
6efe4e23 | 179 | def __init__(self, idx, fd, fi, ld, li, reverse): |
a95055e8 | 180 | self.idx = idx |
6efe4e23 FT |
181 | self.typ = idx.typ |
182 | self.cur = self.idx.bk.cursor() | |
183 | self.item = None | |
184 | self.fd = fd | |
185 | self.fi = fi | |
186 | self.ld = ld | |
187 | self.li = li | |
188 | self.rev = reverse | |
a95055e8 FT |
189 | |
190 | def close(self): | |
191 | if self.cur is not None: | |
192 | self.cur.close() | |
6efe4e23 | 193 | self.cur = None |
a95055e8 FT |
194 | |
195 | def __iter__(self): | |
196 | return self | |
197 | ||
6efe4e23 FT |
198 | def _decode(self, d): |
199 | k, v = d | |
a48a2d5d | 200 | k = self.typ.decode(k) |
6efe4e23 FT |
201 | v = struct.unpack(">Q", v)[0] |
202 | return k, v | |
a95055e8 | 203 | |
6efe4e23 FT |
204 | @dloopfun |
205 | def first(self): | |
206 | try: | |
d6d41a45 | 207 | if self.fd is missing: |
6efe4e23 FT |
208 | self.item = self._decode(self.cur.first()) |
209 | else: | |
210 | k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd))) | |
211 | if not self.fi: | |
212 | while self.typ.compare(k, self.fd) == 0: | |
213 | k, v = self._decode(self.cur.next()) | |
214 | self.item = k, v | |
215 | except notfound: | |
216 | self.item = StopIteration | |
217 | ||
218 | @dloopfun | |
219 | def last(self): | |
220 | try: | |
d6d41a45 | 221 | if self.ld is missing: |
6efe4e23 FT |
222 | self.item = self._decode(self.cur.last()) |
223 | else: | |
d6d41a45 FT |
224 | try: |
225 | k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld))) | |
226 | except notfound: | |
227 | k, v = self._decode(self.cur.last()) | |
228 | if self.li: | |
229 | while self.typ.compare(k, self.ld) == 0: | |
6efe4e23 | 230 | k, v = self._decode(self.cur.next()) |
d6d41a45 FT |
231 | while self.typ.compare(k, self.ld) > 0: |
232 | k, v = self._decode(self.cur.prev()) | |
6efe4e23 | 233 | else: |
d6d41a45 | 234 | while self.typ.compare(k, self.ld) >= 0: |
6efe4e23 FT |
235 | k, v = self._decode(self.cur.prev()) |
236 | self.item = k, v | |
237 | except notfound: | |
238 | self.item = StopIteration | |
239 | ||
240 | @dloopfun | |
241 | def next(self): | |
242 | try: | |
243 | k, v = self.item = self._decode(self.cur.next()) | |
d6d41a45 FT |
244 | if (self.ld is not missing and |
245 | ((self.li and self.typ.compare(k, self.ld) > 0) or | |
246 | (not self.li and self.typ.compare(k, self.ld) >= 0))): | |
6efe4e23 FT |
247 | self.item = StopIteration |
248 | except notfound: | |
249 | self.item = StopIteration | |
250 | ||
251 | @dloopfun | |
252 | def prev(self): | |
a95055e8 | 253 | try: |
6efe4e23 | 254 | self.item = self._decode(self.cur.prev()) |
d6d41a45 FT |
255 | if (self.fd is not missing and |
256 | ((self.fi and self.typ.compare(k, self.fd) < 0) or | |
257 | (not self.fi and self.typ.compare(k, self.fd) <= 0))): | |
6efe4e23 | 258 | self.item = StopIteration |
a95055e8 | 259 | except notfound: |
6efe4e23 FT |
260 | self.item = StopIteration |
261 | ||
262 | def __next__(self): | |
6efe4e23 FT |
263 | if self.item is None: |
264 | if not self.rev: | |
265 | self.next() | |
266 | else: | |
267 | self.prev() | |
a48a2d5d FT |
268 | if self.item is StopIteration: |
269 | raise StopIteration() | |
6efe4e23 FT |
270 | ret, self.item = self.item, None |
271 | return ret | |
a95055e8 FT |
272 | |
273 | def skip(self, n=1): | |
274 | try: | |
275 | for i in range(n): | |
276 | next(self) | |
277 | except StopIteration: | |
278 | return | |
279 | ||
6efe4e23 FT |
280 | def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False): |
281 | if all: | |
d6d41a45 | 282 | cur = self.cursor(self, missing, True, missing, True, reverse) |
6efe4e23 FT |
283 | elif match is not missing: |
284 | cur = self.cursor(self, match, True, match, True, reverse) | |
285 | elif ge is not missing or gt is not missing or lt is not missing or le is not missing: | |
286 | if ge is not missing: | |
287 | fd, fi = ge, True | |
288 | elif gt is not missing: | |
289 | fd, fi = gt, False | |
290 | else: | |
d6d41a45 | 291 | fd, fi = missing, True |
6efe4e23 FT |
292 | if le is not missing: |
293 | ld, li = le, True | |
294 | elif lt is not missing: | |
295 | ld, li = lt, False | |
296 | else: | |
d6d41a45 | 297 | ld, li = missing, True |
6efe4e23 FT |
298 | cur = self.cursor(self, fd, fi, ld, li, reverse) |
299 | else: | |
300 | raise NameError("invalid get() specification") | |
301 | done = False | |
302 | try: | |
303 | if not reverse: | |
304 | cur.first() | |
305 | else: | |
306 | cur.last() | |
307 | done = True | |
308 | return cur | |
309 | finally: | |
310 | if not done: | |
311 | cur.close() | |
a95055e8 | 312 | |
8950191c FT |
313 | @txnfun(lambda self: self.db.env.env) |
314 | def put(self, key, id, *, tx): | |
315 | obid = struct.pack(">Q", id) | |
316 | if not self.db.ob.has_key(obid, txn=tx.tx): | |
317 | raise ValueError("no such object in database: " + str(id)) | |
318 | try: | |
319 | self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA) | |
320 | except bd.DBKeyExistError: | |
321 | return False | |
322 | return True | |
323 | ||
324 | @txnfun(lambda self: self.db.env.env) | |
325 | def remove(self, key, id, *, tx): | |
326 | obid = struct.pack(">Q", id) | |
327 | if not self.db.ob.has_key(obid, txn=tx.tx): | |
328 | raise ValueError("no such object in database: " + str(id)) | |
329 | cur = self.bk.cursor(txn=tx.tx) | |
330 | try: | |
a95055e8 | 331 | try: |
8950191c FT |
332 | cur.get_both(self.typ.encode(key), obid) |
333 | except notfound: | |
334 | return False | |
335 | cur.delete() | |
336 | finally: | |
337 | cur.close() | |
338 | return True |