Made htparser capable of handling basic requests.
[ashd.git] / src / htparser.c
1 /*
2     ashd - A Sane HTTP Daemon
3     Copyright (C) 2008  Fredrik Tolf <fredrik@dolda2000.com>
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <stdlib.h>
20 #include <unistd.h>
21 #include <stdio.h>
22 #include <string.h>
23 #include <sys/select.h>
24 #include <sys/socket.h>
25 #include <netinet/in.h>
26 #include <arpa/inet.h>
27 #include <errno.h>
28 #include <time.h>
29
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33 #include <utils.h>
34 #include <mt.h>
35 #include <log.h>
36 #include <req.h>
37 #include <proc.h>
38
39 #define EV_READ 1
40 #define EV_WRITE 2
41
42 struct blocker {
43     struct blocker *n, *p;
44     int fd;
45     int ev;
46     time_t to;
47     struct muth *th;
48 };
49
50 static struct blocker *blockers;
51 int plex;
52
53 static int block(int fd, int ev, time_t to)
54 {
55     struct blocker *bl;
56     int rv;
57     
58     omalloc(bl);
59     bl->fd = fd;
60     bl->ev = ev;
61     if(to > 0)
62         bl->to = time(NULL) + to;
63     bl->th = current;
64     bl->n = blockers;
65     if(blockers)
66         blockers->p = bl;
67     blockers = bl;
68     rv = yield();
69     if(bl->n)
70         bl->n->p = bl->p;
71     if(bl->p)
72         bl->p->n = bl->n;
73     if(bl == blockers)
74         blockers = bl->n;
75     return(rv);
76 }
77
78 static int listensock4(int port)
79 {
80     struct sockaddr_in name;
81     int fd;
82     int valbuf;
83     
84     memset(&name, 0, sizeof(name));
85     name.sin_family = AF_INET;
86     name.sin_port = htons(port);
87     if((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0)
88         return(-1);
89     valbuf = 1;
90     setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &valbuf, sizeof(valbuf));
91     if(bind(fd, (struct sockaddr *)&name, sizeof(name))) {
92         close(fd);
93         return(-1);
94     }
95     if(listen(fd, 16) < 0) {
96         close(fd);
97         return(-1);
98     }
99     return(fd);
100 }
101
102 static int listensock6(int port)
103 {
104     struct sockaddr_in6 name;
105     int fd;
106     int valbuf;
107     
108     memset(&name, 0, sizeof(name));
109     name.sin6_family = AF_INET6;
110     name.sin6_port = htons(port);
111     if((fd = socket(PF_INET6, SOCK_STREAM, 0)) < 0)
112         return(-1);
113     valbuf = 1;
114     setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &valbuf, sizeof(valbuf));
115     if(bind(fd, (struct sockaddr *)&name, sizeof(name))) {
116         close(fd);
117         return(-1);
118     }
119     if(listen(fd, 16) < 0) {
120         close(fd);
121         return(-1);
122     }
123     return(fd);
124 }
125
126 static size_t readhead(int fd, struct charbuf *buf)
127 {
128     int nl;
129     size_t off;
130     
131     int get1(void)
132     {
133         int ret;
134         
135         while(!(off < buf->d)) {
136             sizebuf(*buf, buf->d + 1024);
137             ret = recv(fd, buf->b + buf->d, buf->s - buf->d, MSG_DONTWAIT);
138             if(ret <= 0) {
139                 if((ret < 0) && (errno == EAGAIN)) {
140                     if(block(fd, EV_READ, 60) <= 0)
141                         return(-1);
142                     continue;
143                 }
144                 return(-1);
145             }
146             buf->d += ret;
147         }
148         return(buf->b[off++]);
149     }
150
151     nl = 0;
152     off = 0;
153     while(1) {
154         switch(get1()) {
155         case '\n':
156             if(nl)
157                 return(off);
158             nl = 1;
159             break;
160         case '\r':
161             break;
162         case -1:
163             return(-1);
164         default:
165             nl = 0;
166             break;
167         }
168     }
169 }
170
171 #define SKIPNL(ptr) ({                          \
172             int __buf__;                        \
173             if(*(ptr) == '\r')                  \
174                 *((ptr)++) = 0;                 \
175             if(*(ptr) != '\n') {                \
176                 __buf__ = 0;                    \
177             } else {                            \
178                 *((ptr)++) = 0;                 \
179                 __buf__ = 1;                    \
180             }                                   \
181             __buf__;})
182 static struct hthead *parserawreq(char *buf)
183 {
184     char *p, *p2, *nl;
185     char *method, *url, *ver;
186     struct hthead *req;
187     
188     if((nl = strchr(buf, '\n')) == NULL)
189         return(NULL);
190     if(((p = strchr(buf, ' ')) == NULL) || (p > nl))
191         return(NULL);
192     method = buf;
193     *(p++) = 0;
194     if(((p2 = strchr(p, ' ')) == NULL) || (p2 > nl))
195         return(NULL);
196     url = p;
197     p = p2;
198     *(p++) = 0;
199     if(strncmp(p, "HTTP/", 5))
200         return(NULL);
201     ver = (p += 5);
202     for(; ((*p >= '0') && (*p <= '9')) || (*p == '.'); p++);
203     if(!SKIPNL(p))
204         return(NULL);
205
206     req = mkreq(method, url, ver);
207     while(1) {
208         if(SKIPNL(p)) {
209             if(*p)
210                 goto fail;
211             break;
212         }
213         if((nl = strchr(p, '\n')) == NULL)
214             goto fail;
215         if(((p2 = strchr(p, ':')) == NULL) || (p2 > nl))
216             goto fail;
217         *(p2++) = 0;
218         for(; (*p2 == ' ') || (*p2 == '\t'); p2++);
219         for(nl = p2; (*nl != '\r') && (*nl != '\n'); nl++);
220         if(!SKIPNL(nl))
221             goto fail;
222         if(strncasecmp(p, "x-ash-", 6))
223             headappheader(req, p, p2);
224         p = nl;
225     }
226     return(req);
227     
228 fail:
229     freehthead(req);
230     return(NULL);
231 }
232
233 static struct hthead *parserawresp(char *buf)
234 {
235     char *p, *p2, *nl;
236     char *msg, *ver;
237     int code;
238     struct hthead *resp;
239     
240     if((nl = strchr(buf, '\n')) == NULL)
241         return(NULL);
242     p = strchr(buf, '\r');
243     if((p != NULL) && (p < nl))
244         nl = p;
245     if(strncmp(buf, "HTTP/", 5))
246         return(NULL);
247     ver = p = buf + 5;
248     for(; ((*p >= '0') && (*p <= '9')) || (*p == '.'); p++);
249     if(*p != ' ')
250         return(NULL);
251     *(p++) = 0;
252     if(((p2 = strchr(p, ' ')) == NULL) || (p2 > nl))
253         return(NULL);
254     *(p2++) = 0;
255     code = atoi(p);
256     if((code < 100) || (code >= 600))
257         return(NULL);
258     if(p2 >= nl)
259         return(NULL);
260     msg = p2;
261     p = nl;
262     if(!SKIPNL(p))
263         return(NULL);
264
265     resp = mkresp(code, msg, ver);
266     while(1) {
267         if(SKIPNL(p)) {
268             if(*p)
269                 goto fail;
270             break;
271         }
272         if((nl = strchr(p, '\n')) == NULL)
273             goto fail;
274         if(((p2 = strchr(p, ':')) == NULL) || (p2 > nl))
275             goto fail;
276         *(p2++) = 0;
277         for(; (*p2 == ' ') || (*p2 == '\t'); p2++);
278         for(nl = p2; (*nl != '\r') && (*nl != '\n'); nl++);
279         if(!SKIPNL(nl))
280             goto fail;
281         headappheader(resp, p, p2);
282         p = nl;
283     }
284     return(resp);
285     
286 fail:
287     freehthead(resp);
288     return(NULL);
289 }
290
291 static off_t passdata(int src, int dst, struct charbuf *buf, off_t max)
292 {
293     size_t dataoff, smax;
294     off_t sent;
295     int eof, ret;
296
297     sent = 0;
298     eof = 0;
299     while(!eof || (buf->d > 0)) {
300         if(!eof && (buf->d < buf->s) && ((max < 0) || (sent + buf->d < max))) {
301             while(1) {
302                 ret = recv(src, buf->b + buf->d, buf->s - buf->d, MSG_DONTWAIT);
303                 if((ret < 0) && (errno == EAGAIN)) {
304                 } else if(ret < 0) {
305                     return(-1);
306                 } else if(ret == 0) {
307                     eof = 1;
308                     break;
309                 } else {
310                     buf->d += ret;
311                     break;
312                 }
313                 if(buf->d > 0)
314                     break;
315                 if(block(src, EV_READ, 0) <= 0)
316                     return(-1);
317             }
318         }
319         for(dataoff = 0; (dataoff < buf->d) && ((max < 0) || (sent < max));) {
320             if(block(dst, EV_WRITE, 120) <= 0)
321                 return(-1);
322             smax = buf->d - dataoff;
323             if(sent + smax > max)
324                 smax = max - sent;
325             ret = send(dst, buf->b + dataoff, smax, MSG_NOSIGNAL | MSG_DONTWAIT);
326             if(ret < 0)
327                 return(-1);
328             dataoff += ret;
329             sent += ret;
330         }
331         bufeat(*buf, dataoff);
332     }
333     return(sent);
334 }
335
336 static void serve(struct muth *muth, va_list args)
337 {
338     vavar(int, fd);
339     vavar(struct sockaddr_storage, name);
340     int cfd;
341     char old;
342     char *hd;
343     struct charbuf inbuf, outbuf;
344     struct hthead *req, *resp;
345     off_t sent;
346     size_t headoff;
347     char nmbuf[256];
348     
349     bufinit(inbuf);
350     bufinit(outbuf);
351     cfd = -1;
352     req = NULL;
353     while(1) {
354         /*
355          * First, find and decode the header:
356          */
357         if((headoff = readhead(fd, &inbuf)) < 0)
358             goto out;
359         if(headoff > 65536) {
360             /* We cannot handle arbitrarily large headers, as they
361              * need to fit within a single Unix datagram. This is
362              * probably a safe limit, and larger packets than this are
363              * most likely erroneous (or malicious) anyway. */
364             goto out;
365         }
366         old = inbuf.b[headoff];
367         inbuf.b[headoff] = 0;
368         if((req = parserawreq(inbuf.b)) == NULL)
369             goto out;
370         inbuf.b[headoff] = old;
371         bufeat(inbuf, headoff);
372         
373         /*
374          * Add metainformation and then send the request to the root
375          * multiplexer:
376          */
377         if(name.ss_family == AF_INET) {
378             headappheader(req, "X-Ash-Address", inet_ntop(AF_INET, &((struct sockaddr_in *)&name)->sin_addr, nmbuf, sizeof(nmbuf)));
379             headappheader(req, "X-Ash-Port", sprintf3("%i", ntohs(((struct sockaddr_in *)&name)->sin_port)));
380         } else if(name.ss_family == AF_INET6) {
381             headappheader(req, "X-Ash-Address", inet_ntop(AF_INET6, &((struct sockaddr_in6 *)&name)->sin6_addr, nmbuf, sizeof(nmbuf)));
382             headappheader(req, "X-Ash-Port", sprintf3("%i", ntohs(((struct sockaddr_in6 *)&name)->sin6_port)));
383         }
384         cfd = sendreq(plex, req);
385         
386         /*
387          * Find and decode the response header:
388          */
389         outbuf.d = 0;
390         headoff = readhead(cfd, &outbuf);
391         hd = memcpy(smalloc(headoff + 1), outbuf.b, headoff);
392         hd[headoff] = 0;
393         if((resp = parserawresp(hd)) == NULL)
394             goto out;
395         
396         /*
397          * Pass the actual output:
398          */
399         sizebuf(outbuf, 65536);
400         sent = passdata(cfd, fd, &outbuf, -1);
401         sent -= headoff;
402         
403         /*
404          * Check for connection expiry
405          */
406         if(strcasecmp(req->method, "head")) {
407             if((hd = getheader(resp, "content-length")) != NULL) {
408                 if(sent != atoo(hd)) {
409                     /* Exit because of error */
410                     goto out;
411                 }
412             } else {
413                 if(((hd = getheader(resp, "transfer-encoding")) == NULL) || !strcasecmp(hd, "identity"))
414                     break;
415             }
416             if(((hd = getheader(req, "connection")) != NULL) && !strcasecmp(hd, "close"))
417                 break;
418             if(((hd = getheader(resp, "connection")) != NULL) && !strcasecmp(hd, "close"))
419                 break;
420         }
421         
422         close(cfd);
423         cfd = -1;
424         freehthead(req);
425         req = NULL;
426         freehthead(resp);
427         resp = NULL;
428     }
429     
430 out:
431     if(cfd >= 0)
432         close(cfd);
433     if(req != NULL)
434         freehthead(req);
435     if(resp != NULL)
436         freehthead(resp);
437     buffree(inbuf);
438     buffree(outbuf);
439     close(fd);
440 }
441
442 static void listenloop(struct muth *muth, va_list args)
443 {
444     vavar(int, ss);
445     int ns;
446     struct sockaddr_storage name;
447     socklen_t namelen;
448     
449     while(1) {
450         namelen = sizeof(name);
451         block(ss, EV_READ, 0);
452         ns = accept(ss, (struct sockaddr *)&name, &namelen);
453         if(ns < 0) {
454             flog(LOG_ERR, "accept: %s", strerror(errno));
455             goto out;
456         }
457         mustart(serve, ns, name);
458     }
459     
460 out:
461     close(ss);
462 }
463
464 static void ioloop(void)
465 {
466     int ret;
467     fd_set rfds, wfds, efds;
468     struct blocker *bl, *nbl;
469     struct timeval toval;
470     time_t now, timeout;
471     int maxfd;
472     int ev;
473     
474     while(blockers != NULL) {
475         FD_ZERO(&rfds);
476         FD_ZERO(&wfds);
477         FD_ZERO(&efds);
478         maxfd = 0;
479         now = time(NULL);
480         timeout = 0;
481         for(bl = blockers; bl; bl = bl->n) {
482             if(bl->ev & EV_READ)
483                 FD_SET(bl->fd, &rfds);
484             if(bl->ev & EV_WRITE)
485                 FD_SET(bl->fd, &wfds);
486             FD_SET(bl->fd, &efds);
487             if(bl->fd > maxfd)
488                 maxfd = bl->fd;
489             if((bl->to != 0) && ((timeout == 0) || (timeout > bl->to)))
490                 timeout = bl->to;
491         }
492         toval.tv_sec = timeout - now;
493         toval.tv_usec = 0;
494         ret = select(maxfd + 1, &rfds, &wfds, &efds, timeout?(&toval):NULL);
495         if(ret < 0) {
496             if(errno != EINTR) {
497                 flog(LOG_CRIT, "ioloop: select errored out: %s", strerror(errno));
498                 /* To avoid CPU hogging in case it's bad, which it
499                  * probably is. */
500                 sleep(1);
501             }
502         }
503         now = time(NULL);
504         for(bl = blockers; bl; bl = nbl) {
505             nbl = bl->n;
506             ev = 0;
507             if(FD_ISSET(bl->fd, &rfds))
508                 ev |= EV_READ;
509             if(FD_ISSET(bl->fd, &wfds))
510                 ev |= EV_WRITE;
511             if(FD_ISSET(bl->fd, &efds))
512                 ev = -1;
513             if(ev != 0)
514                 resume(bl->th, ev);
515             else if((bl->to != 0) && (bl->to <= now))
516                 resume(bl->th, 0);
517         }
518     }
519 }
520
521 int main(int argc, char **argv)
522 {
523     int fd;
524     
525     if(argc < 2) {
526         fprintf(stderr, "usage: htparser ROOT [ARGS...]\n");
527         exit(1);
528     }
529     if((plex = stdmkchild(argv + 1)) < 0) {
530         flog(LOG_ERR, "could not spawn root multiplexer: %s", strerror(errno));
531         return(1);
532     }
533     if((fd = listensock6(8080)) < 0) {
534         flog(LOG_ERR, "could not listen on IPv6: %s", strerror(errno));
535         return(1);
536     }
537     mustart(listenloop, fd);
538     if((fd = listensock4(8080)) < 0) {
539         if(errno != EADDRINUSE) {
540             flog(LOG_ERR, "could not listen on IPv4: %s", strerror(errno));
541             return(1);
542         }
543     } else {
544         mustart(listenloop, fd);
545     }
546     ioloop();
547     return(0);
548 }