Added the `ratequeue' program.
[ashd.git] / src / ratequeue.c
1 /*
2     ashd - A Sane HTTP Daemon
3     Copyright (C) 2008  Fredrik Tolf <fredrik@dolda2000.com>
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <stdlib.h>
20 #include <stdio.h>
21 #include <unistd.h>
22 #include <errno.h>
23 #include <string.h>
24 #include <time.h>
25 #include <signal.h>
26 #include <assert.h>
27 #include <sys/poll.h>
28 #include <arpa/inet.h>
29
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33 #include <utils.h>
34 #include <log.h>
35 #include <req.h>
36 #include <resp.h>
37 #include <proc.h>
38 #include <cf.h>
39
40 #define SBUCKETS 7
41
42 struct source {
43     char data[16];
44     unsigned int len, hash;
45 };
46
47 struct waiting {
48     struct hthead *req;
49     int fd;
50 };
51
52 struct bucket {
53     struct source id;
54     double level, last, etime;
55     typedbuf(struct waiting) brim;
56     int thpos;
57 };
58
59 struct btime {
60     struct bucket *bk;
61     double tm;
62 };
63
64 struct config {
65     double size, rate, retain;
66     int brimsize;
67 };
68
69 static struct bucket *sbuckets[1 << SBUCKETS];
70 static struct bucket **buckets = sbuckets;
71 static int hashlen = SBUCKETS, nbuckets = 0;
72 static typedbuf(struct btime) timeheap;
73 static int child, reload;
74 static double now;
75 static const struct config defcfg = {
76     .size = 100, .rate = 10,
77     .retain = 10, .brimsize = 10,
78 };
79 static struct config cf;
80
81 static double rtime(void)
82 {
83     static int init = 0;
84     static struct timespec or;
85     struct timespec ts;
86     
87     clock_gettime(CLOCK_MONOTONIC, &ts);
88     if(!init) {
89         or = ts;
90         init = 1;
91     }
92     return((ts.tv_sec - or.tv_sec) + ((ts.tv_nsec - or.tv_nsec) / 1000000000.0));
93 }
94
95 static struct source reqsource(struct hthead *req)
96 {
97     int i;
98     char *sa;
99     struct in_addr a4;
100     struct in6_addr a6;
101     struct source ret;
102     
103     ret = (struct source){};
104     if((sa = getheader(req, "X-Ash-Address")) != NULL) {
105         if(inet_pton(AF_INET, sa, &a4) == 1) {
106             memcpy(ret.data, &a4, ret.len = sizeof(a4));
107         } else if(inet_pton(AF_INET6, sa, &a6) == 1) {
108             memcpy(ret.data, &a6, ret.len = sizeof(a6));
109         }
110     }
111     for(i = 0, ret.hash = 0; i < ret.len; i++)
112         ret.hash = (ret.hash * 31) + ret.data[i];
113     return(ret);
114 }
115
116 static void rehash(int nlen)
117 {
118     int i, o, n, m, pl, nl;
119     struct bucket **new, **old;
120     
121     old = buckets;
122     if(nlen <= SBUCKETS) {
123         nlen = SBUCKETS;
124         new = sbuckets;
125     } else {
126         new = szmalloc(sizeof(*new) * (1 << nlen));
127     }
128     if(nlen == hashlen)
129         return;
130     assert(old != new);
131     pl = 1 << hashlen; nl = 1 << nlen; m = nl - 1;
132     for(i = 0; i < pl; i++) {
133         if(!old[i])
134             continue;
135         for(o = old[i]->id.hash & m, n = 0; n < nl; o = (o + 1) & m, n++) {
136             if(!new[o]) {
137                 new[o] = old[i];
138                 break;
139             }
140         }
141     }
142     if(old != sbuckets)
143         free(old);
144     buckets = new;
145     hashlen = nlen;
146 }
147
148 static struct bucket *hashget(struct source *src)
149 {
150     unsigned int i, n, N, m;
151     struct bucket *bk;
152     
153     m = (N = (1 << hashlen)) - 1;
154     for(i = src->hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
155         bk = buckets[i];
156         if(bk && (bk->id.len == src->len) && !memcmp(bk->id.data, src->data, src->len))
157             return(bk);
158     }
159     for(i = src->hash & m; buckets[i]; i = (i + 1) & m);
160     buckets[i] = bk = szmalloc(sizeof(*bk));
161     memcpy(&bk->id, src, sizeof(*src));
162     bk->last = bk->etime = now;
163     bk->thpos = -1;
164     if(++nbuckets > (1 << (hashlen - 1)))
165         rehash(hashlen + 1);
166     return(bk);
167 }
168
169 static void hashdel(struct bucket *bk)
170 {
171     unsigned int i, o, p, n, N, m;
172     struct bucket *sb;
173     
174     m = (N = (1 << hashlen)) - 1;
175     for(i = bk->id.hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
176         assert((sb = buckets[i]) != NULL);
177         if((sb->id.len == bk->id.len) && !memcmp(sb->id.data, bk->id.data, bk->id.len))
178             break;
179     }
180     assert(sb == bk);
181     buckets[i] = NULL;
182     for(o = (i + 1) & m; buckets[o] != NULL; o = (o + 1) & m) {
183         sb = buckets[o];
184         p = (sb->id.hash - i) & m;
185         if((p == 0) || (p > ((o - i) & m))) {
186             buckets[i] = sb;
187             buckets[o] = NULL;
188             i = o;
189         }
190     }
191     if(--nbuckets <= (1 << (hashlen - 3)))
192         rehash(hashlen - 1);
193 }
194
195 static void thraise(struct btime bt, int n)
196 {
197     int p;
198     
199     while(n > 0) {
200         p = (n - 1) >> 1;
201         if(timeheap.b[p].tm <= bt.tm)
202             break;
203         (timeheap.b[n] = timeheap.b[p]).bk->thpos = n;
204         n = p;
205     }
206     (timeheap.b[n] = bt).bk->thpos = n;
207 }
208
209 static void thlower(struct btime bt, int n)
210 {
211     int c1, c2, c;
212     
213     while(1) {
214         c2 = (c1 = (n << 1) + 1) + 1;
215         if(c1 >= timeheap.d)
216             break;
217         c = ((c2 < timeheap.d) && (timeheap.b[c2].tm < timeheap.b[c1].tm)) ? c2 : c1;
218         if(timeheap.b[c].tm > bt.tm)
219             break;
220         (timeheap.b[n] = timeheap.b[c]).bk->thpos = n;
221         n = c;
222     }
223     (timeheap.b[n] = bt).bk->thpos = n;
224 }
225
226 static void thadjust(struct btime bt, int n)
227 {
228     if((n > 0) && (timeheap.b[(n - 1) >> 1].tm > bt.tm))
229         thraise(bt, n);
230     else
231         thlower(bt, n);
232 }
233
234 static void freebucket(struct bucket *bk)
235 {
236     int i, n;
237     struct btime r;
238     
239     hashdel(bk);
240     if((n = bk->thpos) >= 0) {
241         r = timeheap.b[--timeheap.d];
242         if(n < timeheap.d)
243             thadjust(r, n);
244     }
245     for(i = 0; i < bk->brim.d; i++) {
246         freehthead(bk->brim.b[i].req);
247         close(bk->brim.b[i].fd);
248     }
249     buffree(bk->brim);
250     free(bk);
251 }
252
253 static void updbtime(struct bucket *bk)
254 {
255     double tm, tm2;
256     
257     tm = (bk->level == 0) ? (bk->etime + cf.retain) : (bk->last + (bk->level / cf.rate) + cf.retain);
258     if(bk->brim.d > 0) {
259         tm2 = bk->last + ((bk->level - cf.size) / cf.rate);
260         tm = (tm2 < tm) ? tm2 : tm;
261     }
262     if(bk->thpos < 0) {
263         sizebuf(timeheap, ++timeheap.d);
264         thraise((struct btime){bk, tm}, timeheap.d - 1);
265     } else {
266         thadjust((struct btime){bk, tm}, bk->thpos);
267     }
268 }
269
270 static void tickbucket(struct bucket *bk)
271 {
272     double delta, ll;
273     
274     delta = now - bk->last;
275     bk->last = now;
276     ll = bk->level;
277     if((bk->level -= delta * cf.rate) < 0) {
278         bk->level = 0;
279         if(ll > 0)
280             bk->etime = now;
281     }
282     while((bk->brim.d > 0) && (bk->level < cf.size)) {
283         if(sendreq(child, bk->brim.b[0].req, bk->brim.b[0].fd)) {
284             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
285             exit(1);
286         }
287         freehthead(bk->brim.b[0].req);
288         close(bk->brim.b[0].fd);
289         bufdel(bk->brim, 0);
290         bk->level += 1;
291     }
292 }
293
294 static void checkbtime(struct bucket *bk)
295 {
296     tickbucket(bk);
297     if((bk->level == 0) && (now >= bk->etime + cf.retain)) {
298         freebucket(bk);
299         return;
300     }
301     updbtime(bk);
302 }
303
304 static void serve(struct hthead *req, int fd)
305 {
306     struct source src;
307     struct bucket *bk;
308     
309     now = rtime();
310     src = reqsource(req);
311     bk = hashget(&src);
312     tickbucket(bk);
313     if(bk->level < cf.size) {
314         bk->level += 1;
315         if(sendreq(child, req, fd)) {
316             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
317             exit(1);
318         }
319         freehthead(req);
320         close(fd);
321     } else if(bk->brim.d < cf.brimsize) {
322         bufadd(bk->brim, ((struct waiting){.req = req, .fd = fd}));
323     } else {
324         simpleerror(fd, 429, "Too many requests", "Your client is being throttled for issuing too frequent requests.");
325         freehthead(req);
326         close(fd);
327     }
328     updbtime(bk);
329 }
330
331 static int parseint(char *str, int *dst)
332 {
333     long buf;
334     char *p;
335     
336     buf = strtol(str, &p, 0);
337     if((p == str) || *p)
338         return(-1);
339     *dst = buf;
340     return(0);
341 }
342
343 static int parsefloat(char *str, double *dst)
344 {
345     double buf;
346     char *p;
347     
348     buf = strtod(str, &p);
349     if((p == str) || *p)
350         return(-1);
351     *dst = buf;
352     return(0);
353 }
354
355 static int readconf(char *path, struct config *buf)
356 {
357     FILE *fp;
358     struct cfstate *s;
359     int rv;
360     
361     if((fp = fopen(path, "r")) == NULL) {
362         flog(LOG_ERR, "ratequeue: %s: %s", path, strerror(errno));
363         return(-1);
364     }
365     *buf = defcfg;
366     s = mkcfparser(fp, path);
367     rv = -1;
368     while(1) {
369         getcfline(s);
370         if(!strcmp(s->argv[0], "eof")) {
371             break;
372         } else if(!strcmp(s->argv[0], "size")) {
373             if((s->argc < 2) || parsefloat(s->argv[1], &buf->size)) {
374                 flog(LOG_ERR, "%s:%i: missing or invalid `size' argument");
375                 goto err;
376             }
377         } else if(!strcmp(s->argv[0], "rate")) {
378             if((s->argc < 2) || parsefloat(s->argv[1], &buf->rate)) {
379                 flog(LOG_ERR, "%s:%i: missing or invalid `rate' argument");
380                 goto err;
381             }
382         } else if(!strcmp(s->argv[0], "brim")) {
383             if((s->argc < 2) || parseint(s->argv[1], &buf->brimsize)) {
384                 flog(LOG_ERR, "%s:%i: missing or invalid `brim' argument");
385                 goto err;
386             }
387         } else {
388             flog(LOG_WARNING, "%s:%i: unknown directive `%s'", s->file, s->lno, s->argv[0]);
389         }
390     }
391     rv = 0;
392 err:
393     freecfparser(s);
394     fclose(fp);
395     return(rv);
396 }
397
398 static void huphandler(int sig)
399 {
400     reload = 1;
401 }
402
403 static void usage(FILE *out)
404 {
405     fprintf(out, "usage: ratequeue [-h] [-s BUCKET-SIZE] [-r RATE] [-b BRIM-SIZE] PROGRAM [ARGS...]\n");
406 }
407
408 int main(int argc, char **argv)
409 {
410     int c, rv;
411     int fd;
412     struct hthead *req;
413     struct pollfd pfd;
414     double timeout;
415     char *cfname;
416     struct config cfbuf;
417     
418     cf = defcfg;
419     cfname = NULL;
420     while((c = getopt(argc, argv, "+hc:s:r:b:")) >= 0) {
421         switch(c) {
422         case 'h':
423             usage(stdout);
424             return(0);
425         case 'c':
426             cfname = optarg;
427             break;
428         case 's':
429             parsefloat(optarg, &cf.size);
430             break;
431         case 'r':
432             parsefloat(optarg, &cf.rate);
433             break;
434         case 'b':
435             parseint(optarg, &cf.brimsize);
436             break;
437         }
438     }
439     if(argc - optind < 1) {
440         usage(stderr);
441         return(1);
442     }
443     if(cfname) {
444         if(readconf(cfname, &cfbuf))
445             return(1);
446         cf = cfbuf;
447     }
448     if((child = stdmkchild(argv + optind, NULL, NULL)) < 0) {
449         flog(LOG_ERR, "ratequeue: could not fork child: %s", strerror(errno));
450         return(1);
451     }
452     sigaction(SIGHUP, &(struct sigaction){.sa_handler = huphandler}, NULL);
453     while(1) {
454         if(reload) {
455             if(cfname) {
456                 if(!readconf(cfname, &cfbuf))
457                     cf = cfbuf;
458             }
459             reload = 0;
460         }
461         now = rtime();
462         pfd = (struct pollfd){.fd = 0, .events = POLLIN};
463         timeout = (timeheap.d > 0) ? timeheap.b[0].tm : -1;
464         if((rv = poll(&pfd, 1, (timeout < 0) ? -1 : (int)((timeout + 0.1 - now) * 1000))) < 0) {
465             if(errno != EINTR) {
466                 flog(LOG_ERR, "ratequeue: error in poll: %s", strerror(errno));
467                 exit(1);
468             }
469         }
470         if(pfd.revents) {
471             if((fd = recvreq(0, &req)) < 0) {
472                 if(errno == EINTR)
473                     continue;
474                 if(errno != 0)
475                     flog(LOG_ERR, "recvreq: %s", strerror(errno));
476                 break;
477             }
478             serve(req, fd);
479         }
480         while((timeheap.d > 0) && ((now = rtime()) >= timeheap.b[0].tm))
481             checkbtime(timeheap.b[0].bk);
482     }
483     return(0);
484 }