Fixed a couple of htparser bugs.
[ashd.git] / src / htparser.c
1 /*
2     ashd - A Sane HTTP Daemon
3     Copyright (C) 2008  Fredrik Tolf <fredrik@dolda2000.com>
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <stdlib.h>
20 #include <unistd.h>
21 #include <stdio.h>
22 #include <string.h>
23 #include <sys/select.h>
24 #include <sys/socket.h>
25 #include <netinet/in.h>
26 #include <arpa/inet.h>
27 #include <errno.h>
28 #include <time.h>
29
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33 #include <utils.h>
34 #include <mt.h>
35 #include <log.h>
36 #include <req.h>
37 #include <proc.h>
38
39 #define EV_READ 1
40 #define EV_WRITE 2
41
42 struct blocker {
43     struct blocker *n, *p;
44     int fd;
45     int ev;
46     time_t to;
47     struct muth *th;
48 };
49
50 static struct blocker *blockers;
51 int plex;
52
53 static int block(int fd, int ev, time_t to)
54 {
55     struct blocker *bl;
56     int rv;
57     
58     omalloc(bl);
59     bl->fd = fd;
60     bl->ev = ev;
61     if(to > 0)
62         bl->to = time(NULL) + to;
63     bl->th = current;
64     bl->n = blockers;
65     if(blockers)
66         blockers->p = bl;
67     blockers = bl;
68     rv = yield();
69     if(bl->n)
70         bl->n->p = bl->p;
71     if(bl->p)
72         bl->p->n = bl->n;
73     if(bl == blockers)
74         blockers = bl->n;
75     return(rv);
76 }
77
78 static int listensock4(int port)
79 {
80     struct sockaddr_in name;
81     int fd;
82     int valbuf;
83     
84     memset(&name, 0, sizeof(name));
85     name.sin_family = AF_INET;
86     name.sin_port = htons(port);
87     if((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0)
88         return(-1);
89     valbuf = 1;
90     setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &valbuf, sizeof(valbuf));
91     if(bind(fd, (struct sockaddr *)&name, sizeof(name))) {
92         close(fd);
93         return(-1);
94     }
95     if(listen(fd, 16) < 0) {
96         close(fd);
97         return(-1);
98     }
99     return(fd);
100 }
101
102 static int listensock6(int port)
103 {
104     struct sockaddr_in6 name;
105     int fd;
106     int valbuf;
107     
108     memset(&name, 0, sizeof(name));
109     name.sin6_family = AF_INET6;
110     name.sin6_port = htons(port);
111     if((fd = socket(PF_INET6, SOCK_STREAM, 0)) < 0)
112         return(-1);
113     valbuf = 1;
114     setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &valbuf, sizeof(valbuf));
115     if(bind(fd, (struct sockaddr *)&name, sizeof(name))) {
116         close(fd);
117         return(-1);
118     }
119     if(listen(fd, 16) < 0) {
120         close(fd);
121         return(-1);
122     }
123     return(fd);
124 }
125
126 static size_t readhead(int fd, struct charbuf *buf)
127 {
128     int nl;
129     size_t off;
130     
131     int get1(void)
132     {
133         int ret;
134         
135         while(!(off < buf->d)) {
136             sizebuf(*buf, buf->d + 1024);
137             ret = recv(fd, buf->b + buf->d, buf->s - buf->d, MSG_DONTWAIT);
138             if(ret <= 0) {
139                 if((ret < 0) && (errno == EAGAIN)) {
140                     if(block(fd, EV_READ, 60) <= 0)
141                         return(-1);
142                     continue;
143                 }
144                 return(-1);
145             }
146             buf->d += ret;
147         }
148         return(buf->b[off++]);
149     }
150
151     nl = 0;
152     off = 0;
153     while(1) {
154         switch(get1()) {
155         case '\n':
156             if(nl)
157                 return(off);
158             nl = 1;
159             break;
160         case '\r':
161             break;
162         case -1:
163             return(-1);
164         default:
165             nl = 0;
166             break;
167         }
168     }
169 }
170
171 #define SKIPNL(ptr) ({                          \
172             int __buf__;                        \
173             if(*(ptr) == '\r')                  \
174                 *((ptr)++) = 0;                 \
175             if(*(ptr) != '\n') {                \
176                 __buf__ = 0;                    \
177             } else {                            \
178                 *((ptr)++) = 0;                 \
179                 __buf__ = 1;                    \
180             }                                   \
181             __buf__;})
182 static struct hthead *parserawreq(char *buf)
183 {
184     char *p, *p2, *nl;
185     char *method, *url, *ver;
186     struct hthead *req;
187     
188     if((nl = strchr(buf, '\n')) == NULL)
189         return(NULL);
190     if(((p = strchr(buf, ' ')) == NULL) || (p > nl))
191         return(NULL);
192     method = buf;
193     *(p++) = 0;
194     if(((p2 = strchr(p, ' ')) == NULL) || (p2 > nl))
195         return(NULL);
196     url = p;
197     p = p2;
198     *(p++) = 0;
199     if(strncmp(p, "HTTP/", 5))
200         return(NULL);
201     ver = (p += 5);
202     for(; ((*p >= '0') && (*p <= '9')) || (*p == '.'); p++);
203     if(!SKIPNL(p))
204         return(NULL);
205
206     req = mkreq(method, url, ver);
207     while(1) {
208         if(SKIPNL(p)) {
209             if(*p)
210                 goto fail;
211             break;
212         }
213         if((nl = strchr(p, '\n')) == NULL)
214             goto fail;
215         if(((p2 = strchr(p, ':')) == NULL) || (p2 > nl))
216             goto fail;
217         *(p2++) = 0;
218         for(; (*p2 == ' ') || (*p2 == '\t'); p2++);
219         for(nl = p2; (*nl != '\r') && (*nl != '\n'); nl++);
220         if(!SKIPNL(nl))
221             goto fail;
222         if(strncasecmp(p, "x-ash-", 6))
223             headappheader(req, p, p2);
224         p = nl;
225     }
226     return(req);
227     
228 fail:
229     freehthead(req);
230     return(NULL);
231 }
232
233 static struct hthead *parserawresp(char *buf)
234 {
235     char *p, *p2, *nl;
236     char *msg, *ver;
237     int code;
238     struct hthead *resp;
239     
240     if((nl = strchr(buf, '\n')) == NULL)
241         return(NULL);
242     p = strchr(buf, '\r');
243     if((p != NULL) && (p < nl))
244         nl = p;
245     if(strncmp(buf, "HTTP/", 5))
246         return(NULL);
247     ver = p = buf + 5;
248     for(; ((*p >= '0') && (*p <= '9')) || (*p == '.'); p++);
249     if(*p != ' ')
250         return(NULL);
251     *(p++) = 0;
252     if(((p2 = strchr(p, ' ')) == NULL) || (p2 > nl))
253         return(NULL);
254     *(p2++) = 0;
255     code = atoi(p);
256     if((code < 100) || (code >= 600))
257         return(NULL);
258     if(p2 >= nl)
259         return(NULL);
260     msg = p2;
261     p = nl;
262     if(!SKIPNL(p))
263         return(NULL);
264
265     resp = mkresp(code, msg, ver);
266     while(1) {
267         if(SKIPNL(p)) {
268             if(*p)
269                 goto fail;
270             break;
271         }
272         if((nl = strchr(p, '\n')) == NULL)
273             goto fail;
274         if(((p2 = strchr(p, ':')) == NULL) || (p2 > nl))
275             goto fail;
276         *(p2++) = 0;
277         for(; (*p2 == ' ') || (*p2 == '\t'); p2++);
278         for(nl = p2; (*nl != '\r') && (*nl != '\n'); nl++);
279         if(!SKIPNL(nl))
280             goto fail;
281         headappheader(resp, p, p2);
282         p = nl;
283     }
284     return(resp);
285     
286 fail:
287     freehthead(resp);
288     return(NULL);
289 }
290
291 static off_t passdata(int src, int dst, struct charbuf *buf, off_t max)
292 {
293     size_t dataoff, smax;
294     off_t sent;
295     int eof, ret;
296
297     sent = 0;
298     eof = 0;
299     while((!eof || (buf->d > 0)) && ((max < 0) || (sent < max))) {
300         if(!eof && (buf->d < buf->s) && ((max < 0) || (sent + buf->d < max))) {
301             while(1) {
302                 ret = recv(src, buf->b + buf->d, buf->s - buf->d, MSG_DONTWAIT);
303                 if((ret < 0) && (errno == EAGAIN)) {
304                 } else if(ret < 0) {
305                     return(-1);
306                 } else if(ret == 0) {
307                     eof = 1;
308                     break;
309                 } else {
310                     buf->d += ret;
311                     break;
312                 }
313                 if(buf->d > 0)
314                     break;
315                 if(block(src, EV_READ, 0) <= 0)
316                     return(-1);
317             }
318         }
319         for(dataoff = 0; (dataoff < buf->d) && ((max < 0) || (sent < max));) {
320             if(block(dst, EV_WRITE, 120) <= 0)
321                 return(-1);
322             smax = buf->d - dataoff;
323             if(sent + smax > max)
324                 smax = max - sent;
325             ret = send(dst, buf->b + dataoff, smax, MSG_NOSIGNAL | MSG_DONTWAIT);
326             if(ret < 0)
327                 return(-1);
328             dataoff += ret;
329             sent += ret;
330         }
331         bufeat(*buf, dataoff);
332     }
333     return(sent);
334 }
335
336 static void serve(struct muth *muth, va_list args)
337 {
338     vavar(int, fd);
339     vavar(struct sockaddr_storage, name);
340     int cfd;
341     char old;
342     char *hd;
343     struct charbuf inbuf, outbuf;
344     struct hthead *req, *resp;
345     off_t dlen, sent;
346     ssize_t headoff;
347     char nmbuf[256];
348     
349     bufinit(inbuf);
350     bufinit(outbuf);
351     cfd = -1;
352     req = NULL;
353     while(1) {
354         /*
355          * First, find and decode the header:
356          */
357         if((headoff = readhead(fd, &inbuf)) < 0)
358             goto out;
359         if(headoff > 65536) {
360             /* We cannot handle arbitrarily large headers, as they
361              * need to fit within a single Unix datagram. This is
362              * probably a safe limit, and larger packets than this are
363              * most likely erroneous (or malicious) anyway. */
364             goto out;
365         }
366         old = inbuf.b[headoff];
367         inbuf.b[headoff] = 0;
368         if((req = parserawreq(inbuf.b)) == NULL)
369             goto out;
370         inbuf.b[headoff] = old;
371         bufeat(inbuf, headoff);
372         /* We strip off the leading slash from the rest string, so
373          * that multiplexers can parse coherently. */
374         if(req->rest[0] == '/')
375             replrest(req, req->rest + 1);
376         
377         /*
378          * Add metainformation and then send the request to the root
379          * multiplexer:
380          */
381         if(name.ss_family == AF_INET) {
382             headappheader(req, "X-Ash-Address", inet_ntop(AF_INET, &((struct sockaddr_in *)&name)->sin_addr, nmbuf, sizeof(nmbuf)));
383             headappheader(req, "X-Ash-Port", sprintf3("%i", ntohs(((struct sockaddr_in *)&name)->sin_port)));
384         } else if(name.ss_family == AF_INET6) {
385             headappheader(req, "X-Ash-Address", inet_ntop(AF_INET6, &((struct sockaddr_in6 *)&name)->sin6_addr, nmbuf, sizeof(nmbuf)));
386             headappheader(req, "X-Ash-Port", sprintf3("%i", ntohs(((struct sockaddr_in6 *)&name)->sin6_port)));
387         }
388         cfd = sendreq(plex, req);
389
390         /*
391          * If there is message data, pass it:
392          */
393         if((hd = getheader(req, "content-length")) != NULL) {
394             dlen = atoo(hd);
395             if(dlen > 0)
396                 passdata(fd, cfd, &inbuf, dlen);
397         }
398         /* Make sure to send EOF */
399         shutdown(cfd, SHUT_WR);
400         
401         /*
402          * Find and decode the response header:
403          */
404         outbuf.d = 0;
405         headoff = readhead(cfd, &outbuf);
406         hd = memcpy(smalloc(headoff + 1), outbuf.b, headoff);
407         hd[headoff] = 0;
408         if((resp = parserawresp(hd)) == NULL)
409             goto out;
410         
411         /*
412          * Pass the actual output:
413          */
414         sizebuf(outbuf, 65536);
415         sent = passdata(cfd, fd, &outbuf, -1);
416         sent -= headoff;
417         
418         /*
419          * Check for connection expiry
420          */
421         if(strcasecmp(req->method, "head")) {
422             if((hd = getheader(resp, "content-length")) != NULL) {
423                 if(sent != atoo(hd)) {
424                     /* Exit because of error */
425                     goto out;
426                 }
427             } else {
428                 if(((hd = getheader(resp, "transfer-encoding")) == NULL) || !strcasecmp(hd, "identity"))
429                     break;
430             }
431             if(((hd = getheader(req, "connection")) != NULL) && !strcasecmp(hd, "close"))
432                 break;
433             if(((hd = getheader(resp, "connection")) != NULL) && !strcasecmp(hd, "close"))
434                 break;
435         }
436         
437         close(cfd);
438         cfd = -1;
439         freehthead(req);
440         req = NULL;
441         freehthead(resp);
442         resp = NULL;
443     }
444     
445 out:
446     if(cfd >= 0)
447         close(cfd);
448     if(req != NULL)
449         freehthead(req);
450     if(resp != NULL)
451         freehthead(resp);
452     buffree(inbuf);
453     buffree(outbuf);
454     close(fd);
455 }
456
457 static void listenloop(struct muth *muth, va_list args)
458 {
459     vavar(int, ss);
460     int ns;
461     struct sockaddr_storage name;
462     socklen_t namelen;
463     
464     while(1) {
465         namelen = sizeof(name);
466         block(ss, EV_READ, 0);
467         ns = accept(ss, (struct sockaddr *)&name, &namelen);
468         if(ns < 0) {
469             flog(LOG_ERR, "accept: %s", strerror(errno));
470             goto out;
471         }
472         mustart(serve, ns, name);
473     }
474     
475 out:
476     close(ss);
477 }
478
479 static void ioloop(void)
480 {
481     int ret;
482     fd_set rfds, wfds, efds;
483     struct blocker *bl, *nbl;
484     struct timeval toval;
485     time_t now, timeout;
486     int maxfd;
487     int ev;
488     
489     while(blockers != NULL) {
490         FD_ZERO(&rfds);
491         FD_ZERO(&wfds);
492         FD_ZERO(&efds);
493         maxfd = 0;
494         now = time(NULL);
495         timeout = 0;
496         for(bl = blockers; bl; bl = bl->n) {
497             if(bl->ev & EV_READ)
498                 FD_SET(bl->fd, &rfds);
499             if(bl->ev & EV_WRITE)
500                 FD_SET(bl->fd, &wfds);
501             FD_SET(bl->fd, &efds);
502             if(bl->fd > maxfd)
503                 maxfd = bl->fd;
504             if((bl->to != 0) && ((timeout == 0) || (timeout > bl->to)))
505                 timeout = bl->to;
506         }
507         toval.tv_sec = timeout - now;
508         toval.tv_usec = 0;
509         ret = select(maxfd + 1, &rfds, &wfds, &efds, timeout?(&toval):NULL);
510         if(ret < 0) {
511             if(errno != EINTR) {
512                 flog(LOG_CRIT, "ioloop: select errored out: %s", strerror(errno));
513                 /* To avoid CPU hogging in case it's bad, which it
514                  * probably is. */
515                 sleep(1);
516             }
517         }
518         now = time(NULL);
519         for(bl = blockers; bl; bl = nbl) {
520             nbl = bl->n;
521             ev = 0;
522             if(FD_ISSET(bl->fd, &rfds))
523                 ev |= EV_READ;
524             if(FD_ISSET(bl->fd, &wfds))
525                 ev |= EV_WRITE;
526             if(FD_ISSET(bl->fd, &efds))
527                 ev = -1;
528             if(ev != 0)
529                 resume(bl->th, ev);
530             else if((bl->to != 0) && (bl->to <= now))
531                 resume(bl->th, 0);
532         }
533     }
534 }
535
536 int main(int argc, char **argv)
537 {
538     int fd;
539     
540     if(argc < 2) {
541         fprintf(stderr, "usage: htparser ROOT [ARGS...]\n");
542         exit(1);
543     }
544     if((plex = stdmkchild(argv + 1)) < 0) {
545         flog(LOG_ERR, "could not spawn root multiplexer: %s", strerror(errno));
546         return(1);
547     }
548     if((fd = listensock6(8080)) < 0) {
549         flog(LOG_ERR, "could not listen on IPv6: %s", strerror(errno));
550         return(1);
551     }
552     mustart(listenloop, fd);
553     if((fd = listensock4(8080)) < 0) {
554         if(errno != EADDRINUSE) {
555             flog(LOG_ERR, "could not listen on IPv4: %s", strerror(errno));
556             return(1);
557         }
558     } else {
559         mustart(listenloop, fd);
560     }
561     ioloop();
562     return(0);
563 }