python3: Cache async channel FDs so that updates always happen correctly.
[ashd.git] / python3 / ashd / async.py
index 02c75a9..238f6b7 100644 (file)
@@ -3,11 +3,16 @@ import sys, os, errno, threading, select, traceback
 class epoller(object):
     exc_handler = None
 
-    def __init__(self):
+    def __init__(self, check=None):
         self.registered = {}
+        self.fdcache = {}
         self.lock = threading.RLock()
         self.ep = None
         self.th = None
+        self.stopped = False
+        self.loopcheck = set()
+        if check is not None:
+            self.loopcheck.add(check)
         self._daemon = True
 
     @staticmethod
@@ -25,7 +30,7 @@ class epoller(object):
     def exception(self, ch, *exc):
         self.remove(ch)
         if self.exc_handler is None:
-            traceback.print_exception(exc)
+            traceback.print_exception(*exc)
         else:
             self.exc_handler(ch, *exc)
 
@@ -38,15 +43,32 @@ class epoller(object):
         except Exception as exc:
             self.exception(ch, *sys.exc_info())
 
+    def _closeall(self):
+        with self.lock:
+            while self.registered:
+                fd, (ch, evs) = next(iter(self.registered.items()))
+                del self.registered[fd]
+                self.ep.unregister(fd)
+                self._cb(ch, "close")
+
     def _run(self):
         ep = select.epoll()
         try:
             with self.lock:
-                for fd, (ob, evs) in self.registered.items():
-                    ep.register(fd, evs)
+                try:
+                    for fd, (ob, evs) in self.registered.items():
+                        ep.register(fd, evs)
+                except:
+                    self.registered.clear()
+                    raise
                 self.ep = ep
 
             while self.registered:
+                for ck in self.loopcheck:
+                    ck(self)
+                if self.stopped:
+                    self._closeall()
+                    break
                 try:
                     evlist = ep.poll(10)
                 except IOError as exc:
@@ -65,6 +87,7 @@ class epoller(object):
                         if fd in self.registered:
                             nevs = self._evsfor(ch)
                             if nevs == 0:
+                                del self.fdcache[ch]
                                 del self.registered[fd]
                                 ep.unregister(fd)
                                 self._cb(ch, "close")
@@ -98,6 +121,7 @@ class epoller(object):
                 ch.close()
                 return
             ch.watcher = self
+            self.fdcache[ch] = fd
             self.registered[fd] = (ch, evs)
             if self.ep:
                 self.ep.register(fd, evs)
@@ -105,14 +129,16 @@ class epoller(object):
 
     def remove(self, ch, ignore=False):
         with self.lock:
-            fd = ch.fileno()
-            if fd not in self.registered:
+            try:
+                fd = self.fdcache[ch]
+            except KeyError:
                 if ignore:
                     return
                 raise KeyError("fd %i is not registered" % fd)
             pch, cevs = self.registered[fd]
             if pch is not ch:
                 raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
+            del self.fdcache[ch]
             del self.registered[fd]
             if self.ep:
                 self.ep.unregister(fd)
@@ -120,8 +146,9 @@ class epoller(object):
 
     def update(self, ch, ignore=False):
         with self.lock:
-            fd = ch.fileno()
-            if fd not in self.registered:
+            try:
+                fd = self.fdcache[ch]
+            except KeyError:
                 if ignore:
                     return
                 raise KeyError("fd %i is not registered" % fd)
@@ -130,6 +157,7 @@ class epoller(object):
                 raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
             evs = self._evsfor(ch)
             if evs == 0:
+                del self.fdcache[ch]
                 del self.registered[fd]
                 if self.ep:
                     self.ep.unregister(fd)
@@ -139,15 +167,39 @@ class epoller(object):
                 if self.ep:
                     self.ep.modify(fd, evs)
 
+    def stop(self):
+        if threading.current_thread() == self.th:
+            self.stopped = True
+        else:
+            def tgt():
+                self.stopped = True
+            cb = callbuffer()
+            cb.call(tgt)
+            cb.stop()
+            self.add(cb)
+
 def watcher():
     return epoller()
 
-class sockbuffer(object):
-    def __init__(self, sk):
-        self.sk = sk
+class channel(object):
+    readable = False
+    writable = False
+
+    def __init__(self):
+        self.watcher = None
+
+    def fileno(self):
+        raise NotImplementedError("fileno()")
+
+    def close(self):
+        pass
+
+class sockbuffer(channel):
+    def __init__(self, socket, **kwargs):
+        super().__init__(**kwargs)
+        self.sk = socket
         self.eof = False
         self.obuf = bytearray()
-        self.watcher = None
 
     def fileno(self):
         return self.sk.fileno()
@@ -188,8 +240,9 @@ class sockbuffer(object):
             self.obuf[:] = b""
             self.eof = True
 
-class callbuffer(object):
-    def __init__(self):
+class callbuffer(channel):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
         self.queue = []
         self.rp, self.wp = os.pipe()
         self.lock = threading.Lock()
@@ -239,3 +292,9 @@ class callbuffer(object):
             if self.wp >= 0:
                 os.close(self.wp)
                 self.wp = -1
+
+def currentwatcher(io, current):
+    def check(io):
+        if not current:
+            io.stop()
+    io.loopcheck.add(check)