99da89af0d7bdc6e656ec968412b7b8b5dcfcf0f
[ashd.git] / python3 / ashd / async.py
1 import sys, os, errno, threading, select, traceback
2
3 class epoller(object):
4     exc_handler = None
5
6     def __init__(self, check=None):
7         self.registered = {}
8         self.lock = threading.RLock()
9         self.ep = None
10         self.th = None
11         self.stopped = False
12         self.loopcheck = set()
13         if check is not None:
14             self.loopcheck.add(check)
15         self._daemon = True
16
17     @staticmethod
18     def _evsfor(ch):
19         return ((select.EPOLLIN if ch.readable else 0) |
20                 (select.EPOLLOUT if ch.writable else 0))
21
22     def _ckrun(self):
23         if self.registered and self.th is None:
24             th = threading.Thread(target=self._run, name="Async epoll thread")
25             th.daemon = self._daemon
26             th.start()
27             self.th = th
28
29     def exception(self, ch, *exc):
30         self.remove(ch)
31         if self.exc_handler is None:
32             traceback.print_exception(*exc)
33         else:
34             self.exc_handler(ch, *exc)
35
36     def _cb(self, ch, nm):
37         try:
38             m = getattr(ch, nm, None)
39             if m is None:
40                 raise AttributeError("%r has no %s method" % (ch, nm))
41             m()
42         except Exception as exc:
43             self.exception(ch, *sys.exc_info())
44
45     def _closeall(self):
46         with self.lock:
47             while self.registered:
48                 fd, (ch, evs) = next(iter(self.registered.items()))
49                 del self.registered[fd]
50                 self.ep.unregister(fd)
51                 self._cb(ch, "close")
52
53     def _run(self):
54         ep = select.epoll()
55         try:
56             with self.lock:
57                 for fd, (ob, evs) in self.registered.items():
58                     ep.register(fd, evs)
59                 self.ep = ep
60
61             while self.registered:
62                 for ck in self.loopcheck:
63                     ck(self)
64                 if self.stopped:
65                     self._closeall()
66                     break
67                 try:
68                     evlist = ep.poll(10)
69                 except IOError as exc:
70                     if exc.errno == errno.EINTR:
71                         continue
72                     raise
73                 for fd, evs in evlist:
74                     with self.lock:
75                         if fd not in self.registered:
76                             continue
77                         ch, cevs = self.registered[fd]
78                         if fd in self.registered and evs & (select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR):
79                             self._cb(ch, "read")
80                         if fd in self.registered and evs & select.EPOLLOUT:
81                             self._cb(ch, "write")
82                         if fd in self.registered:
83                             nevs = self._evsfor(ch)
84                             if nevs == 0:
85                                 del self.registered[fd]
86                                 ep.unregister(fd)
87                                 self._cb(ch, "close")
88                             elif nevs != cevs:
89                                 self.registered[fd] = ch, nevs
90                                 ep.modify(fd, nevs)
91
92         finally:
93             with self.lock:
94                 self.th = None
95                 self.ep = None
96                 self._ckrun()
97             ep.close()
98
99     @property
100     def daemon(self): return self._daemon
101     @daemon.setter
102     def daemon(self, value):
103         self._daemon = bool(value)
104         with self.lock:
105             if self.th is not None:
106                 self.th = daemon = self._daemon
107
108     def add(self, ch):
109         with self.lock:
110             fd = ch.fileno()
111             if fd in self.registered:
112                 raise KeyError("fd %i is already registered" % fd)
113             evs = self._evsfor(ch)
114             if evs == 0:
115                 ch.close()
116                 return
117             ch.watcher = self
118             self.registered[fd] = (ch, evs)
119             if self.ep:
120                 self.ep.register(fd, evs)
121             self._ckrun()
122
123     def remove(self, ch, ignore=False):
124         with self.lock:
125             fd = ch.fileno()
126             if fd not in self.registered:
127                 if ignore:
128                     return
129                 raise KeyError("fd %i is not registered" % fd)
130             pch, cevs = self.registered[fd]
131             if pch is not ch:
132                 raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
133             del self.registered[fd]
134             if self.ep:
135                 self.ep.unregister(fd)
136             ch.close()
137
138     def update(self, ch, ignore=False):
139         with self.lock:
140             fd = ch.fileno()
141             if fd not in self.registered:
142                 if ignore:
143                     return
144                 raise KeyError("fd %i is not registered" % fd)
145             pch, cevs = self.registered[fd]
146             if pch is not ch:
147                 raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
148             evs = self._evsfor(ch)
149             if evs == 0:
150                 del self.registered[fd]
151                 if self.ep:
152                     self.ep.unregister(fd)
153                 ch.close()
154             elif evs != cevs:
155                 self.registered[fd] = ch, evs
156                 if self.ep:
157                     self.ep.modify(fd, evs)
158
159     def stop(self):
160         if threading.current_thread() == self.th:
161             self.stopped = True
162         else:
163             def tgt():
164                 self.stopped = True
165             cb = callbuffer()
166             cb.call(tgt)
167             cb.stop()
168             self.add(cb)
169
170 def watcher():
171     return epoller()
172
173 class channel(object):
174     readable = False
175     writable = False
176
177     def __init__(self):
178         self.watcher = None
179
180     def fileno(self):
181         raise NotImplementedError("fileno()")
182
183     def close(self):
184         pass
185
186 class sockbuffer(channel):
187     def __init__(self, socket, **kwargs):
188         super().__init__(**kwargs)
189         self.sk = socket
190         self.eof = False
191         self.obuf = bytearray()
192
193     def fileno(self):
194         return self.sk.fileno()
195
196     def close(self):
197         self.sk.close()
198
199     def gotdata(self, data):
200         if data == b"":
201             self.eof = True
202
203     def send(self, data, eof=False):
204         self.obuf.extend(data)
205         if eof:
206             self.eof = True
207         if self.watcher is not None:
208             self.watcher.update(self, True)
209
210     @property
211     def readable(self):
212         return not self.eof
213     def read(self):
214         try:
215             data = self.sk.recv(1024)
216             self.gotdata(data)
217         except IOError:
218             self.obuf[:] = b""
219             self.eof = True
220
221     @property
222     def writable(self):
223         return bool(self.obuf);
224     def write(self):
225         try:
226             ret = self.sk.send(self.obuf)
227             self.obuf[:ret] = b""
228         except IOError:
229             self.obuf[:] = b""
230             self.eof = True
231
232 class callbuffer(channel):
233     def __init__(self, **kwargs):
234         super().__init__(**kwargs)
235         self.queue = []
236         self.rp, self.wp = os.pipe()
237         self.lock = threading.Lock()
238         self.eof = False
239
240     def fileno(self):
241         return self.rp
242
243     def close(self):
244         with self.lock:
245             try:
246                 if self.wp >= 0:
247                     os.close(self.wp)
248                 self.wp = -1
249             finally:
250                 if self.rp >= 0:
251                     os.close(self.rp)
252                 self.rp = -1
253
254     @property
255     def readable(self):
256         return not self.eof
257     def read(self):
258         with self.lock:
259             try:
260                 data = os.read(self.rp, 1024)
261                 if data == b"":
262                     self.eof = True
263             except IOError:
264                 self.eof = True
265             cbs = list(self.queue)
266             self.queue[:] = []
267         for cb in cbs:
268             cb()
269
270     writable = False
271
272     def call(self, cb):
273         with self.lock:
274             if self.wp < 0:
275                 raise Exception("stopped")
276             self.queue.append(cb)
277             os.write(self.wp, b"a")
278
279     def stop(self):
280         with self.lock:
281             if self.wp >= 0:
282                 os.close(self.wp)
283                 self.wp = -1
284
285 def currentwatcher(io, current):
286     def check(io):
287         if not current:
288             io.stop()
289     io.loopcheck.add(check)