02c75a92a9175a61e3f257e05c866be93e305631
[ashd.git] / python3 / ashd / async.py
1 import sys, os, errno, threading, select, traceback
2
3 class epoller(object):
4     exc_handler = None
5
6     def __init__(self):
7         self.registered = {}
8         self.lock = threading.RLock()
9         self.ep = None
10         self.th = None
11         self._daemon = True
12
13     @staticmethod
14     def _evsfor(ch):
15         return ((select.EPOLLIN if ch.readable else 0) |
16                 (select.EPOLLOUT if ch.writable else 0))
17
18     def _ckrun(self):
19         if self.registered and self.th is None:
20             th = threading.Thread(target=self._run, name="Async epoll thread")
21             th.daemon = self._daemon
22             th.start()
23             self.th = th
24
25     def exception(self, ch, *exc):
26         self.remove(ch)
27         if self.exc_handler is None:
28             traceback.print_exception(exc)
29         else:
30             self.exc_handler(ch, *exc)
31
32     def _cb(self, ch, nm):
33         try:
34             m = getattr(ch, nm, None)
35             if m is None:
36                 raise AttributeError("%r has no %s method" % (ch, nm))
37             m()
38         except Exception as exc:
39             self.exception(ch, *sys.exc_info())
40
41     def _run(self):
42         ep = select.epoll()
43         try:
44             with self.lock:
45                 for fd, (ob, evs) in self.registered.items():
46                     ep.register(fd, evs)
47                 self.ep = ep
48
49             while self.registered:
50                 try:
51                     evlist = ep.poll(10)
52                 except IOError as exc:
53                     if exc.errno == errno.EINTR:
54                         continue
55                     raise
56                 for fd, evs in evlist:
57                     with self.lock:
58                         if fd not in self.registered:
59                             continue
60                         ch, cevs = self.registered[fd]
61                         if fd in self.registered and evs & (select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR):
62                             self._cb(ch, "read")
63                         if fd in self.registered and evs & select.EPOLLOUT:
64                             self._cb(ch, "write")
65                         if fd in self.registered:
66                             nevs = self._evsfor(ch)
67                             if nevs == 0:
68                                 del self.registered[fd]
69                                 ep.unregister(fd)
70                                 self._cb(ch, "close")
71                             elif nevs != cevs:
72                                 self.registered[fd] = ch, nevs
73                                 ep.modify(fd, nevs)
74
75         finally:
76             with self.lock:
77                 self.th = None
78                 self.ep = None
79                 self._ckrun()
80             ep.close()
81
82     @property
83     def daemon(self): return self._daemon
84     @daemon.setter
85     def daemon(self, value):
86         self._daemon = bool(value)
87         with self.lock:
88             if self.th is not None:
89                 self.th = daemon = self._daemon
90
91     def add(self, ch):
92         with self.lock:
93             fd = ch.fileno()
94             if fd in self.registered:
95                 raise KeyError("fd %i is already registered" % fd)
96             evs = self._evsfor(ch)
97             if evs == 0:
98                 ch.close()
99                 return
100             ch.watcher = self
101             self.registered[fd] = (ch, evs)
102             if self.ep:
103                 self.ep.register(fd, evs)
104             self._ckrun()
105
106     def remove(self, ch, ignore=False):
107         with self.lock:
108             fd = ch.fileno()
109             if fd not in self.registered:
110                 if ignore:
111                     return
112                 raise KeyError("fd %i is not registered" % fd)
113             pch, cevs = self.registered[fd]
114             if pch is not ch:
115                 raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
116             del self.registered[fd]
117             if self.ep:
118                 self.ep.unregister(fd)
119             ch.close()
120
121     def update(self, ch, ignore=False):
122         with self.lock:
123             fd = ch.fileno()
124             if fd not in self.registered:
125                 if ignore:
126                     return
127                 raise KeyError("fd %i is not registered" % fd)
128             pch, cevs = self.registered[fd]
129             if pch is not ch:
130                 raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
131             evs = self._evsfor(ch)
132             if evs == 0:
133                 del self.registered[fd]
134                 if self.ep:
135                     self.ep.unregister(fd)
136                 ch.close()
137             elif evs != cevs:
138                 self.registered[fd] = ch, evs
139                 if self.ep:
140                     self.ep.modify(fd, evs)
141
142 def watcher():
143     return epoller()
144
145 class sockbuffer(object):
146     def __init__(self, sk):
147         self.sk = sk
148         self.eof = False
149         self.obuf = bytearray()
150         self.watcher = None
151
152     def fileno(self):
153         return self.sk.fileno()
154
155     def close(self):
156         self.sk.close()
157
158     def gotdata(self, data):
159         if data == b"":
160             self.eof = True
161
162     def send(self, data, eof=False):
163         self.obuf.extend(data)
164         if eof:
165             self.eof = True
166         if self.watcher is not None:
167             self.watcher.update(self, True)
168
169     @property
170     def readable(self):
171         return not self.eof
172     def read(self):
173         try:
174             data = self.sk.recv(1024)
175             self.gotdata(data)
176         except IOError:
177             self.obuf[:] = b""
178             self.eof = True
179
180     @property
181     def writable(self):
182         return bool(self.obuf);
183     def write(self):
184         try:
185             ret = self.sk.send(self.obuf)
186             self.obuf[:ret] = b""
187         except IOError:
188             self.obuf[:] = b""
189             self.eof = True
190
191 class callbuffer(object):
192     def __init__(self):
193         self.queue = []
194         self.rp, self.wp = os.pipe()
195         self.lock = threading.Lock()
196         self.eof = False
197
198     def fileno(self):
199         return self.rp
200
201     def close(self):
202         with self.lock:
203             try:
204                 if self.wp >= 0:
205                     os.close(self.wp)
206                 self.wp = -1
207             finally:
208                 if self.rp >= 0:
209                     os.close(self.rp)
210                 self.rp = -1
211
212     @property
213     def readable(self):
214         return not self.eof
215     def read(self):
216         with self.lock:
217             try:
218                 data = os.read(self.rp, 1024)
219                 if data == b"":
220                     self.eof = True
221             except IOError:
222                 self.eof = True
223             cbs = list(self.queue)
224             self.queue[:] = []
225         for cb in cbs:
226             cb()
227
228     writable = False
229
230     def call(self, cb):
231         with self.lock:
232             if self.wp < 0:
233                 raise Exception("stopped")
234             self.queue.append(cb)
235             os.write(self.wp, b"a")
236
237     def stop(self):
238         with self.lock:
239             if self.wp >= 0:
240                 os.close(self.wp)
241                 self.wp = -1