python*: Use poll instead of select in ckflush.
[ashd.git] / python3 / ashd / serve.py
index 0a45259..0927710 100644 (file)
@@ -1,4 +1,4 @@
-import sys, os, threading, time, logging, select, queue
+import sys, os, threading, time, logging, select, queue, collections
 from . import perf
 
 log = logging.getLogger("ashd.serve")
@@ -23,7 +23,7 @@ class reqthread(threading.Thread):
         super().__init__(name=name, **kw)
 
 class wsgirequest(object):
-    def __init__(self, handler):
+    def __init__(self, *, handler):
         self.status = None
         self.headers = []
         self.respsent = False
@@ -75,8 +75,10 @@ class handler(object):
     def handle(self, request):
         raise Exception()
     def ckflush(self, req):
+        p = select.poll()
+        p.register(req, select.POLLOUT)
         while len(req.buffer) > 0:
-            rls, wls, els = select.select([], [req], [req])
+            p.poll()
             req.flush()
     def close(self):
         pass
@@ -88,6 +90,8 @@ class handler(object):
         return {}
 
 class single(handler):
+    cname = "single"
+
     def handle(self, req):
         try:
             env = req.mkenv()
@@ -106,7 +110,19 @@ class single(handler):
         finally:
             req.close()
 
+def dbg(*a):
+    f = True
+    for o in a:
+        if not f:
+            sys.stderr.write(" ")
+        sys.stderr.write(str(a))
+        f = False
+    sys.stderr.write("\n")
+    sys.stderr.flush()
+
 class freethread(handler):
+    cname = "free"
+
     def __init__(self, *, max=None, timeout=None, **kw):
         super().__init__(**kw)
         self.current = set()
@@ -138,8 +154,9 @@ class freethread(handler):
                     while len(self.current) >= self.max:
                         self.tcond.wait()
             th = reqthread(target=self.run, args=[req])
+            th.registered = False
             th.start()
-            while th.is_alive() and th not in self.current:
+            while not th.registered:
                 self.tcond.wait()
 
     def run(self, req):
@@ -147,6 +164,7 @@ class freethread(handler):
             th = threading.current_thread()
             with self.lk:
                 self.current.add(th)
+                th.registered = True
                 self.tcond.notify_all()
             try:
                 env = req.mkenv()
@@ -179,39 +197,62 @@ class freethread(handler):
             th.join()
 
 class threadpool(handler):
-    def __init__(self, *, min=0, max=20, live=300, **kw):
+    cname = "pool"
+
+    def __init__(self, *, max=25, qsz=100, timeout=None, **kw):
         super().__init__(**kw)
         self.current = set()
-        self.free = set()
-        self.lk = threading.RLock()
-        self.pcond = threading.Condition(self.lk)
-        self.rcond = threading.Condition(self.lk)
-        self.wreq = None
-        self.min = min
+        self.clk = threading.Lock()
+        self.ccond = threading.Condition(self.clk)
+        self.queue = collections.deque()
+        self.waiting = set()
+        self.waitlimit = 5
+        self.wlstart = 0.0
+        self.qlk = threading.Lock()
+        self.qfcond = threading.Condition(self.qlk)
+        self.qecond = threading.Condition(self.qlk)
         self.max = max
-        self.live = live
-        for i in range(self.min):
-            self.newthread()
+        self.qsz = qsz
+        self.timeout = timeout
 
     @classmethod
-    def parseargs(cls, *, min=None, max=None, live=None, **args):
+    def parseargs(cls, *, max=None, queue=None, abort=None, **args):
         ret = super().parseargs(**args)
-        if min:
-            ret["min"] = int(min)
         if max:
             ret["max"] = int(max)
-        if live:
-            ret["live"] = int(live)
+        if queue:
+            ret["qsz"] = int(queue)
+        if abort:
+            ret["timeout"] = int(abort)
         return ret
 
-    def newthread(self):
-        with self.lk:
-            th = reqthread(target=self.loop)
-            th.start()
-            while not th in self.current:
-                self.pcond.wait()
+    def handle(self, req):
+        spawn = False
+        with self.qlk:
+            if self.timeout is not None:
+                now = start = time.time()
+                while len(self.queue) >= self.qsz:
+                    self.qecond.wait(start + self.timeout - now)
+                    now = time.time()
+                    if now - start > self.timeout:
+                        os.abort()
+            else:
+                while len(self.queue) >= self.qsz:
+                    self.qecond.wait()
+            self.queue.append(req)
+            self.qfcond.notify()
+            if len(self.waiting) < 1:
+                spawn = True
+        if spawn:
+            with self.clk:
+                if len(self.current) < self.max:
+                    th = reqthread(target=self.run)
+                    th.registered = False
+                    th.start()
+                    while not th.registered:
+                        self.ccond.wait()
 
-    def _handle(self, req):
+    def handle1(self, req):
         try:
             env = req.mkenv()
             with perf.request(env) as reqevent:
@@ -226,84 +267,92 @@ class threadpool(handler):
             pass
         except:
             log.error("exception occurred when handling request", exc_info=True)
-        finally:
-            req.close()
 
-    def loop(self):
+    def run(self):
+        timeout = 10.0
         th = threading.current_thread()
