Fix typo.
[wrw.git] / wrw / auth.py
1 import binascii, hashlib, threading, time
2 from . import resp
3
4 class unauthorized(resp.httperror):
5     def __init__(self, challenge, message=None, detail=None):
6         super().__init__(401, message, detail)
7         if isinstance(challenge, str):
8             challenge = [challenge]
9         self.challenge = challenge
10
11     def handle(self, req):
12         for challenge in self.challenge:
13             req.ohead.add("WWW-Authenticate", challenge)
14         return super().handle(req)
15
16 class forbidden(resp.httperror):
17     def __init__(self, message=None, detail=None):
18         super().__init__(403, message, detail)
19
20 def parsemech(req):
21     h = req.ihead.get("Authorization", None)
22     if h is None:
23         return None, None
24     p = h.find(" ")
25     if p < 0:
26         return None, None
27     return h[:p].strip().lower(), h[p + 1:].strip()
28
29 def parsebasic(req):
30     mech, data = parsemech(req)
31     if mech != "basic":
32         return None, None
33     try:
34         data = data.encode("us-ascii")
35     except UnicodeError:
36         return None, None
37     try:
38         raw = binascii.a2b_base64(data)
39     except binascii.Error:
40         return None, None
41     try:
42         raw = raw.decode("utf-8")
43     except UnicodeError:
44         raw = raw.decode("latin1")
45     p = raw.find(":")
46     if p < 0:
47         return None, None
48     return raw[:p], raw[p + 1:]
49
50 class basiccache(object):
51     cachetime = 300
52
53     def __init__(self, realm, authfn=None):
54         self._lock = threading.Lock()
55         self._cache = {}
56         self.realm = realm
57         if authfn is not None:
58             self.auth = authfn
59
60     def _obscure(self, nm, pw):
61         dig = hashlib.sha256()
62         dig.update(self.realm.encode("utf-8"))
63         dig.update(nm.encode("utf-8"))
64         dig.update(pw.encode("utf-8"))
65         return dig.digest()
66
67     def check(self, req):
68         nm, pw = parsebasic(req)
69         if nm is None:
70             raise unauthorized("Basic Realm=\"%s\"" % self.realm)
71         pwh = self._obscure(nm, pw)
72         now = time.time()
73         with self._lock:
74             if (nm, pwh) in self._cache:
75                 lock, atime, res, resob = self._cache[nm, pwh]
76                 if now - atime < self.cachetime:
77                     if res == "s":
78                         return resob
79                     elif res == "f":
80                         raise resob
81             else:
82                 lock = threading.Lock()
83                 self._cache[nm, pwh] = (lock, now, None, None)
84         with lock:
85             try:
86                 ret = self.auth(req, nm, pw)
87             except forbidden as exc:
88                 with self._lock:
89                     self._cache[nm, pwh] = (lock, now, "f", exc)
90                 raise
91             if ret is None:
92                 raise forbidden()
93             with self._lock:
94                 self._cache[nm, pwh] = (lock, now, "s", ret)
95             return ret
96
97     def auth(self, req, nm, pw):
98         raise Exception("authentication function neither supplied nor overridden")