python3: To be safe, abort entirely if initial epoller registration fails.
[ashd.git] / python3 / ashd / async.py
1 import sys, os, errno, threading, select, traceback
2
3 class epoller(object):
4     exc_handler = None
5
6     def __init__(self, check=None):
7         self.registered = {}
8         self.lock = threading.RLock()
9         self.ep = None
10         self.th = None
11         self.stopped = False
12         self.loopcheck = set()
13         if check is not None:
14             self.loopcheck.add(check)
15         self._daemon = True
16
17     @staticmethod
18     def _evsfor(ch):
19         return ((select.EPOLLIN if ch.readable else 0) |
20                 (select.EPOLLOUT if ch.writable else 0))
21
22     def _ckrun(self):
23         if self.registered and self.th is None:
24             th = threading.Thread(target=self._run, name="Async epoll thread")
25             th.daemon = self._daemon
26             th.start()
27             self.th = th
28
29     def exception(self, ch, *exc):
30         self.remove(ch)
31         if self.exc_handler is None:
32             traceback.print_exception(*exc)
33         else:
34             self.exc_handler(ch, *exc)
35
36     def _cb(self, ch, nm):
37         try:
38             m = getattr(ch, nm, None)
39             if m is None:
40                 raise AttributeError("%r has no %s method" % (ch, nm))
41             m()
42         except Exception as exc:
43             self.exception(ch, *sys.exc_info())
44
45     def _closeall(self):
46         with self.lock:
47             while self.registered:
48                 fd, (ch, evs) = next(iter(self.registered.items()))
49                 del self.registered[fd]
50                 self.ep.unregister(fd)
51                 self._cb(ch, "close")
52
53     def _run(self):
54         ep = select.epoll()
55         try:
56             with self.lock:
57                 try:
58                     for fd, (ob, evs) in self.registered.items():
59                         ep.register(fd, evs)
60                 except:
61                     self.registered.clear()
62                     raise
63                 self.ep = ep
64
65             while self.registered:
66                 for ck in self.loopcheck:
67                     ck(self)
68                 if self.stopped:
69                     self._closeall()
70                     break
71                 try:
72                     evlist = ep.poll(10)
73                 except IOError as exc:
74                     if exc.errno == errno.EINTR:
75                         continue
76                     raise
77                 for fd, evs in evlist:
78                     with self.lock:
79                         if fd not in self.registered:
80                             continue
81                         ch, cevs = self.registered[fd]
82                         if fd in self.registered and evs & (select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR):
83                             self._cb(ch, "read")
84                         if fd in self.registered and evs & select.EPOLLOUT:
85                             self._cb(ch, "write")
86                         if fd in self.registered:
87                             nevs = self._evsfor(ch)
88                             if nevs == 0:
89                                 del self.registered[fd]
90                                 ep.unregister(fd)
91                                 self._cb(ch, "close")
92                             elif nevs != cevs:
93                                 self.registered[fd] = ch, nevs
94                                 ep.modify(fd, nevs)
95
96         finally:
97             with self.lock:
98                 self.th = None
99                 self.ep = None
100                 self._ckrun()
101             ep.close()
102
103     @property
104     def daemon(self): return self._daemon
105     @daemon.setter
106     def daemon(self, value):
107         self._daemon = bool(value)
108         with self.lock:
109             if self.th is not None:
110                 self.th = daemon = self._daemon
111
112     def add(self, ch):
113         with self.lock:
114             fd = ch.fileno()
115             if fd in self.registered:
116                 raise KeyError("fd %i is already registered" % fd)
117             evs = self._evsfor(ch)
118             if evs == 0:
119                 ch.close()
120                 return
121             ch.watcher = self
122             self.registered[fd] = (ch, evs)
123             if self.ep:
124                 self.ep.register(fd, evs)
125             self._ckrun()
126
127     def remove(self, ch, ignore=False):
128         with self.lock:
129             fd = ch.fileno()
130             if fd not in self.registered:
131                 if ignore:
132                     return
133                 raise KeyError("fd %i is not registered" % fd)
134             pch, cevs = self.registered[fd]
135             if pch is not ch:
136                 raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
137             del self.registered[fd]
138             if self.ep:
139                 self.ep.unregister(fd)
140             ch.close()
141
142     def update(self, ch, ignore=False):
143         with self.lock:
144             fd = ch.fileno()
145             if fd not in self.registered:
146                 if ignore:
147                     return
148                 raise KeyError("fd %i is not registered" % fd)
149             pch, cevs = self.registered[fd]
150             if pch is not ch:
151                 raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
152             evs = self._evsfor(ch)
153             if evs == 0:
154                 del self.registered[fd]
155                 if self.ep:
156                     self.ep.unregister(fd)
157                 ch.close()
158             elif evs != cevs:
159                 self.registered[fd] = ch, evs
160                 if self.ep:
161                     self.ep.modify(fd, evs)
162
163     def stop(self):
164         if threading.current_thread() == self.th:
165             self.stopped = True
166         else:
167             def tgt():
168                 self.stopped = True
169             cb = callbuffer()
170             cb.call(tgt)
171             cb.stop()
172             self.add(cb)
173
174 def watcher():
175     return epoller()
176
177 class channel(object):
178     readable = False
179     writable = False
180
181     def __init__(self):
182         self.watcher = None
183
184     def fileno(self):
185         raise NotImplementedError("fileno()")
186
187     def close(self):
188         pass
189
190 class sockbuffer(channel):
191     def __init__(self, socket, **kwargs):
192         super().__init__(**kwargs)
193         self.sk = socket
194         self.eof = False
195         self.obuf = bytearray()
196
197     def fileno(self):
198         return self.sk.fileno()
199
200     def close(self):
201         self.sk.close()
202
203     def gotdata(self, data):
204         if data == b"":
205             self.eof = True
206
207     def send(self, data, eof=False):
208         self.obuf.extend(data)
209         if eof:
210             self.eof = True
211         if self.watcher is not None:
212             self.watcher.update(self, True)
213
214     @property
215     def readable(self):
216         return not self.eof
217     def read(self):
218         try:
219             data = self.sk.recv(1024)
220             self.gotdata(data)
221         except IOError:
222             self.obuf[:] = b""
223             self.eof = True
224
225     @property
226     def writable(self):
227         return bool(self.obuf);
228     def write(self):
229         try:
230             ret = self.sk.send(self.obuf)
231             self.obuf[:ret] = b""
232         except IOError:
233             self.obuf[:] = b""
234             self.eof = True
235
236 class callbuffer(channel):
237     def __init__(self, **kwargs):
238         super().__init__(**kwargs)
239         self.queue = []
240         self.rp, self.wp = os.pipe()
241         self.lock = threading.Lock()
242         self.eof = False
243
244     def fileno(self):
245         return self.rp
246
247     def close(self):
248         with self.lock:
249             try:
250                 if self.wp >= 0:
251                     os.close(self.wp)
252                 self.wp = -1
253             finally:
254                 if self.rp >= 0:
255                     os.close(self.rp)
256                 self.rp = -1
257
258     @property
259     def readable(self):
260         return not self.eof
261     def read(self):
262         with self.lock:
263             try:
264                 data = os.read(self.rp, 1024)
265                 if data == b"":
266                     self.eof = True
267             except IOError:
268                 self.eof = True
269             cbs = list(self.queue)
270             self.queue[:] = []
271         for cb in cbs:
272             cb()
273
274     writable = False
275
276     def call(self, cb):
277         with self.lock:
278             if self.wp < 0:
279                 raise Exception("stopped")
280             self.queue.append(cb)
281             os.write(self.wp, b"a")
282
283     def stop(self):
284         with self.lock:
285             if self.wp >= 0:
286                 os.close(self.wp)
287                 self.wp = -1
288
289 def currentwatcher(io, current):
290     def check(io):
291         if not current:
292             io.stop()
293     io.loopcheck.add(check)