-        with self.lk:
+        with self.clk:
             self.current.add(th)
+            th.registered = True
+            self.ccond.notify_all()
         try:
             while True:
-                with self.lk:
-                    self.free.add(th)
-                    try:
-                        self.pcond.notify_all()
-                        now = start = time.time()
-                        while self.wreq is None:
-                            self.rcond.wait(start + self.live - now)
-                            now = time.time()
-                            if now - start > self.live:
-                                if len(self.current) > self.min:
-                                    self.current.remove(th)
-                                    return
-                                else:
-                                    start = now
-                        req, self.wreq = self.wreq, None
-                        self.pcond.notify_all()
-                    finally:
-                        self.free.remove(th)
-                self._handle(req)
-                req = None
-        finally:
-            with self.lk:
+                start = now = time.time()
+                with self.qlk:
+                    while len(self.queue) < 1:
+                        if len(self.waiting) >= self.waitlimit and now - self.wlstart >= timeout:
+                            return
+                        self.waiting.add(th)
+                        try:
+                            if len(self.waiting) == self.waitlimit:
+                                self.wlstart = now
+                            self.qfcond.wait(start + timeout - now)
+                        finally:
+                            self.waiting.remove(th)
+                        now = time.time()
+                        if now - start > timeout:
+                            return
+                    req = self.queue.popleft()
+                    self.qecond.notify()
                 try:
-                    self.current.remove(th)
-                except KeyError:
-                    pass
-                self.pcond.notify_all()
+                    self.handle1(req)
+                finally:
+                    req.close()
+        finally:
+            with self.clk:
+                self.current.remove(th)
 
-    def handle(self, req):
+    def close(self):
         while True:
-            with self.lk:
-                if len(self.free) < 1 and len(self.current) < self.max:
-                    self.newthread()
-                while self.wreq is not None:
-                    self.pcond.wait()
-                if self.wreq is None:
-                    self.wreq = req
-                    self.rcond.notify(1)
+            with self.clk:
+                if len(self.current) > 0:
+                    th = next(iter(self.current))
+                else:
                     return
-
-    def close(self):
-        self.live = 0
-        self.min = 0
-        with self.lk:
-            while len(self.current) > 0:
-                self.rcond.notify_all()
-                self.pcond.wait(1)
+            th.join()
 
 class resplex(handler):
-    def __init__(self, **kw):
+    cname = "rplex"
+
+    def __init__(self, *, max=None, **kw):
         super().__init__(**kw)
         self.current = set()
         self.lk = threading.Lock()
+        self.tcond = threading.Condition(self.lk)
+        self.max = max
         self.cqueue = queue.Queue(5)
         self.cnpipe = os.pipe()
         self.rthread = reqthread(name="Response thread", target=self.handle2)
         self.rthread.start()
 
+    @classmethod
+    def parseargs(cls, *, max=None, **args):
+        ret = super().parseargs(**args)
+        if max:
+            ret["max"] = int(max)
+        return ret
+
     def ckflush(self, req):
         raise Exception("resplex handler does not support the write() function")
 
     def handle(self, req):
-        reqthread(target=self.handle1, args=[req]).start()
+        with self.lk:
+            if self.max is not None:
+                while len(self.current) >= self.max:
+                    self.tcond.wait()
+            th = reqthread(target=self.handle1, args=[req])
+            th.registered = False
+            th.start()
+            while not th.registered:
+                self.tcond.wait()
 
     def handle1(self, req):
         try:
             th = threading.current_thread()
             with self.lk:
                 self.current.add(th)
+                th.registered = True
+                self.tcond.notify_all()
             try:
                 env = req.mkenv()
                 respobj = req.handlewsgi(env, req.startreq)
@@ -318,7 +367,9 @@ class resplex(handler):
                     os.write(self.cnpipe[1], b" ")
                     req = None
             finally:
-                self.current.remove(th)
+                with self.lk:
+                    self.current.remove(th)
+                    self.tcond.notify_all()
         except closed:
             pass
         except:
@@ -352,15 +403,22 @@ class resplex(handler):
                         data = next(respiter)
                     except StopIteration:
                         rem = True
-                        req.flushreq()
+                        try:
+                            req.flushreq()
+                        except:
+                            log.error("exception occurred when handling response data", exc_info=True)
                     except:
                         rem = True
                         log.error("exception occurred when iterating response", exc_info=True)
                     if not rem:
                         if data:
-                            req.flushreq()
-                            req.writedata(data)
-                    else:
+                            try:
+                                req.flushreq()
+                                req.writedata(data)
+                            except:
+                                log.error("exception occurred when handling response data", exc_info=True)
+                                rem = True
+                    if rem:
                         current[req] = None
                         try:
                             if hasattr(respiter, "close"):
@@ -412,10 +470,10 @@ class resplex(handler):
         os.close(self.cnpipe[1])
         self.rthread.join()
 
-names = {"single": single,
-         "free": freethread,
-         "pool": threadpool,
-         "rplex": resplex}
+names = {cls.cname: cls for cls in globals().values() if
+         isinstance(cls, type) and
+         issubclass(cls, handler) and
+         hasattr(cls, "cname")}
 
 def parsehspec(spec):
     if ":" not in spec: