Added the `ratequeue' program.
[ashd.git] / src / ratequeue.c
CommitLineData
ebe9b505
FT
1/*
2 ashd - A Sane HTTP Daemon
3 Copyright (C) 2008 Fredrik Tolf <fredrik@dolda2000.com>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
9
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http://www.gnu.org/licenses/>.
17*/
18
19#include <stdlib.h>
20#include <stdio.h>
21#include <unistd.h>
22#include <errno.h>
23#include <string.h>
24#include <time.h>
25#include <signal.h>
26#include <assert.h>
27#include <sys/poll.h>
28#include <arpa/inet.h>
29
30#ifdef HAVE_CONFIG_H
31#include <config.h>
32#endif
33#include <utils.h>
34#include <log.h>
35#include <req.h>
36#include <resp.h>
37#include <proc.h>
38#include <cf.h>
39
40#define SBUCKETS 7
41
42struct source {
43 char data[16];
44 unsigned int len, hash;
45};
46
47struct waiting {
48 struct hthead *req;
49 int fd;
50};
51
52struct bucket {
53 struct source id;
54 double level, last, etime;
55 typedbuf(struct waiting) brim;
56 int thpos;
57};
58
59struct btime {
60 struct bucket *bk;
61 double tm;
62};
63
64struct config {
65 double size, rate, retain;
66 int brimsize;
67};
68
69static struct bucket *sbuckets[1 << SBUCKETS];
70static struct bucket **buckets = sbuckets;
71static int hashlen = SBUCKETS, nbuckets = 0;
72static typedbuf(struct btime) timeheap;
73static int child, reload;
74static double now;
75static const struct config defcfg = {
76 .size = 100, .rate = 10,
77 .retain = 10, .brimsize = 10,
78};
79static struct config cf;
80
81static double rtime(void)
82{
83 static int init = 0;
84 static struct timespec or;
85 struct timespec ts;
86
87 clock_gettime(CLOCK_MONOTONIC, &ts);
88 if(!init) {
89 or = ts;
90 init = 1;
91 }
92 return((ts.tv_sec - or.tv_sec) + ((ts.tv_nsec - or.tv_nsec) / 1000000000.0));
93}
94
95static struct source reqsource(struct hthead *req)
96{
97 int i;
98 char *sa;
99 struct in_addr a4;
100 struct in6_addr a6;
101 struct source ret;
102
103 ret = (struct source){};
104 if((sa = getheader(req, "X-Ash-Address")) != NULL) {
105 if(inet_pton(AF_INET, sa, &a4) == 1) {
106 memcpy(ret.data, &a4, ret.len = sizeof(a4));
107 } else if(inet_pton(AF_INET6, sa, &a6) == 1) {
108 memcpy(ret.data, &a6, ret.len = sizeof(a6));
109 }
110 }
111 for(i = 0, ret.hash = 0; i < ret.len; i++)
112 ret.hash = (ret.hash * 31) + ret.data[i];
113 return(ret);
114}
115
116static void rehash(int nlen)
117{
118 int i, o, n, m, pl, nl;
119 struct bucket **new, **old;
120
121 old = buckets;
122 if(nlen <= SBUCKETS) {
123 nlen = SBUCKETS;
124 new = sbuckets;
125 } else {
126 new = szmalloc(sizeof(*new) * (1 << nlen));
127 }
128 if(nlen == hashlen)
129 return;
130 assert(old != new);
131 pl = 1 << hashlen; nl = 1 << nlen; m = nl - 1;
132 for(i = 0; i < pl; i++) {
133 if(!old[i])
134 continue;
135 for(o = old[i]->id.hash & m, n = 0; n < nl; o = (o + 1) & m, n++) {
136 if(!new[o]) {
137 new[o] = old[i];
138 break;
139 }
140 }
141 }
142 if(old != sbuckets)
143 free(old);
144 buckets = new;
145 hashlen = nlen;
146}
147
148static struct bucket *hashget(struct source *src)
149{
150 unsigned int i, n, N, m;
151 struct bucket *bk;
152
153 m = (N = (1 << hashlen)) - 1;
154 for(i = src->hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
155 bk = buckets[i];
156 if(bk && (bk->id.len == src->len) && !memcmp(bk->id.data, src->data, src->len))
157 return(bk);
158 }
159 for(i = src->hash & m; buckets[i]; i = (i + 1) & m);
160 buckets[i] = bk = szmalloc(sizeof(*bk));
161 memcpy(&bk->id, src, sizeof(*src));
162 bk->last = bk->etime = now;
163 bk->thpos = -1;
164 if(++nbuckets > (1 << (hashlen - 1)))
165 rehash(hashlen + 1);
166 return(bk);
167}
168
169static void hashdel(struct bucket *bk)
170{
171 unsigned int i, o, p, n, N, m;
172 struct bucket *sb;
173
174 m = (N = (1 << hashlen)) - 1;
175 for(i = bk->id.hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
176 assert((sb = buckets[i]) != NULL);
177 if((sb->id.len == bk->id.len) && !memcmp(sb->id.data, bk->id.data, bk->id.len))
178 break;
179 }
180 assert(sb == bk);
181 buckets[i] = NULL;
182 for(o = (i + 1) & m; buckets[o] != NULL; o = (o + 1) & m) {
183 sb = buckets[o];
184 p = (sb->id.hash - i) & m;
185 if((p == 0) || (p > ((o - i) & m))) {
186 buckets[i] = sb;
187 buckets[o] = NULL;
188 i = o;
189 }
190 }
191 if(--nbuckets <= (1 << (hashlen - 3)))
192 rehash(hashlen - 1);
193}
194
195static void thraise(struct btime bt, int n)
196{
197 int p;
198
199 while(n > 0) {
200 p = (n - 1) >> 1;
201 if(timeheap.b[p].tm <= bt.tm)
202 break;
203 (timeheap.b[n] = timeheap.b[p]).bk->thpos = n;
204 n = p;
205 }
206 (timeheap.b[n] = bt).bk->thpos = n;
207}
208
209static void thlower(struct btime bt, int n)
210{
211 int c1, c2, c;
212
213 while(1) {
214 c2 = (c1 = (n << 1) + 1) + 1;
215 if(c1 >= timeheap.d)
216 break;
217 c = ((c2 < timeheap.d) && (timeheap.b[c2].tm < timeheap.b[c1].tm)) ? c2 : c1;
218 if(timeheap.b[c].tm > bt.tm)
219 break;
220 (timeheap.b[n] = timeheap.b[c]).bk->thpos = n;
221 n = c;
222 }
223 (timeheap.b[n] = bt).bk->thpos = n;
224}
225
226static void thadjust(struct btime bt, int n)
227{
228 if((n > 0) && (timeheap.b[(n - 1) >> 1].tm > bt.tm))
229 thraise(bt, n);
230 else
231 thlower(bt, n);
232}
233
234static void freebucket(struct bucket *bk)
235{
236 int i, n;
237 struct btime r;
238
239 hashdel(bk);
240 if((n = bk->thpos) >= 0) {
241 r = timeheap.b[--timeheap.d];
242 if(n < timeheap.d)
243 thadjust(r, n);
244 }
245 for(i = 0; i < bk->brim.d; i++) {
246 freehthead(bk->brim.b[i].req);
247 close(bk->brim.b[i].fd);
248 }
249 buffree(bk->brim);
250 free(bk);
251}
252
253static void updbtime(struct bucket *bk)
254{
255 double tm, tm2;
256
257 tm = (bk->level == 0) ? (bk->etime + cf.retain) : (bk->last + (bk->level / cf.rate) + cf.retain);
258 if(bk->brim.d > 0) {
259 tm2 = bk->last + ((bk->level - cf.size) / cf.rate);
260 tm = (tm2 < tm) ? tm2 : tm;
261 }
262 if(bk->thpos < 0) {
263 sizebuf(timeheap, ++timeheap.d);
264 thraise((struct btime){bk, tm}, timeheap.d - 1);
265 } else {
266 thadjust((struct btime){bk, tm}, bk->thpos);
267 }
268}
269
270static void tickbucket(struct bucket *bk)
271{
272 double delta, ll;
273
274 delta = now - bk->last;
275 bk->last = now;
276 ll = bk->level;
277 if((bk->level -= delta * cf.rate) < 0) {
278 bk->level = 0;
279 if(ll > 0)
280 bk->etime = now;
281 }
282 while((bk->brim.d > 0) && (bk->level < cf.size)) {
283 if(sendreq(child, bk->brim.b[0].req, bk->brim.b[0].fd)) {
284 flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
285 exit(1);
286 }
287 freehthead(bk->brim.b[0].req);
288 close(bk->brim.b[0].fd);
289 bufdel(bk->brim, 0);
290 bk->level += 1;
291 }
292}
293
294static void checkbtime(struct bucket *bk)
295{
296 tickbucket(bk);
297 if((bk->level == 0) && (now >= bk->etime + cf.retain)) {
298 freebucket(bk);
299 return;
300 }
301 updbtime(bk);
302}
303
304static void serve(struct hthead *req, int fd)
305{
306 struct source src;
307 struct bucket *bk;
308
309 now = rtime();
310 src = reqsource(req);
311 bk = hashget(&src);
312 tickbucket(bk);
313 if(bk->level < cf.size) {
314 bk->level += 1;
315 if(sendreq(child, req, fd)) {
316 flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
317 exit(1);
318 }
319 freehthead(req);
320 close(fd);
321 } else if(bk->brim.d < cf.brimsize) {
322 bufadd(bk->brim, ((struct waiting){.req = req, .fd = fd}));
323 } else {
324 simpleerror(fd, 429, "Too many requests", "Your client is being throttled for issuing too frequent requests.");
325 freehthead(req);
326 close(fd);
327 }
328 updbtime(bk);
329}
330
331static int parseint(char *str, int *dst)
332{
333 long buf;
334 char *p;
335
336 buf = strtol(str, &p, 0);
337 if((p == str) || *p)
338 return(-1);
339 *dst = buf;
340 return(0);
341}
342
343static int parsefloat(char *str, double *dst)
344{
345 double buf;
346 char *p;
347
348 buf = strtod(str, &p);
349 if((p == str) || *p)
350 return(-1);
351 *dst = buf;
352 return(0);
353}
354
355static int readconf(char *path, struct config *buf)
356{
357 FILE *fp;
358 struct cfstate *s;
359 int rv;
360
361 if((fp = fopen(path, "r")) == NULL) {
362 flog(LOG_ERR, "ratequeue: %s: %s", path, strerror(errno));
363 return(-1);
364 }
365 *buf = defcfg;
366 s = mkcfparser(fp, path);
367 rv = -1;
368 while(1) {
369 getcfline(s);
370 if(!strcmp(s->argv[0], "eof")) {
371 break;
372 } else if(!strcmp(s->argv[0], "size")) {
373 if((s->argc < 2) || parsefloat(s->argv[1], &buf->size)) {
374 flog(LOG_ERR, "%s:%i: missing or invalid `size' argument");
375 goto err;
376 }
377 } else if(!strcmp(s->argv[0], "rate")) {
378 if((s->argc < 2) || parsefloat(s->argv[1], &buf->rate)) {
379 flog(LOG_ERR, "%s:%i: missing or invalid `rate' argument");
380 goto err;
381 }
382 } else if(!strcmp(s->argv[0], "brim")) {
383 if((s->argc < 2) || parseint(s->argv[1], &buf->brimsize)) {
384 flog(LOG_ERR, "%s:%i: missing or invalid `brim' argument");
385 goto err;
386 }
387 } else {
388 flog(LOG_WARNING, "%s:%i: unknown directive `%s'", s->file, s->lno, s->argv[0]);
389 }
390 }
391 rv = 0;
392err:
393 freecfparser(s);
394 fclose(fp);
395 return(rv);
396}
397
398static void huphandler(int sig)
399{
400 reload = 1;
401}
402
403static void usage(FILE *out)
404{
405 fprintf(out, "usage: ratequeue [-h] [-s BUCKET-SIZE] [-r RATE] [-b BRIM-SIZE] PROGRAM [ARGS...]\n");
406}
407
408int main(int argc, char **argv)
409{
410 int c, rv;
411 int fd;
412 struct hthead *req;
413 struct pollfd pfd;
414 double timeout;
415 char *cfname;
416 struct config cfbuf;
417
418 cf = defcfg;
419 cfname = NULL;
420 while((c = getopt(argc, argv, "+hc:s:r:b:")) >= 0) {
421 switch(c) {
422 case 'h':
423 usage(stdout);
424 return(0);
425 case 'c':
426 cfname = optarg;
427 break;
428 case 's':
429 parsefloat(optarg, &cf.size);
430 break;
431 case 'r':
432 parsefloat(optarg, &cf.rate);
433 break;
434 case 'b':
435 parseint(optarg, &cf.brimsize);
436 break;
437 }
438 }
439 if(argc - optind < 1) {
440 usage(stderr);
441 return(1);
442 }
443 if(cfname) {
444 if(readconf(cfname, &cfbuf))
445 return(1);
446 cf = cfbuf;
447 }
448 if((child = stdmkchild(argv + optind, NULL, NULL)) < 0) {
449 flog(LOG_ERR, "ratequeue: could not fork child: %s", strerror(errno));
450 return(1);
451 }
452 sigaction(SIGHUP, &(struct sigaction){.sa_handler = huphandler}, NULL);
453 while(1) {
454 if(reload) {
455 if(cfname) {
456 if(!readconf(cfname, &cfbuf))
457 cf = cfbuf;
458 }
459 reload = 0;
460 }
461 now = rtime();
462 pfd = (struct pollfd){.fd = 0, .events = POLLIN};
463 timeout = (timeheap.d > 0) ? timeheap.b[0].tm : -1;
464 if((rv = poll(&pfd, 1, (timeout < 0) ? -1 : (int)((timeout + 0.1 - now) * 1000))) < 0) {
465 if(errno != EINTR) {
466 flog(LOG_ERR, "ratequeue: error in poll: %s", strerror(errno));
467 exit(1);
468 }
469 }
470 if(pfd.revents) {
471 if((fd = recvreq(0, &req)) < 0) {
472 if(errno == EINTR)
473 continue;
474 if(errno != 0)
475 flog(LOG_ERR, "recvreq: %s", strerror(errno));
476 break;
477 }
478 serve(req, fd);
479 }
480 while((timeheap.d > 0) && ((now = rtime()) >= timeheap.b[0].tm))
481 checkbtime(timeheap.b[0].bk);
482 }
483 return(0);
484}