ratequeue: Be consistent with signedness in rehash vs. hashget/hashdel.
[ashd.git] / src / ratequeue.c
1 /*
2     ashd - A Sane HTTP Daemon
3     Copyright (C) 2008  Fredrik Tolf <fredrik@dolda2000.com>
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <stdlib.h>
20 #include <stdio.h>
21 #include <unistd.h>
22 #include <errno.h>
23 #include <string.h>
24 #include <time.h>
25 #include <signal.h>
26 #include <assert.h>
27 #include <sys/poll.h>
28 #include <arpa/inet.h>
29
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33 #include <utils.h>
34 #include <log.h>
35 #include <req.h>
36 #include <resp.h>
37 #include <proc.h>
38 #include <cf.h>
39
40 #define SBUCKETS 7
41
42 struct source {
43     int type;
44     char data[16];
45     unsigned int len, hash;
46 };
47
48 struct waiting {
49     struct hthead *req;
50     int fd;
51 };
52
53 struct bucket {
54     struct source id;
55     double level, last, etime, wtime;
56     typedbuf(struct waiting) brim;
57     int thpos, blocked;
58 };
59
60 struct btime {
61     struct bucket *bk;
62     double tm;
63 };
64
65 struct config {
66     double size, rate, retain, warnrate;
67     int brimsize;
68 };
69
70 static struct bucket *sbuckets[1 << SBUCKETS];
71 static struct bucket **buckets = sbuckets;
72 static int hashlen = SBUCKETS, nbuckets = 0;
73 static typedbuf(struct btime) timeheap;
74 static int child, reload;
75 static double now;
76 static const struct config defcfg = {
77     .size = 100, .rate = 10, .warnrate = 60,
78     .retain = 10, .brimsize = 10,
79 };
80 static struct config cf;
81
82 static double rtime(void)
83 {
84     static int init = 0;
85     static struct timespec or;
86     struct timespec ts;
87     
88     clock_gettime(CLOCK_MONOTONIC, &ts);
89     if(!init) {
90         or = ts;
91         init = 1;
92     }
93     return((ts.tv_sec - or.tv_sec) + ((ts.tv_nsec - or.tv_nsec) / 1000000000.0));
94 }
95
96 static struct source reqsource(struct hthead *req)
97 {
98     int i;
99     char *sa;
100     struct in_addr a4;
101     struct in6_addr a6;
102     struct source ret;
103     
104     ret = (struct source){};
105     if((sa = getheader(req, "X-Ash-Address")) != NULL) {
106         if(inet_pton(AF_INET, sa, &a4) == 1) {
107             ret.type = AF_INET;
108             memcpy(ret.data, &a4, ret.len = sizeof(a4));
109         } else if(inet_pton(AF_INET6, sa, &a6) == 1) {
110             ret.type = AF_INET6;
111             memcpy(ret.data, &a6, ret.len = sizeof(a6));
112         }
113     }
114     for(i = 0, ret.hash = ret.type; i < ret.len; i++)
115         ret.hash = (ret.hash * 31) + ret.data[i];
116     return(ret);
117 }
118
119 static int srccmp(const struct source *a, const struct source *b)
120 {
121     int c;
122     
123     if((c = a->len - b->len) != 0)
124         return(c);
125     if((c = a->type - b->type) != 0)
126         return(c);
127     return(memcmp(a->data, b->data, a->len));
128 }
129
130 static const char *formatsrc(const struct source *src)
131 {
132     static char buf[128];
133     struct in_addr a4;
134     struct in6_addr a6;
135     
136     switch(src->type) {
137     case AF_INET:
138         memcpy(&a4, src->data, sizeof(a4));
139         if(!inet_ntop(AF_INET, &a4, buf, sizeof(buf)))
140             return("<invalid ipv4>");
141         return(buf);
142     case AF_INET6:
143         memcpy(&a6, src->data, sizeof(a6));
144         if(!inet_ntop(AF_INET6, &a6, buf, sizeof(buf)))
145             return("<invalid ipv6>");
146         return(buf);
147     default:
148         return("<invalid source record>");
149     }
150 }
151
152 static void rehash(int nlen)
153 {
154     unsigned int i, o, n, m, pl, nl;
155     struct bucket **new, **old;
156     
157     old = buckets;
158     if(nlen <= SBUCKETS) {
159         nlen = SBUCKETS;
160         new = sbuckets;
161     } else {
162         new = szmalloc(sizeof(*new) * (1 << nlen));
163     }
164     if(nlen == hashlen)
165         return;
166     assert(old != new);
167     pl = 1 << hashlen; nl = 1 << nlen; m = nl - 1;
168     for(i = 0; i < pl; i++) {
169         if(!old[i])
170             continue;
171         for(o = old[i]->id.hash & m, n = 0; n < nl; o = (o + 1) & m, n++) {
172             if(!new[o]) {
173                 new[o] = old[i];
174                 break;
175             }
176         }
177     }
178     if(old != sbuckets)
179         free(old);
180     buckets = new;
181     hashlen = nlen;
182 }
183
184 static struct bucket *hashget(const struct source *src)
185 {
186     unsigned int i, n, N, m;
187     struct bucket *bk;
188     
189     m = (N = (1 << hashlen)) - 1;
190     for(i = src->hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
191         bk = buckets[i];
192         if(bk && !srccmp(&bk->id, src))
193             return(bk);
194     }
195     for(i = src->hash & m; buckets[i]; i = (i + 1) & m);
196     buckets[i] = bk = szmalloc(sizeof(*bk));
197     memcpy(&bk->id, src, sizeof(*src));
198     bk->last = bk->etime = now;
199     bk->thpos = -1;
200     bk->blocked = -1;
201     if(++nbuckets > (1 << (hashlen - 1)))
202         rehash(hashlen + 1);
203     return(bk);
204 }
205
206 static void hashdel(struct bucket *bk)
207 {
208     unsigned int i, o, p, n, N, m;
209     struct bucket *sb;
210     
211     m = (N = (1 << hashlen)) - 1;
212     for(i = bk->id.hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
213         assert((sb = buckets[i]) != NULL);
214         if(!srccmp(&sb->id, &bk->id))
215             break;
216     }
217     assert(sb == bk);
218     buckets[i] = NULL;
219     for(o = (i + 1) & m; buckets[o] != NULL; o = (o + 1) & m) {
220         sb = buckets[o];
221         p = (sb->id.hash - i) & m;
222         if((p == 0) || (p > ((o - i) & m))) {
223             buckets[i] = sb;
224             buckets[o] = NULL;
225             i = o;
226         }
227     }
228     if(--nbuckets <= (1 << (hashlen - 3)))
229         rehash(hashlen - 1);
230 }
231
232 static void thraise(struct btime bt, int n)
233 {
234     int p;
235     
236     while(n > 0) {
237         p = (n - 1) >> 1;
238         if(timeheap.b[p].tm <= bt.tm)
239             break;
240         (timeheap.b[n] = timeheap.b[p]).bk->thpos = n;
241         n = p;
242     }
243     (timeheap.b[n] = bt).bk->thpos = n;
244 }
245
246 static void thlower(struct btime bt, int n)
247 {
248     int c1, c2, c;
249     
250     while(1) {
251         c2 = (c1 = (n << 1) + 1) + 1;
252         if(c1 >= timeheap.d)
253             break;
254         c = ((c2 < timeheap.d) && (timeheap.b[c2].tm < timeheap.b[c1].tm)) ? c2 : c1;
255         if(timeheap.b[c].tm > bt.tm)
256             break;
257         (timeheap.b[n] = timeheap.b[c]).bk->thpos = n;
258         n = c;
259     }
260     (timeheap.b[n] = bt).bk->thpos = n;
261 }
262
263 static void thadjust(struct btime bt, int n)
264 {
265     if((n > 0) && (timeheap.b[(n - 1) >> 1].tm > bt.tm))
266         thraise(bt, n);
267     else
268         thlower(bt, n);
269 }
270
271 static void freebucket(struct bucket *bk)
272 {
273     int i, n;
274     struct btime r;
275     
276     hashdel(bk);
277     if((n = bk->thpos) >= 0) {
278         r = timeheap.b[--timeheap.d];
279         if(n < timeheap.d)
280             thadjust(r, n);
281     }
282     for(i = 0; i < bk->brim.d; i++) {
283         freehthead(bk->brim.b[i].req);
284         close(bk->brim.b[i].fd);
285     }
286     buffree(bk->brim);
287     free(bk);
288 }
289
290 static void updbtime(struct bucket *bk)
291 {
292     double tm, tm2;
293     
294     tm = (bk->level == 0) ? (bk->etime + cf.retain) : (bk->last + (bk->level / cf.rate) + cf.retain);
295     if((bk->blocked > 0) && ((tm2 = bk->wtime + cf.warnrate) > tm))
296         tm = tm2;
297     
298     if((bk->brim.d > 0) && ((tm2 = bk->last + ((bk->level - cf.size) / cf.rate)) < tm))
299         tm = tm2;
300     if((bk->blocked > 0) && ((tm2 = bk->wtime + cf.warnrate) < tm))
301         tm = tm2;
302     
303     if(bk->thpos < 0) {
304         sizebuf(timeheap, ++timeheap.d);
305         thraise((struct btime){bk, tm}, timeheap.d - 1);
306     } else {
307         thadjust((struct btime){bk, tm}, bk->thpos);
308     }
309 }
310
311 static void tickbucket(struct bucket *bk)
312 {
313     double delta, ll;
314     
315     delta = now - bk->last;
316     bk->last = now;
317     ll = bk->level;
318     if((bk->level -= delta * cf.rate) < 0) {
319         if(ll > 0)
320             bk->etime = now + (bk->level / cf.rate);
321         bk->level = 0;
322     }
323     while((bk->brim.d > 0) && (bk->level < cf.size)) {
324         if(sendreq(child, bk->brim.b[0].req, bk->brim.b[0].fd)) {
325             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
326             exit(1);
327         }
328         freehthead(bk->brim.b[0].req);
329         close(bk->brim.b[0].fd);
330         bufdel(bk->brim, 0);
331         bk->level += 1;
332     }
333     if((bk->blocked > 0) && (now - bk->wtime >= cf.warnrate)) {
334         flog(LOG_NOTICE, "ratequeue: blocked %i requests from %s", bk->blocked, formatsrc(&bk->id));
335         bk->blocked = 0;
336         bk->wtime = now;
337     }
338 }
339
340 static void checkbtime(struct bucket *bk)
341 {
342     tickbucket(bk);
343     if((bk->level == 0) && (now >= bk->etime + cf.retain) && (bk->blocked <= 0)) {
344         freebucket(bk);
345         return;
346     }
347     updbtime(bk);
348 }
349
350 static void serve(struct hthead *req, int fd)
351 {
352     struct source src;
353     struct bucket *bk;
354     
355     now = rtime();
356     src = reqsource(req);
357     bk = hashget(&src);
358     tickbucket(bk);
359     if(bk->level < cf.size) {
360         bk->level += 1;
361         if(sendreq(child, req, fd)) {
362             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
363             exit(1);
364         }
365         freehthead(req);
366         close(fd);
367     } else if(bk->brim.d < cf.brimsize) {
368         bufadd(bk->brim, ((struct waiting){.req = req, .fd = fd}));
369     } else {
370         if(bk->blocked < 0) {
371             flog(LOG_NOTICE, "ratequeue: blocking requests from %s", formatsrc(&bk->id));
372             bk->blocked = 0;
373             bk->wtime = now;
374         }
375         simpleerror(fd, 429, "Too many requests", "Your client is being throttled for issuing too frequent requests.");
376         freehthead(req);
377         close(fd);
378         bk->blocked++;
379     }
380     updbtime(bk);
381 }
382
383 static int parseint(const char *str, int *dst)
384 {
385     long buf;
386     char *p;
387     
388     buf = strtol(str, &p, 0);
389     if((p == str) || *p)
390         return(-1);
391     *dst = buf;
392     return(0);
393 }
394
395 static int parsefloat(const char *str, double *dst)
396 {
397     double buf;
398     char *p;
399     
400     buf = strtod(str, &p);
401     if((p == str) || *p)
402         return(-1);
403     *dst = buf;
404     return(0);
405 }
406
407 static int readconf(char *path, struct config *buf)
408 {
409     FILE *fp;
410     struct cfstate *s;
411     int rv;
412     
413     if((fp = fopen(path, "r")) == NULL) {
414         flog(LOG_ERR, "ratequeue: %s: %s", path, strerror(errno));
415         return(-1);
416     }
417     *buf = defcfg;
418     s = mkcfparser(fp, path);
419     rv = -1;
420     while(1) {
421         getcfline(s);
422         if(!strcmp(s->argv[0], "eof")) {
423             break;
424         } else if(!strcmp(s->argv[0], "size")) {
425             if((s->argc < 2) || parsefloat(s->argv[1], &buf->size)) {
426                 flog(LOG_ERR, "%s:%i: missing or invalid `size' argument");
427                 goto err;
428             }
429         } else if(!strcmp(s->argv[0], "rate")) {
430             if((s->argc < 2) || parsefloat(s->argv[1], &buf->rate)) {
431                 flog(LOG_ERR, "%s:%i: missing or invalid `rate' argument");
432                 goto err;
433             }
434         } else if(!strcmp(s->argv[0], "brim")) {
435             if((s->argc < 2) || parseint(s->argv[1], &buf->brimsize)) {
436                 flog(LOG_ERR, "%s:%i: missing or invalid `brim' argument");
437                 goto err;
438             }
439         } else {
440             flog(LOG_WARNING, "%s:%i: unknown directive `%s'", s->file, s->lno, s->argv[0]);
441         }
442     }
443     rv = 0;
444 err:
445     freecfparser(s);
446     fclose(fp);
447     return(rv);
448 }
449
450 static void huphandler(int sig)
451 {
452     reload = 1;
453 }
454
455 static void usage(FILE *out)
456 {
457     fprintf(out, "usage: ratequeue [-h] [-s BUCKET-SIZE] [-r RATE] [-b BRIM-SIZE] PROGRAM [ARGS...]\n");
458 }
459
460 int main(int argc, char **argv)
461 {
462     int c, rv;
463     int fd;
464     struct hthead *req;
465     struct pollfd pfd;
466     double timeout;
467     char *cfname;
468     struct config cfbuf;
469     
470     cf = defcfg;
471     cfname = NULL;
472     while((c = getopt(argc, argv, "+hc:s:r:b:")) >= 0) {
473         switch(c) {
474         case 'h':
475             usage(stdout);
476             return(0);
477         case 'c':
478             cfname = optarg;
479             break;
480         case 's':
481             parsefloat(optarg, &cf.size);
482             break;
483         case 'r':
484             parsefloat(optarg, &cf.rate);
485             break;
486         case 'b':
487             parseint(optarg, &cf.brimsize);
488             break;
489         }
490     }
491     if(argc - optind < 1) {
492         usage(stderr);
493         return(1);
494     }
495     if(cfname) {
496         if(readconf(cfname, &cfbuf))
497             return(1);
498         cf = cfbuf;
499     }
500     if((child = stdmkchild(argv + optind, NULL, NULL)) < 0) {
501         flog(LOG_ERR, "ratequeue: could not fork child: %s", strerror(errno));
502         return(1);
503     }
504     sigaction(SIGHUP, &(struct sigaction){.sa_handler = huphandler}, NULL);
505     while(1) {
506         if(reload) {
507             if(cfname) {
508                 if(!readconf(cfname, &cfbuf))
509                     cf = cfbuf;
510             }
511             reload = 0;
512         }
513         now = rtime();
514         pfd = (struct pollfd){.fd = 0, .events = POLLIN};
515         timeout = (timeheap.d > 0) ? timeheap.b[0].tm : -1;
516         if((rv = poll(&pfd, 1, (timeout < 0) ? -1 : (int)((timeout + 0.1 - now) * 1000))) < 0) {
517             if(errno != EINTR) {
518                 flog(LOG_ERR, "ratequeue: error in poll: %s", strerror(errno));
519                 exit(1);
520             }
521         }
522         if(pfd.revents) {
523             if((fd = recvreq(0, &req)) < 0) {
524                 if(errno == EINTR)
525                     continue;
526                 if(errno != 0)
527                     flog(LOG_ERR, "recvreq: %s", strerror(errno));
528                 break;
529             }
530             serve(req, fd);
531         }
532         while((timeheap.d > 0) && ((now = rtime()) >= timeheap.b[0].tm))
533             checkbtime(timeheap.b[0].bk);
534     }
535     return(0);
536 }