Moved htparser's ioloop to the library.
[ashd.git] / src / htparser.c
index a8705fe..6382bf2 100644 (file)
@@ -20,9 +20,9 @@
 #include <unistd.h>
 #include <stdio.h>
 #include <string.h>
-#include <sys/select.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
+#include <arpa/inet.h>
 #include <errno.h>
 
 #ifdef HAVE_CONFIG_H
 #endif
 #include <utils.h>
 #include <mt.h>
+#include <mtio.h>
 #include <log.h>
 #include <req.h>
+#include <proc.h>
 
-#define EV_READ 1
-#define EV_WRITE 2
-
-struct blocker {
-    struct blocker *n, *p;
-    int fd;
-    int ev;
-    struct muth *th;
-};
-
-static struct blocker *blockers;
-
-static int block(int fd, int ev)
-{
-    struct blocker *bl;
-    int rv;
-    
-    omalloc(bl);
-    bl->fd = fd;
-    bl->ev = ev;
-    bl->th = current;
-    bl->n = blockers;
-    if(blockers)
-       blockers->p = bl;
-    blockers = bl;
-    rv = yield();
-    if(bl->n)
-       bl->n->p = bl->p;
-    if(bl->p)
-       bl->p->n = bl->n;
-    if(bl == blockers)
-       blockers = bl->n;
-    return(rv);
-}
+static int plex;
 
 static int listensock4(int port)
 {
@@ -116,41 +85,49 @@ static int listensock6(int port)
     return(fd);
 }
 
