python: Added a channel superclass for ashd.async.
[ashd.git] / python3 / ashd / async.py
1 import sys, os, errno, threading, select, traceback
2
3 class epoller(object):
4     exc_handler = None
5
6     def __init__(self):
7         self.registered = {}
8         self.lock = threading.RLock()
9         self.ep = None
10         self.th = None
11         self.stopped = False
12         self._daemon = True
13
14     @staticmethod
15     def _evsfor(ch):
16         return ((select.EPOLLIN if ch.readable else 0) |
17                 (select.EPOLLOUT if ch.writable else 0))
18
19     def _ckrun(self):
20         if self.registered and self.th is None:
21             th = threading.Thread(target=self._run, name="Async epoll thread")
22             th.daemon = self._daemon
23             th.start()
24             self.th = th
25
26     def exception(self, ch, *exc):
27         self.remove(ch)
28         if self.exc_handler is None:
29             traceback.print_exception(exc)
30         else:
31             self.exc_handler(ch, *exc)
32
33     def _cb(self, ch, nm):
34         try:
35             m = getattr(ch, nm, None)
36             if m is None:
37                 raise AttributeError("%r has no %s method" % (ch, nm))
38             m()
39         except Exception as exc:
40             self.exception(ch, *sys.exc_info())
41
42     def _closeall(self):
43         with self.lock:
44             while self.registered:
45                 fd, (ch, evs) = next(iter(self.registered.items()))
46                 del self.registered[fd]
47                 self.ep.unregister(fd)
48                 self._cb(ch, "close")
49
50     def _run(self):
51         ep = select.epoll()
52         try:
53             with self.lock:
54                 for fd, (ob, evs) in self.registered.items():
55                     ep.register(fd, evs)
56                 self.ep = ep
57
58             while self.registered:
59                 if self.stopped:
60                     self._closeall()
61                     break
62                 try:
63                     evlist = ep.poll(10)
64                 except IOError as exc:
65                     if exc.errno == errno.EINTR:
66                         continue
67                     raise
68                 for fd, evs in evlist:
69                     with self.lock:
70                         if fd not in self.registered:
71                             continue
72                         ch, cevs = self.registered[fd]
73                         if fd in self.registered and evs & (select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR):
74                             self._cb(ch, "read")
75                         if fd in self.registered and evs & select.EPOLLOUT:
76                             self._cb(ch, "write")
77                         if fd in self.registered:
78                             nevs = self._evsfor(ch)
79                             if nevs == 0:
80                                 del self.registered[fd]
81                                 ep.unregister(fd)
82                                 self._cb(ch, "close")
83                             elif nevs != cevs:
84                                 self.registered[fd] = ch, nevs
85                                 ep.modify(fd, nevs)
86
87         finally:
88             with self.lock:
89                 self.th = None
90                 self.ep = None
91                 self._ckrun()
92             ep.close()
93
94     @property
95     def daemon(self): return self._daemon
96     @daemon.setter
97     def daemon(self, value):
98         self._daemon = bool(value)
99         with self.lock:
100             if self.th is not None:
101                 self.th = daemon = self._daemon
102
103     def add(self, ch):
104         with self.lock:
105             fd = ch.fileno()
106             if fd in self.registered:
107                 raise KeyError("fd %i is already registered" % fd)
108             evs = self._evsfor(ch)
109             if evs == 0:
110                 ch.close()
111                 return
112             ch.watcher = self
113             self.registered[fd] = (ch, evs)
114             if self.ep:
115                 self.ep.register(fd, evs)
116             self._ckrun()
117
118     def remove(self, ch, ignore=False):
119         with self.lock:
120             fd = ch.fileno()
121             if fd not in self.registered:
122                 if ignore:
123                     return
124                 raise KeyError("fd %i is not registered" % fd)
125             pch, cevs = self.registered[fd]
126             if pch is not ch:
127                 raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
128             del self.registered[fd]
129             if self.ep:
130                 self.ep.unregister(fd)
131             ch.close()
132
133     def update(self, ch, ignore=False):
134         with self.lock:
135             fd = ch.fileno()
136             if fd not in self.registered:
137                 if ignore:
138                     return
139                 raise KeyError("fd %i is not registered" % fd)
140             pch, cevs = self.registered[fd]
141             if pch is not ch:
142                 raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
143             evs = self._evsfor(ch)
144             if evs == 0:
145                 del self.registered[fd]
146                 if self.ep:
147                     self.ep.unregister(fd)
148                 ch.close()
149             elif evs != cevs:
150                 self.registered[fd] = ch, evs
151                 if self.ep:
152                     self.ep.modify(fd, evs)
153
154     def stop(self):
155         if threading.current_thread() == self.th:
156             self.stopped = True
157         else:
158             def tgt():
159                 self.stopped = True
160             cb = callbuffer()
161             cb.call(tgt)
162             cb.stop()
163             self.add(cb)
164
165 def watcher():
166     return epoller()
167
168 class channel(object):
169     readable = False
170     writable = False
171
172     def __init__(self):
173         self.watcher = None
174
175     def fileno(self):
176         raise NotImplementedError("fileno()")
177
178     def close(self):
179         pass
180
181 class sockbuffer(channel):
182     def __init__(self, socket, **kwargs):
183         super().__init__(**kwargs)
184         self.sk = socket
185         self.eof = False
186         self.obuf = bytearray()
187
188     def fileno(self):
189         return self.sk.fileno()
190
191     def close(self):
192         self.sk.close()
193
194     def gotdata(self, data):
195         if data == b"":
196             self.eof = True
197
198     def send(self, data, eof=False):
199         self.obuf.extend(data)
200         if eof:
201             self.eof = True
202         if self.watcher is not None:
203             self.watcher.update(self, True)
204
205     @property
206     def readable(self):
207         return not self.eof
208     def read(self):
209         try:
210             data = self.sk.recv(1024)
211             self.gotdata(data)
212         except IOError:
213             self.obuf[:] = b""
214             self.eof = True
215
216     @property
217     def writable(self):
218         return bool(self.obuf);
219     def write(self):
220         try:
221             ret = self.sk.send(self.obuf)
222             self.obuf[:ret] = b""
223         except IOError:
224             self.obuf[:] = b""
225             self.eof = True
226
227 class callbuffer(channel):
228     def __init__(self, **kwargs):
229         super().__init__(**kwargs)
230         self.queue = []
231         self.rp, self.wp = os.pipe()
232         self.lock = threading.Lock()
233         self.eof = False
234
235     def fileno(self):
236         return self.rp
237
238     def close(self):
239         with self.lock:
240             try:
241                 if self.wp >= 0:
242                     os.close(self.wp)
243                 self.wp = -1
244             finally:
245                 if self.rp >= 0:
246                     os.close(self.rp)
247                 self.rp = -1
248
249     @property
250     def readable(self):
251         return not self.eof
252     def read(self):
253         with self.lock:
254             try:
255                 data = os.read(self.rp, 1024)
256                 if data == b"":
257                     self.eof = True
258             except IOError:
259                 self.eof = True
260             cbs = list(self.queue)
261             self.queue[:] = []
262         for cb in cbs:
263             cb()
264
265     writable = False
266
267     def call(self, cb):
268         with self.lock:
269             if self.wp < 0:
270                 raise Exception("stopped")
271             self.queue.append(cb)
272             os.write(self.wp, b"a")
273
274     def stop(self):
275         with self.lock:
276             if self.wp >= 0:
277                 os.close(self.wp)
278                 self.wp = -1
279
280 def currentwatcher(io, current):
281     def run():
282         while current:
283             current.wait()
284         io.stop()
285     threading.Thread(target=run, name="Current watcher").start()