Fixed copyrequest missing input.
[wrw.git] / wrw / util.py
1 import inspect, math
2 from . import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     spec = inspect.getargspec(callable)
12     def wrapper(req):
13         try:
14             data = form.formdata(req)
15         except IOError:
16             raise resp.httperror(400, "Invalid request", "Form data was incomplete")
17         args = dict(data.items())
18         args["req"] = req
19         if not spec.keywords:
20             for arg in list(args):
21                 if arg not in spec.args:
22                     del args[arg]
23         for i in range(len(spec.args) - (len(spec.defaults) if spec.defaults else 0)):
24             if spec.args[i] not in args:
25                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
26         return callable(**args)
27     wrapper.__wrapped__ = callable
28     return wrapper
29
30 class funplex(object):
31     def __init__(self, *funs, **nfuns):
32         self.dir = {}
33         self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
34         self.dir.update(nfuns)
35
36     @staticmethod
37     def unwrap(fun):
38         while hasattr(fun, "__wrapped__"):
39             fun = fun.__wrapped__
40         return fun
41
42     def __call__(self, req):
43         if req.pathinfo == "":
44             raise resp.redirect(req.uriname + "/")
45         if req.pathinfo[:1] != "/":
46             raise resp.notfound()
47         p = req.pathinfo[1:]
48         if p == "":
49             p = "__index__"
50             bi = 1
51         else:
52             p = p.partition("/")[0]
53             bi = len(p) + 1
54         if p in self.dir:
55             sreq = req.shift(bi)
56             sreq.selfpath = req.pathinfo[1:]
57             return self.dir[p](sreq)
58         raise resp.notfound()
59
60     def add(self, fun):
61         self.dir[self.unwrap(fun).__name__] = fun
62         return fun
63
64     def name(self, name):
65         def dec(fun):
66             self.dir[name] = fun
67             return fun
68         return dec
69
70 def persession(data=None):
71     def dec(callable):
72         def wrapper(req):
73             sess = session.get(req)
74             if callable not in sess:
75                 if data is None:
76                     sess[callable] = callable()
77                 else:
78                     if data not in sess:
79                         sess[data] = data()
80                     sess[callable] = callable(data)
81             return sess[callable].handle(req)
82         wrapper.__wrapped__ = callable
83         return wrapper
84     return dec
85
86 class preiter(object):
87     __slots__ = ["bk", "bki", "_next"]
88     end = object()
89     def __init__(self, real):
90         self.bk = real
91         self.bki = iter(real)
92         self._next = None
93         self.__next__()
94
95     def __iter__(self):
96         return self
97
98     def __next__(self):
99         if self._next is self.end:
100             raise StopIteration()
101         ret = self._next
102         try:
103             self._next = next(self.bki)
104         except StopIteration:
105             self._next = self.end
106         return ret
107
108     def close(self):
109         if hasattr(self.bk, "close"):
110             self.bk.close()
111
112 def pregen(callable):
113     def wrapper(*args, **kwargs):
114         return preiter(callable(*args, **kwargs))
115     wrapper.__wrapped__ = callable
116     return wrapper
117
118 def stringwrap(charset):
119     def dec(callable):
120         @pregen
121         def wrapper(*args, **kwargs):
122             for string in callable(*args, **kwargs):
123                 yield string.encode(charset)
124         wrapper.__wrapped__ = callable
125         return wrapper
126     return dec
127
128 class sessiondata(object):
129     @classmethod
130     def get(cls, req, create=True):
131         sess = cls.sessdb().get(req)
132         with sess.lock:
133             try:
134                 return sess[cls]
135             except KeyError:
136                 if not create:
137                     return None
138                 ret = cls(req, sess)
139                 sess[cls] = ret
140                 return ret
141
142     @classmethod
143     def sessdb(cls):
144         return session.default.val
145
146 class autodirty(sessiondata):
147     @classmethod
148     def get(cls, req):
149         ret = super().get(req)
150         if "_is_dirty" not in ret.__dict__:
151             ret.__dict__["_is_dirty"] = False
152         return ret
153
154     def sessfrozen(self):
155         self.__dict__["_is_dirty"] = False
156
157     def sessdirty(self):
158         return self._is_dirty
159
160     def __setattr__(self, name, value):
161         super().__setattr__(name, value)
162         if "_is_dirty" in self.__dict__:
163             self.__dict__["_is_dirty"] = True
164
165     def __delattr__(self, name):
166         super().__delattr__(name, value)
167         if "_is_dirty" in self.__dict__:
168             self.__dict__["_is_dirty"] = True
169
170 class manudirty(object):
171     def __init__(self, *args, **kwargs):
172         super().__init__(*args, **kwargs)
173         self.__dirty = False
174
175     def sessfrozen(self):
176         self.__dirty = False
177
178     def sessdirty(self):
179         return self.__dirty
180
181     def dirty(self):
182         self.__dirty = True
183
184 class specslot(object):
185     __slots__ = ["nm", "idx", "dirty"]
186     unbound = object()
187     
188     def __init__(self, nm, idx, dirty):
189         self.nm = nm
190         self.idx = idx
191         self.dirty = dirty
192
193     @staticmethod
194     def slist(ins):
195         # Avoid calling __getattribute__
196         return specdirty.__sslots__.__get__(ins, type(ins))
197
198     def __get__(self, ins, cls):
199         val = self.slist(ins)[self.idx]
200         if val is specslot.unbound:
201             raise AttributeError("specslot %r is unbound" % self.nm)
202         return val
203
204     def __set__(self, ins, val):
205         self.slist(ins)[self.idx] = val
206         if self.dirty:
207             ins.dirty()
208
209     def __delete__(self, ins):
210         self.slist(ins)[self.idx] = specslot.unbound
211         ins.dirty()
212
213 class specclass(type):
214     def __init__(self, name, bases, tdict):
215         super().__init__(name, bases, tdict)
216         sslots = set()
217         dslots = set()
218         for cls in self.__mro__:
219             css = cls.__dict__.get("__saveslots__", ())
220             sslots.update(css)
221             dslots.update(cls.__dict__.get("__dirtyslots__", css))
222         self.__sslots_l__ = list(sslots)
223         self.__sslots_a__ = list(sslots | dslots)
224         for i, slot in enumerate(self.__sslots_a__):
225             setattr(self, slot, specslot(slot, i, slot in dslots))
226
227 class specdirty(sessiondata, metaclass=specclass):
228     __slots__ = ["session", "__sslots__", "_is_dirty"]
229     
230     def __specinit__(self):
231         pass
232
233     @staticmethod
234     def __new__(cls, req, sess):
235         self = super().__new__(cls)
236         self.session = sess
237         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
238         self.__specinit__()
239         self._is_dirty = False
240         return self
241
242     def __getnewargs__(self):
243         return (None, self.session)
244
245     def dirty(self):
246         self._is_dirty = True
247
248     def sessfrozen(self):
249         self._is_dirty = False
250
251     def sessdirty(self):
252         return self._is_dirty
253
254     def __getstate__(self):
255         ret = {}
256         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
257             if val is specslot.unbound:
258                 ret[nm] = False, None
259             else:
260                 ret[nm] = True, val
261         return ret
262
263     def __setstate__(self, st):
264         ss = specslot.slist(self)
265         for i, nm in enumerate(type(self).__sslots_a__):
266             bound, val = st.pop(nm, (False, None))
267             if not bound:
268                 ss[i] = specslot.unbound
269             else:
270                 ss[i] = val
271
272 def datecheck(req, mtime):
273     if "If-Modified-Since" in req.ihead:
274         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
275         if rtime is not None and rtime >= math.floor(mtime):
276             raise resp.unmodified()
277     req.ohead["Last-Modified"] = proto.httpdate(mtime)