-static char *slowreadhead(int fd)
+static size_t readhead(int fd, struct charbuf *buf)
 {
-    int ret;
-    struct charbuf buf;
     int nl;
-    char last;
+    size_t off;
     
-    bufinit(buf);
-    nl = 0;
-    while(1) {
-       sizebuf(buf, buf.d + 1);
-       ret = recv(fd, buf.b + buf.d, 1, MSG_DONTWAIT);
-       if(ret <= 0) {
-           if((ret < 0) && (errno == EAGAIN)) {
-               block(fd, EV_READ);
-               continue;
+    int get1(void)
+    {
+       int ret;
+       
+       while(!(off < buf->d)) {
+           sizebuf(*buf, buf->d + 1024);
+           ret = recv(fd, buf->b + buf->d, buf->s - buf->d, MSG_DONTWAIT);
+           if(ret <= 0) {
+               if((ret < 0) && (errno == EAGAIN)) {
+                   if(block(fd, EV_READ, 60) <= 0)
+                       return(-1);
+                   continue;
+               }
+               return(-1);
            }
-           goto err;
+           buf->d += ret;
        }
-       last = buf.b[buf.d++];
-       if(last == '\n') {
+       return(buf->b[off++]);
+    }
+
+    nl = 0;
+    off = 0;
+    while(1) {
+       switch(get1()) {
+       case '\n':
            if(nl)
-               break;
+               return(off);
            nl = 1;
-       } else if(last == '\r') {
-       } else {
+           break;
+       case '\r':
+           break;
+       case -1:
+           return(-1);
+       default:
            nl = 0;
+           break;
        }
     }
-    bufadd(buf, 0);
-    return(buf.b);
-    
-err:
-    buffree(buf);
-    return(NULL);
 }
 
 #define SKIPNL(ptr) ({                         \
@@ -164,11 +141,11 @@ err:
                __buf__ = 1;                    \
            }                                   \
            __buf__;})
-static struct htreq *parseraw(char *buf)
+static struct hthead *parserawreq(char *buf)
 {
     char *p, *p2, *nl;
     char *method, *url, *ver;
-    struct htreq *req;
+    struct hthead *req;
     
     if((nl = strchr(buf, '\n')) == NULL)
        return(NULL);
@@ -192,45 +169,255 @@ static struct htreq *parseraw(char *buf)
     while(1) {
        if(SKIPNL(p)) {
            if(*p)
-               return(NULL);
+               goto fail;
            break;
        }
        if((nl = strchr(p, '\n')) == NULL)
-           return(NULL);
+           goto fail;
        if(((p2 = strchr(p, ':')) == NULL) || (p2 > nl))
-           return(NULL);
+           goto fail;
        *(p2++) = 0;
        for(; (*p2 == ' ') || (*p2 == '\t'); p2++);
        for(nl = p2; (*nl != '\r') && (*nl != '\n'); nl++);
        if(!SKIPNL(nl))
-           return(NULL);
-       reqappheader(req, p, p2);
+           goto fail;
+       if(strncasecmp(p, "x-ash-", 6))
+           headappheader(req, p, p2);
        p = nl;
     }
     return(req);
+    
+fail:
+    freehthead(req);
+    return(NULL);
+}
+
+static struct hthead *parserawresp(char *buf)
+{
+    char *p, *p2, *nl;
+    char *msg, *ver;
+    int code;
+    struct hthead *resp;
+    
+    if((nl = strchr(buf, '\n')) == NULL)
+       return(NULL);
+    p = strchr(buf, '\r');
+    if((p != NULL) && (p < nl))
+       nl = p;
+    if(strncmp(buf, "HTTP/", 5))
+       return(NULL);
+    ver = p = buf + 5;
+    for(; ((*p >= '0') && (*p <= '9')) || (*p == '.'); p++);
+    if(*p != ' ')
+       return(NULL);
+    *(p++) = 0;
+    if(((p2 = strchr(p, ' ')) == NULL) || (p2 > nl))
+       return(NULL);
+    *(p2++) = 0;
+    code = atoi(p);
+    if((code < 100) || (code >= 600))
+       return(NULL);
+    if(p2 >= nl)
+       return(NULL);
+    msg = p2;
+    p = nl;
+    if(!SKIPNL(p))
+       return(NULL);
+
+    resp = mkresp(code, msg, ver);
+    while(1) {
+       if(SKIPNL(p)) {
+           if(*p)
+               goto fail;
+           break;
+       }
+       if((nl = strchr(p, '\n')) == NULL)
+           goto fail;
+       if(((p2 = strchr(p, ':')) == NULL) || (p2 > nl))
+           goto fail;
+       *(p2++) = 0;
+       for(; (*p2 == ' ') || (*p2 == '\t'); p2++);
+       for(nl = p2; (*nl != '\r') && (*nl != '\n'); nl++);
+       if(!SKIPNL(nl))
+           goto fail;
+       headappheader(resp, p, p2);
+       p = nl;
+    }
+    return(resp);
+    
+fail:
+    freehthead(resp);
+    return(NULL);
+}
+
+static off_t passdata(int src, int dst, struct charbuf *buf, off_t max)
+{
+    size_t dataoff, smax;
+    off_t sent;
+    int eof, ret;
+
+    sent = 0;
+    eof = 0;
+    while((!eof || (buf->d > 0)) && ((max < 0) || (sent < max))) {
+       if(!eof && (buf->d < buf->s) && ((max < 0) || (sent + buf->d < max))) {
+           while(1) {
+               ret = recv(src, buf->b + buf->d, buf->s - buf->d, MSG_DONTWAIT);
+               if((ret < 0) && (errno == EAGAIN)) {
+               } else if(ret < 0) {
+                   return(-1);
+               } else if(ret == 0) {
+                   eof = 1;
+                   break;
+               } else {
+                   buf->d += ret;
+                   break;
+               }
+               if(buf->d > 0)
+                   break;
+               if(block(src, EV_READ, 0) <= 0)
+                   return(-1);
+           }
+       }
+       for(dataoff = 0; (dataoff < buf->d) && ((max < 0) || (sent < max));) {
+           if(block(dst, EV_WRITE, 120) <= 0)
+               return(-1);
+           smax = buf->d - dataoff;
+           if(sent + smax > max)
+               smax = max - sent;
+           ret = send(dst, buf->b + dataoff, smax, MSG_NOSIGNAL | MSG_DONTWAIT);
+           if(ret < 0)
+               return(-1);
+           dataoff += ret;
+           sent += ret;
+       }
+       bufeat(*buf, dataoff);
+    }
+    return(sent);
 }
 
 static void serve(struct muth *muth, va_list args)
 {
     vavar(int, fd);
-    char *hb;
-    struct htreq *req;
+    vavar(struct sockaddr_storage, name);
+    int cfd;
+    char old;
+    char *hd, *p;
+    struct charbuf inbuf, outbuf;
+    struct hthead *req, *resp;
+    off_t dlen, sent;
+    ssize_t headoff;
+    char nmbuf[256];
     
-    hb = NULL;
+    bufinit(inbuf);
+    bufinit(outbuf);
+    cfd = -1;
+    req = resp = NULL;
     while(1) {
-       if((hb = slowreadhead(fd)) == NULL)
+       /*
+        * First, find and decode the header:
+        */
+       if((headoff = readhead(fd, &inbuf)) < 0)
            goto out;
-       if((req = parseraw(hb)) == NULL)
+       if(headoff > 65536) {
+           /* We cannot handle arbitrarily large headers, as they
+            * need to fit within a single Unix datagram. This is
+            * probably a safe limit, and larger packets than this are
+            * most likely erroneous (or malicious) anyway. */
            goto out;
-       free(hb);
-       hb = NULL;
-       printf("\"%s\", \"%s\", \"%s\"\n", req->method, req->url, req->ver);
-       freereq(req);
+       }
+       old = inbuf.b[headoff];
+       inbuf.b[headoff] = 0;
+       if((req = parserawreq(inbuf.b)) == NULL)
+           goto out;
+       inbuf.b[headoff] = old;
+       bufeat(inbuf, headoff);
+       /* We strip off the leading slash and any param string from
+        * the rest string, so that multiplexers can parse
+        * coherently. */
+       if(req->rest[0] == '/')
+           replrest(req, req->rest + 1);
+       if((p = strchr(req->rest, '?')) != NULL)
+           *p = 0;
+       
+       /*
+        * Add metainformation and then send the request to the root
+        * multiplexer:
+        */
+       if(name.ss_family == AF_INET) {
+           headappheader(req, "X-Ash-Address", inet_ntop(AF_INET, &((struct sockaddr_in *)&name)->sin_addr, nmbuf, sizeof(nmbuf)));
+           headappheader(req, "X-Ash-Port", sprintf3("%i", ntohs(((struct sockaddr_in *)&name)->sin_port)));
+       } else if(name.ss_family == AF_INET6) {
+           headappheader(req, "X-Ash-Address", inet_ntop(AF_INET6, &((struct sockaddr_in6 *)&name)->sin6_addr, nmbuf, sizeof(nmbuf)));
+           headappheader(req, "X-Ash-Port", sprintf3("%i", ntohs(((struct sockaddr_in6 *)&name)->sin6_port)));
+       }
+       if((cfd = sendreq(plex, req)) < 0)
+           goto out;
+
+       /*
+        * If there is message data, pass it:
+        */
+       if((hd = getheader(req, "content-length")) != NULL) {
+           dlen = atoo(hd);
+           if(dlen > 0)
+               passdata(fd, cfd, &inbuf, dlen);
+       }
+       /* Make sure to send EOF */
+       shutdown(cfd, SHUT_WR);
+       
+       /*
+        * Find and decode the response header:
+        */
+       outbuf.d = 0;
+       if((headoff = readhead(cfd, &outbuf)) < 0)
+           goto out;
+       hd = memcpy(smalloc(headoff + 1), outbuf.b, headoff);
+       hd[headoff] = 0;
+       if((resp = parserawresp(hd)) == NULL)
+           goto out;
+       
+       /*
+        * Pass the actual output:
+        */
+       sizebuf(outbuf, 65536);
+       sent = passdata(cfd, fd, &outbuf, -1);
+       sent -= headoff;
+       
+       /*
+        * Check for connection expiry
+        */
+       if(strcasecmp(req->method, "head")) {
+           if((hd = getheader(resp, "content-length")) != NULL) {
+               if(sent != atoo(hd)) {
+                   /* Exit because of error */
+                   goto out;
+               }
+           } else {
+               if(((hd = getheader(resp, "transfer-encoding")) == NULL) || !strcasecmp(hd, "identity"))
+                   break;
+           }
+           if(((hd = getheader(req, "connection")) != NULL) && !strcasecmp(hd, "close"))
+               break;
+           if(((hd = getheader(resp, "connection")) != NULL) && !strcasecmp(hd, "close"))
+               break;
+       }
+       
+       close(cfd);
+       cfd = -1;
+       freehthead(req);
+       req = NULL;
+       freehthead(resp);
+       resp = NULL;
     }
     
 out:
-    if(hb != NULL)
-       free(hb);
+    if(cfd >= 0)
+       close(cfd);
+    if(req != NULL)
+       freehthead(req);
+    if(resp != NULL)
+       freehthead(resp);
+    buffree(inbuf);
+    buffree(outbuf);
     close(fd);
 }
 
@@ -243,69 +430,31 @@ static void listenloop(struct muth *muth, va_list args)
     
     while(1) {
        namelen = sizeof(name);
-       block(ss, EV_READ);
+       block(ss, EV_READ, 0);
        ns = accept(ss, (struct sockaddr *)&name, &namelen);
        if(ns < 0) {
            flog(LOG_ERR, "accept: %s", strerror(errno));
            goto out;
        }
-       mustart(serve, ns);
+       mustart(serve, ns, name);
     }
     
 out:
     close(ss);
 }
 
-static void ioloop(void)
-{
-    int ret;
-    fd_set rfds, wfds, efds;
-    struct blocker *bl, *nbl;
-    int maxfd;
-    int ev;
-    
-    while(blockers != NULL) {
-       FD_ZERO(&rfds);
-       FD_ZERO(&wfds);
-       FD_ZERO(&efds);
-       maxfd = 0;
-       for(bl = blockers; bl; bl = bl->n) {
-           if(bl->ev & EV_READ)
-               FD_SET(bl->fd, &rfds);
-           if(bl->ev & EV_WRITE)
-               FD_SET(bl->fd, &wfds);
-           FD_SET(bl->fd, &efds);
-           if(bl->fd > maxfd)
-               maxfd = bl->fd;
-       }
-       ret = select(maxfd + 1, &rfds, &wfds, &efds, NULL);
-       if(ret < 0) {
-           if(errno != EINTR) {
-               flog(LOG_CRIT, "ioloop: select errored out: %s", strerror(errno));
-               /* To avoid CPU hogging in case it's bad, which it
-                * probably is. */
-               sleep(1);
-           }
-       }
-       for(bl = blockers; bl; bl = nbl) {
-           nbl = bl->n;
-           ev = 0;
-           if(FD_ISSET(bl->fd, &rfds))
-               ev |= EV_READ;
-           if(FD_ISSET(bl->fd, &wfds))
-               ev |= EV_WRITE;
-           if(FD_ISSET(bl->fd, &efds))
-               ev = -1;
-           if(ev != 0)
-               resume(bl->th, ev);
-       }
-    }
-}
-
 int main(int argc, char **argv)
 {
     int fd;
     
+    if(argc < 2) {
+       fprintf(stderr, "usage: htparser ROOT [ARGS...]\n");
+       exit(1);
+    }
+    if((plex = stdmkchild(argv + 1)) < 0) {
+       flog(LOG_ERR, "could not spawn root multiplexer: %s", strerror(errno));
+       return(1);
+    }
     if((fd = listensock6(8080)) < 0) {
        flog(LOG_ERR, "could not listen on IPv6: %s", strerror(errno));
        return(1);