238f6b72b68d516b2c9d2179a4ac652399aa91cb
[ashd.git] / python3 / ashd / async.py
1 import sys, os, errno, threading, select, traceback
2
3 class epoller(object):
4     exc_handler = None
5
6     def __init__(self, check=None):
7         self.registered = {}
8         self.fdcache = {}
9         self.lock = threading.RLock()
10         self.ep = None
11         self.th = None
12         self.stopped = False
13         self.loopcheck = set()
14         if check is not None:
15             self.loopcheck.add(check)
16         self._daemon = True
17
18     @staticmethod
19     def _evsfor(ch):
20         return ((select.EPOLLIN if ch.readable else 0) |
21                 (select.EPOLLOUT if ch.writable else 0))
22
23     def _ckrun(self):
24         if self.registered and self.th is None:
25             th = threading.Thread(target=self._run, name="Async epoll thread")
26             th.daemon = self._daemon
27             th.start()
28             self.th = th
29
30     def exception(self, ch, *exc):
31         self.remove(ch)
32         if self.exc_handler is None:
33             traceback.print_exception(*exc)
34         else:
35             self.exc_handler(ch, *exc)
36
37     def _cb(self, ch, nm):
38         try:
39             m = getattr(ch, nm, None)
40             if m is None:
41                 raise AttributeError("%r has no %s method" % (ch, nm))
42             m()
43         except Exception as exc:
44             self.exception(ch, *sys.exc_info())
45
46     def _closeall(self):
47         with self.lock:
48             while self.registered:
49                 fd, (ch, evs) = next(iter(self.registered.items()))
50                 del self.registered[fd]
51                 self.ep.unregister(fd)
52                 self._cb(ch, "close")
53
54     def _run(self):
55         ep = select.epoll()
56         try:
57             with self.lock:
58                 try:
59                     for fd, (ob, evs) in self.registered.items():
60                         ep.register(fd, evs)
61                 except:
62                     self.registered.clear()
63                     raise
64                 self.ep = ep
65
66             while self.registered:
67                 for ck in self.loopcheck:
68                     ck(self)
69                 if self.stopped:
70                     self._closeall()
71                     break
72                 try:
73                     evlist = ep.poll(10)
74                 except IOError as exc:
75                     if exc.errno == errno.EINTR:
76                         continue
77                     raise
78                 for fd, evs in evlist:
79                     with self.lock:
80                         if fd not in self.registered:
81                             continue
82                         ch, cevs = self.registered[fd]
83                         if fd in self.registered and evs & (select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR):
84                             self._cb(ch, "read")
85                         if fd in self.registered and evs & select.EPOLLOUT:
86                             self._cb(ch, "write")
87                         if fd in self.registered:
88                             nevs = self._evsfor(ch)
89                             if nevs == 0:
90                                 del self.fdcache[ch]
91                                 del self.registered[fd]
92                                 ep.unregister(fd)
93                                 self._cb(ch, "close")
94                             elif nevs != cevs:
95                                 self.registered[fd] = ch, nevs
96                                 ep.modify(fd, nevs)
97
98         finally:
99             with self.lock:
100                 self.th = None
101                 self.ep = None
102                 self._ckrun()
103             ep.close()
104
105     @property
106     def daemon(self): return self._daemon
107     @daemon.setter
108     def daemon(self, value):
109         self._daemon = bool(value)
110         with self.lock:
111             if self.th is not None:
112                 self.th = daemon = self._daemon
113
114     def add(self, ch):
115         with self.lock:
116             fd = ch.fileno()
117             if fd in self.registered:
118                 raise KeyError("fd %i is already registered" % fd)
119             evs = self._evsfor(ch)
120             if evs == 0:
121                 ch.close()
122                 return
123             ch.watcher = self
124             self.fdcache[ch] = fd
125             self.registered[fd] = (ch, evs)
126             if self.ep:
127                 self.ep.register(fd, evs)
128             self._ckrun()
129
130     def remove(self, ch, ignore=False):
131         with self.lock:
132             try:
133                 fd = self.fdcache[ch]
134             except KeyError:
135                 if ignore:
136                     return
137                 raise KeyError("fd %i is not registered" % fd)
138             pch, cevs = self.registered[fd]
139             if pch is not ch:
140                 raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
141             del self.fdcache[ch]
142             del self.registered[fd]
143             if self.ep:
144                 self.ep.unregister(fd)
145             ch.close()
146
147     def update(self, ch, ignore=False):
148         with self.lock:
149             try:
150                 fd = self.fdcache[ch]
151             except KeyError:
152                 if ignore:
153                     return
154                 raise KeyError("fd %i is not registered" % fd)
155             pch, cevs = self.registered[fd]
156             if pch is not ch:
157                 raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
158             evs = self._evsfor(ch)
159             if evs == 0:
160                 del self.fdcache[ch]
161                 del self.registered[fd]
162                 if self.ep:
163                     self.ep.unregister(fd)
164                 ch.close()
165             elif evs != cevs:
166                 self.registered[fd] = ch, evs
167                 if self.ep:
168                     self.ep.modify(fd, evs)
169
170     def stop(self):
171         if threading.current_thread() == self.th:
172             self.stopped = True
173         else:
174             def tgt():
175                 self.stopped = True
176             cb = callbuffer()
177             cb.call(tgt)
178             cb.stop()
179             self.add(cb)
180
181 def watcher():
182     return epoller()
183
184 class channel(object):
185     readable = False
186     writable = False
187
188     def __init__(self):
189         self.watcher = None
190
191     def fileno(self):
192         raise NotImplementedError("fileno()")
193
194     def close(self):
195         pass
196
197 class sockbuffer(channel):
198     def __init__(self, socket, **kwargs):
199         super().__init__(**kwargs)
200         self.sk = socket
201         self.eof = False
202         self.obuf = bytearray()
203
204     def fileno(self):
205         return self.sk.fileno()
206
207     def close(self):
208         self.sk.close()
209
210     def gotdata(self, data):
211         if data == b"":
212             self.eof = True
213
214     def send(self, data, eof=False):
215         self.obuf.extend(data)
216         if eof:
217             self.eof = True
218         if self.watcher is not None:
219             self.watcher.update(self, True)
220
221     @property
222     def readable(self):
223         return not self.eof
224     def read(self):
225         try:
226             data = self.sk.recv(1024)
227             self.gotdata(data)
228         except IOError:
229             self.obuf[:] = b""
230             self.eof = True
231
232     @property
233     def writable(self):
234         return bool(self.obuf);
235     def write(self):
236         try:
237             ret = self.sk.send(self.obuf)
238             self.obuf[:ret] = b""
239         except IOError:
240             self.obuf[:] = b""
241             self.eof = True
242
243 class callbuffer(channel):
244     def __init__(self, **kwargs):
245         super().__init__(**kwargs)
246         self.queue = []
247         self.rp, self.wp = os.pipe()
248         self.lock = threading.Lock()
249         self.eof = False
250
251     def fileno(self):
252         return self.rp
253
254     def close(self):
255         with self.lock:
256             try:
257                 if self.wp >= 0:
258                     os.close(self.wp)
259                 self.wp = -1
260             finally:
261                 if self.rp >= 0:
262                     os.close(self.rp)
263                 self.rp = -1
264
265     @property
266     def readable(self):
267         return not self.eof
268     def read(self):
269         with self.lock:
270             try:
271                 data = os.read(self.rp, 1024)
272                 if data == b"":
273                     self.eof = True
274             except IOError:
275                 self.eof = True
276             cbs = list(self.queue)
277             self.queue[:] = []
278         for cb in cbs:
279             cb()
280
281     writable = False
282
283     def call(self, cb):
284         with self.lock:
285             if self.wp < 0:
286                 raise Exception("stopped")
287             self.queue.append(cb)
288             os.write(self.wp, b"a")
289
290     def stop(self):
291         with self.lock:
292             if self.wp >= 0:
293                 os.close(self.wp)
294                 self.wp = -1
295
296 def currentwatcher(io, current):
297     def check(io):
298         if not current:
299             io.stop()
300     io.loopcheck.add(check)