Add Content-Length to SP responses.
[wrw.git] / wrw / util.py
1 import inspect, math
2 import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     spec = inspect.getargspec(callable)
12     def wrapper(req):
13         data = form.formdata(req)
14         args = dict(data.items())
15         args["req"] = req
16         if not spec.keywords:
17             for arg in list(args):
18                 if arg not in spec.args:
19                     del args[arg]
20         for i in xrange(len(spec.args) - (len(spec.defaults) if spec.defaults else 0)):
21             if spec.args[i] not in args:
22                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
23         return callable(**args)
24     wrapper.__wrapped__ = callable
25     return wrapper
26
27 class funplex(object):
28     def __init__(self, *funs, **nfuns):
29         self.dir = {}
30         self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
31         self.dir.update(nfuns)
32
33     @staticmethod
34     def unwrap(fun):
35         while hasattr(fun, "__wrapped__"):
36             fun = fun.__wrapped__
37         return fun
38
39     def __call__(self, req):
40         if req.pathinfo == "":
41             raise resp.redirect(req.uriname + "/")
42         if req.pathinfo[:1] != "/":
43             raise resp.notfound()
44         p = req.pathinfo[1:]
45         if p == "":
46             p = "__index__"
47             bi = 1
48         else:
49             p = p.partition("/")[0]
50             bi = len(p) + 1
51         if p in self.dir:
52             return self.dir[p](req.shift(bi))
53         raise resp.notfound()
54
55     def add(self, fun):
56         self.dir[self.unwrap(fun).__name__] = fun
57         return fun
58
59     def name(self, name):
60         def dec(fun):
61             self.dir[name] = fun
62             return fun
63         return dec
64
65 def persession(data=None):
66     def dec(callable):
67         def wrapper(req):
68             sess = session.get(req)
69             if callable not in sess:
70                 if data is None:
71                     sess[callable] = callable()
72                 else:
73                     if data not in sess:
74                         sess[data] = data()
75                     sess[callable] = callable(data)
76             return sess[callable].handle(req)
77         wrapper.__wrapped__ = callable
78         return wrapper
79     return dec
80
81 class preiter(object):
82     __slots__ = ["bk", "bki", "_next"]
83     end = object()
84     def __init__(self, real):
85         self.bk = real
86         self.bki = iter(real)
87         self._next = None
88         self.next()
89
90     def __iter__(self):
91         return self
92
93     def next(self):
94         if self._next is self.end:
95             raise StopIteration()
96         ret = self._next
97         try:
98             self._next = next(self.bki)
99         except StopIteration:
100             self._next = self.end
101         return ret
102
103     def close(self):
104         if hasattr(self.bk, "close"):
105             self.bk.close()
106
107 def pregen(callable):
108     def wrapper(*args, **kwargs):
109         return preiter(callable(*args, **kwargs))
110     wrapper.__wrapped__ = callable
111     return wrapper
112
113 class sessiondata(object):
114     @classmethod
115     def get(cls, req, create=True):
116         sess = cls.sessdb().get(req)
117         with sess.lock:
118             try:
119                 return sess[cls]
120             except KeyError:
121                 if not create:
122                     return None
123                 ret = cls(req, sess)
124                 sess[cls] = ret
125                 return ret
126
127     @classmethod
128     def sessdb(cls):
129         return session.default.val
130
131 class autodirty(sessiondata):
132     @classmethod
133     def get(cls, req):
134         ret = super(autodirty, cls).get(req)
135         if "_is_dirty" not in ret.__dict__:
136             ret.__dict__["_is_dirty"] = False
137         return ret
138
139     def sessfrozen(self):
140         self.__dict__["_is_dirty"] = False
141
142     def sessdirty(self):
143         return self._is_dirty
144
145     def __setattr__(self, name, value):
146         super(autodirty, self).__setattr__(name, value)
147         if "_is_dirty" in self.__dict__:
148             self.__dict__["_is_dirty"] = True
149
150     def __delattr__(self, name):
151         super(autodirty, self).__delattr__(name, value)
152         if "_is_dirty" in self.__dict__:
153             self.__dict__["_is_dirty"] = True
154
155 class manudirty(object):
156     def __init__(self, *args, **kwargs):
157         super(manudirty, self).__init__(*args, **kwargs)
158         self.__dirty = False
159
160     def sessfrozen(self):
161         self.__dirty = False
162
163     def sessdirty(self):
164         return self.__dirty
165
166     def dirty(self):
167         self.__dirty = True
168
169 class specslot(object):
170     __slots__ = ["nm", "idx", "dirty"]
171     unbound = object()
172     
173     def __init__(self, nm, idx, dirty):
174         self.nm = nm
175         self.idx = idx
176         self.dirty = dirty
177
178     @staticmethod
179     def slist(ins):
180         # Avoid calling __getattribute__
181         return specdirty.__sslots__.__get__(ins, type(ins))
182
183     def __get__(self, ins, cls):
184         val = self.slist(ins)[self.idx]
185         if val is specslot.unbound:
186             raise AttributeError("specslot %r is unbound" % self.nm)
187         return val
188
189     def __set__(self, ins, val):
190         self.slist(ins)[self.idx] = val
191         if self.dirty:
192             ins.dirty()
193
194     def __delete__(self, ins):
195         self.slist(ins)[self.idx] = specslot.unbound
196         ins.dirty()
197
198 class specclass(type):
199     def __init__(self, name, bases, tdict):
200         super(specclass, self).__init__(name, bases, tdict)
201         sslots = set()
202         dslots = set()
203         for cls in self.__mro__:
204             css = cls.__dict__.get("__saveslots__", ())
205             sslots.update(css)
206             dslots.update(cls.__dict__.get("__dirtyslots__", css))
207         self.__sslots_l__ = list(sslots)
208         self.__sslots_a__ = list(sslots | dslots)
209         for i, slot in enumerate(self.__sslots_a__):
210             setattr(self, slot, specslot(slot, i, slot in dslots))
211
212 class specdirty(sessiondata):
213     __metaclass__ = specclass
214     __slots__ = ["session", "__sslots__", "_is_dirty"]
215     
216     def __specinit__(self):
217         pass
218
219     @staticmethod
220     def __new__(cls, req, sess):
221         self = super(specdirty, cls).__new__(cls)
222         self.session = sess
223         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
224         self.__specinit__()
225         self._is_dirty = False
226         return self
227
228     def __getnewargs__(self):
229         return (None, self.session)
230
231     def dirty(self):
232         self._is_dirty = True
233
234     def sessfrozen(self):
235         self._is_dirty = False
236
237     def sessdirty(self):
238         return self._is_dirty
239
240     def __getstate__(self):
241         ret = {}
242         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
243             if val is specslot.unbound:
244                 ret[nm] = False, None
245             else:
246                 ret[nm] = True, val
247         return ret
248
249     def __setstate__(self, st):
250         ss = specslot.slist(self)
251         for i, nm in enumerate(type(self).__sslots_a__):
252             bound, val = st.pop(nm, (False, None))
253             if not bound:
254                 ss[i] = specslot.unbound
255             else:
256                 ss[i] = val
257
258 def datecheck(req, mtime):
259     if "If-Modified-Since" in req.ihead:
260         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
261         if rtime >= math.floor(mtime):
262             raise resp.unmodified()
263     req.ohead["Last-Modified"] = proto.httpdate(mtime)