etc: Add environment option to run init.d/ashd silently.
[ashd.git] / src / ratequeue.c
1 /*
2     ashd - A Sane HTTP Daemon
3     Copyright (C) 2008  Fredrik Tolf <fredrik@dolda2000.com>
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <stdlib.h>
20 #include <stdio.h>
21 #include <unistd.h>
22 #include <errno.h>
23 #include <string.h>
24 #include <time.h>
25 #include <signal.h>
26 #include <assert.h>
27 #include <sys/poll.h>
28 #include <sys/socket.h>
29 #include <netinet/in.h>
30 #include <arpa/inet.h>
31
32 #ifdef HAVE_CONFIG_H
33 #include <config.h>
34 #endif
35 #include <utils.h>
36 #include <log.h>
37 #include <req.h>
38 #include <resp.h>
39 #include <proc.h>
40 #include <cf.h>
41
42 #define SBUCKETS 7
43
44 struct source {
45     int type;
46     char data[16];
47     unsigned int len, hash;
48 };
49
50 struct waiting {
51     struct hthead *req;
52     int fd;
53 };
54
55 struct bucket {
56     struct source id;
57     double level, last, etime, wtime;
58     typedbuf(struct waiting) brim;
59     int thpos, blocked;
60 };
61
62 struct btime {
63     struct bucket *bk;
64     double tm;
65 };
66
67 struct config {
68     double size, rate, retain, warnrate;
69     int brimsize;
70 };
71
72 static struct bucket *sbuckets[1 << SBUCKETS];
73 static struct bucket **buckets = sbuckets;
74 static int hashlen = SBUCKETS, nbuckets = 0;
75 static typedbuf(struct btime) timeheap;
76 static int child, reload;
77 static double now;
78 static const struct config defcfg = {
79     .size = 100, .rate = 10, .warnrate = 60,
80     .retain = 10, .brimsize = 10,
81 };
82 static struct config cf;
83
84 static double rtime(void)
85 {
86     static int init = 0;
87     static struct timespec or;
88     struct timespec ts;
89     
90     clock_gettime(CLOCK_MONOTONIC, &ts);
91     if(!init) {
92         or = ts;
93         init = 1;
94     }
95     return((ts.tv_sec - or.tv_sec) + ((ts.tv_nsec - or.tv_nsec) / 1000000000.0));
96 }
97
98 static struct source reqsource(struct hthead *req)
99 {
100     int i;
101     char *sa;
102     struct in_addr a4;
103     struct in6_addr a6;
104     struct source ret;
105     
106     ret = (struct source){};
107     if((sa = getheader(req, "X-Ash-Address")) != NULL) {
108         if(inet_pton(AF_INET, sa, &a4) == 1) {
109             ret.type = AF_INET;
110             memcpy(ret.data, &a4, ret.len = sizeof(a4));
111         } else if(inet_pton(AF_INET6, sa, &a6) == 1) {
112             ret.type = AF_INET6;
113             memcpy(ret.data, &a6, ret.len = sizeof(a6));
114         }
115     }
116     for(i = 0, ret.hash = ret.type; i < ret.len; i++)
117         ret.hash = (ret.hash * 31) + ret.data[i];
118     return(ret);
119 }
120
121 static int srccmp(const struct source *a, const struct source *b)
122 {
123     int c;
124     
125     if((c = a->len - b->len) != 0)
126         return(c);
127     if((c = a->type - b->type) != 0)
128         return(c);
129     return(memcmp(a->data, b->data, a->len));
130 }
131
132 static const char *formatsrc(const struct source *src)
133 {
134     static char buf[128];
135     struct in_addr a4;
136     struct in6_addr a6;
137     
138     switch(src->type) {
139     case AF_INET:
140         memcpy(&a4, src->data, sizeof(a4));
141         if(!inet_ntop(AF_INET, &a4, buf, sizeof(buf)))
142             return("<invalid ipv4>");
143         return(buf);
144     case AF_INET6:
145         memcpy(&a6, src->data, sizeof(a6));
146         if(!inet_ntop(AF_INET6, &a6, buf, sizeof(buf)))
147             return("<invalid ipv6>");
148         return(buf);
149     default:
150         return("<invalid source record>");
151     }
152 }
153
154 static void rehash(int nlen)
155 {
156     unsigned int i, o, n, m, pl, nl;
157     struct bucket **new, **old;
158     
159     old = buckets;
160     if(nlen <= SBUCKETS) {
161         nlen = SBUCKETS;
162         new = sbuckets;
163     } else {
164         new = szmalloc(sizeof(*new) * (1 << nlen));
165     }
166     if(nlen == hashlen)
167         return;
168     assert(old != new);
169     pl = 1 << hashlen; nl = 1 << nlen; m = nl - 1;
170     for(i = 0; i < pl; i++) {
171         if(!old[i])
172             continue;
173         for(o = old[i]->id.hash & m, n = 0; n < nl; o = (o + 1) & m, n++) {
174             if(!new[o]) {
175                 new[o] = old[i];
176                 break;
177             }
178         }
179     }
180     if(old != sbuckets)
181         free(old);
182     buckets = new;
183     hashlen = nlen;
184 }
185
186 static struct bucket *hashget(const struct source *src)
187 {
188     unsigned int i, n, N, m;
189     struct bucket *bk;
190     
191     m = (N = (1 << hashlen)) - 1;
192     for(i = src->hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
193         bk = buckets[i];
194         if(bk && !srccmp(&bk->id, src))
195             return(bk);
196     }
197     for(i = src->hash & m; buckets[i]; i = (i + 1) & m);
198     buckets[i] = bk = szmalloc(sizeof(*bk));
199     memcpy(&bk->id, src, sizeof(*src));
200     bk->last = bk->etime = now;
201     bk->thpos = -1;
202     bk->blocked = -1;
203     if(++nbuckets > (1 << (hashlen - 1)))
204         rehash(hashlen + 1);
205     return(bk);
206 }
207
208 static void hashdel(struct bucket *bk)
209 {
210     unsigned int i, o, p, n, N, m;
211     struct bucket *sb;
212     
213     m = (N = (1 << hashlen)) - 1;
214     for(i = bk->id.hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
215         assert((sb = buckets[i]) != NULL);
216         if(!srccmp(&sb->id, &bk->id))
217             break;
218     }
219     assert(sb == bk);
220     buckets[i] = NULL;
221     for(o = (i + 1) & m; buckets[o] != NULL; o = (o + 1) & m) {
222         sb = buckets[o];
223         p = (sb->id.hash - i) & m;
224         if((p == 0) || (p > ((o - i) & m))) {
225             buckets[i] = sb;
226             buckets[o] = NULL;
227             i = o;
228         }
229     }
230     if(--nbuckets <= (1 << (hashlen - 3)))
231         rehash(hashlen - 1);
232 }
233
234 static void thraise(struct btime bt, int n)
235 {
236     int p;
237     
238     while(n > 0) {
239         p = (n - 1) >> 1;
240         if(timeheap.b[p].tm <= bt.tm)
241             break;
242         (timeheap.b[n] = timeheap.b[p]).bk->thpos = n;
243         n = p;
244     }
245     (timeheap.b[n] = bt).bk->thpos = n;
246 }
247
248 static void thlower(struct btime bt, int n)
249 {
250     int c1, c2, c;
251     
252     while(1) {
253         c2 = (c1 = (n << 1) + 1) + 1;
254         if(c1 >= timeheap.d)
255             break;
256         c = ((c2 < timeheap.d) && (timeheap.b[c2].tm < timeheap.b[c1].tm)) ? c2 : c1;
257         if(timeheap.b[c].tm > bt.tm)
258             break;
259         (timeheap.b[n] = timeheap.b[c]).bk->thpos = n;
260         n = c;
261     }
262     (timeheap.b[n] = bt).bk->thpos = n;
263 }
264
265 static void thadjust(struct btime bt, int n)
266 {
267     if((n > 0) && (timeheap.b[(n - 1) >> 1].tm > bt.tm))
268         thraise(bt, n);
269     else
270         thlower(bt, n);
271 }
272
273 static void freebucket(struct bucket *bk)
274 {
275     int i, n;
276     struct btime r;
277     
278     hashdel(bk);
279     if((n = bk->thpos) >= 0) {
280         r = timeheap.b[--timeheap.d];
281         if(n < timeheap.d)
282             thadjust(r, n);
283     }
284     for(i = 0; i < bk->brim.d; i++) {
285         freehthead(bk->brim.b[i].req);
286         close(bk->brim.b[i].fd);
287     }
288     buffree(bk->brim);
289     free(bk);
290 }
291
292 static void updbtime(struct bucket *bk)
293 {
294     double tm, tm2;
295     
296     tm = (bk->level == 0) ? (bk->etime + cf.retain) : (bk->last + (bk->level / cf.rate) + cf.retain);
297     if((bk->blocked > 0) && ((tm2 = bk->wtime + cf.warnrate) > tm))
298         tm = tm2;
299     
300     if((bk->brim.d > 0) && ((tm2 = bk->last + ((bk->level - cf.size) / cf.rate)) < tm))
301         tm = tm2;
302     if((bk->blocked > 0) && ((tm2 = bk->wtime + cf.warnrate) < tm))
303         tm = tm2;
304     
305     if(bk->thpos < 0) {
306         sizebuf(timeheap, ++timeheap.d);
307         thraise((struct btime){bk, tm}, timeheap.d - 1);
308     } else {
309         thadjust((struct btime){bk, tm}, bk->thpos);
310     }
311 }
312
313 static void tickbucket(struct bucket *bk)
314 {
315     double delta, ll;
316     
317     delta = now - bk->last;
318     bk->last = now;
319     ll = bk->level;
320     if((bk->level -= delta * cf.rate) < 0) {
321         if(ll > 0)
322             bk->etime = now + (bk->level / cf.rate);
323         bk->level = 0;
324     }
325     while((bk->brim.d > 0) && (bk->level < cf.size)) {
326         if(sendreq(child, bk->brim.b[0].req, bk->brim.b[0].fd)) {
327             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
328             exit(1);
329         }
330         freehthead(bk->brim.b[0].req);
331         close(bk->brim.b[0].fd);
332         bufdel(bk->brim, 0);
333         bk->level += 1;
334     }
335     if((bk->blocked > 0) && (now - bk->wtime >= cf.warnrate)) {
336         flog(LOG_NOTICE, "ratequeue: blocked %i requests from %s", bk->blocked, formatsrc(&bk->id));
337         bk->blocked = 0;
338         bk->wtime = now;
339     }
340 }
341
342 static void checkbtime(struct bucket *bk)
343 {
344     tickbucket(bk);
345     if((bk->level == 0) && (now >= bk->etime + cf.retain) && (bk->blocked <= 0)) {
346         freebucket(bk);
347         return;
348     }
349     updbtime(bk);
350 }
351
352 static void serve(struct hthead *req, int fd)
353 {
354     struct source src;
355     struct bucket *bk;
356     
357     now = rtime();
358     src = reqsource(req);
359     bk = hashget(&src);
360     tickbucket(bk);
361     if(bk->level < cf.size) {
362         bk->level += 1;
363         if(sendreq(child, req, fd)) {
364             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
365             exit(1);
366         }
367         freehthead(req);
368         close(fd);
369     } else if(bk->brim.d < cf.brimsize) {
370         bufadd(bk->brim, ((struct waiting){.req = req, .fd = fd}));
371     } else {
372         if(bk->blocked < 0) {
373             flog(LOG_NOTICE, "ratequeue: blocking requests from %s", formatsrc(&bk->id));
374             bk->blocked = 0;
375             bk->wtime = now;
376         }
377         simpleerror(fd, 429, "Too many requests", "Your client is being throttled for issuing too frequent requests.");
378         freehthead(req);
379         close(fd);
380         bk->blocked++;
381     }
382     updbtime(bk);
383 }
384
385 static int parseint(const char *str, int *dst)
386 {
387     long buf;
388     char *p;
389     
390     buf = strtol(str, &p, 0);
391     if((p == str) || *p)
392         return(-1);
393     *dst = buf;
394     return(0);
395 }
396
397 static int parsefloat(const char *str, double *dst)
398 {
399     double buf;
400     char *p;
401     
402     buf = strtod(str, &p);
403     if((p == str) || *p)
404         return(-1);
405     *dst = buf;
406     return(0);
407 }
408
409 static int readconf(char *path, struct config *buf)
410 {
411     FILE *fp;
412     struct cfstate *s;
413     int rv;
414     
415     if((fp = fopen(path, "r")) == NULL) {
416         flog(LOG_ERR, "ratequeue: %s: %s", path, strerror(errno));
417         return(-1);
418     }
419     *buf = defcfg;
420     s = mkcfparser(fp, path);
421     rv = -1;
422     while(1) {
423         getcfline(s);
424         if(!strcmp(s->argv[0], "eof")) {
425             break;
426         } else if(!strcmp(s->argv[0], "size")) {
427             if((s->argc < 2) || parsefloat(s->argv[1], &buf->size)) {
428                 flog(LOG_ERR, "%s:%i: missing or invalid `size' argument");
429                 goto err;
430             }
431         } else if(!strcmp(s->argv[0], "rate")) {
432             if((s->argc < 2) || parsefloat(s->argv[1], &buf->rate)) {
433                 flog(LOG_ERR, "%s:%i: missing or invalid `rate' argument");
434                 goto err;
435             }
436         } else if(!strcmp(s->argv[0], "brim")) {
437             if((s->argc < 2) || parseint(s->argv[1], &buf->brimsize)) {
438                 flog(LOG_ERR, "%s:%i: missing or invalid `brim' argument");
439                 goto err;
440             }
441         } else {
442             flog(LOG_WARNING, "%s:%i: unknown directive `%s'", s->file, s->lno, s->argv[0]);
443         }
444     }
445     rv = 0;
446 err:
447     freecfparser(s);
448     fclose(fp);
449     return(rv);
450 }
451
452 static void huphandler(int sig)
453 {
454     reload = 1;
455 }
456
457 static void usage(FILE *out)
458 {
459     fprintf(out, "usage: ratequeue [-h] [-s BUCKET-SIZE] [-r RATE] [-b BRIM-SIZE] PROGRAM [ARGS...]\n");
460 }
461
462 int main(int argc, char **argv)
463 {
464     int c, rv;
465     int fd;
466     struct hthead *req;
467     struct pollfd pfd;
468     double timeout;
469     char *cfname;
470     struct config cfbuf;
471     
472     cf = defcfg;
473     cfname = NULL;
474     while((c = getopt(argc, argv, "+hc:s:r:b:")) >= 0) {
475         switch(c) {
476         case 'h':
477             usage(stdout);
478             return(0);
479         case 'c':
480             cfname = optarg;
481             break;
482         case 's':
483             parsefloat(optarg, &cf.size);
484             break;
485         case 'r':
486             parsefloat(optarg, &cf.rate);
487             break;
488         case 'b':
489             parseint(optarg, &cf.brimsize);
490             break;
491         }
492     }
493     if(argc - optind < 1) {
494         usage(stderr);
495         return(1);
496     }
497     if(cfname) {
498         if(readconf(cfname, &cfbuf))
499             return(1);
500         cf = cfbuf;
501     }
502     if((child = stdmkchild(argv + optind, NULL, NULL)) < 0) {
503         flog(LOG_ERR, "ratequeue: could not fork child: %s", strerror(errno));
504         return(1);
505     }
506     sigaction(SIGHUP, &(struct sigaction){.sa_handler = huphandler}, NULL);
507     while(1) {
508         if(reload) {
509             if(cfname) {
510                 if(!readconf(cfname, &cfbuf))
511                     cf = cfbuf;
512             }
513             reload = 0;
514         }
515         now = rtime();
516         pfd = (struct pollfd){.fd = 0, .events = POLLIN};
517         timeout = (timeheap.d > 0) ? timeheap.b[0].tm : -1;
518         if((rv = poll(&pfd, 1, (timeout < 0) ? -1 : (int)((timeout + 0.1 - now) * 1000))) < 0) {
519             if(errno != EINTR) {
520                 flog(LOG_ERR, "ratequeue: error in poll: %s", strerror(errno));
521                 exit(1);
522             }
523         }
524         if(pfd.revents) {
525             if((fd = recvreq(0, &req)) < 0) {
526                 if(errno == EINTR)
527                     continue;
528                 if(errno != 0)
529                     flog(LOG_ERR, "recvreq: %s", strerror(errno));
530                 break;
531             }
532             serve(req, fd);
533         }
534         while((timeheap.d > 0) && ((now = rtime()) >= timeheap.b[0].tm))
535             checkbtime(timeheap.b[0].bk);
536     }
537     return(0);
538 }