CSAPP - Proxy Lab

Part 1 Implementing a sequential web proxy

设计:proxy的任务是接受客户端的请求,将请求预处理(uri解析、host解析)后发送给真正的服务器

构建Proxy

main函数负责创建与客户端连接的socket,等待连接请求

int main(int argc, char* argv[])
{
    int listenfd, connfd;
    char hostname[MAXLINE], port[MAXLINE];
    socklen_t clientlen;
    struct sockaddr_storage clientaddr;

    if (argc != 2) {
        fprintf(stderr, "usage: %s \n", argv[0]);
        exit(1);
    }

    listenfd = Open_listenfd(argv[1]);
    while (1) {
        clientlen = sizeof(clientaddr);
        connfd = Accept(listenfd, (SA *)&clientaddr, &clientlen);

        Getnameinfo((SA *) &clientaddr, clientlen, hostname, MAXLINE,
                    port, MAXLINE, 0);
        printf("Accepted connection from (%s, %s)\n", hostname, port);

        proxying(connfd);
        Close(connfd);
    }
}

proxying函数是整个代理流程的抽象,读取客户端的请求,创建与服务器的连接socket(proxyfd),转发response。

void proxying(int fd) {
    char buf[MAXLINE], method[MAXLINE], uri[MAXLINE], version[MAXLINE];
    char hostname[MAXLINE], port[MAXLINE], query[MAXLINE];
    rio_t rio;

    Rio_readinitb(&rio, fd);
    if (!Rio_readlineb(&rio, buf, MAXLINE)) // read request line
        return;
    printf("Request line:%s", buf);

    char host[MAXLINE];  // request header host
    read_requesthdrs(&rio, host);  // read request header

    sscanf(buf, "%s %s %s", method, uri, version);
    parseuri(uri, hostname, port, query);
    if (strlen(host) == 0) {
        strcpy(host, hostname);
    }

    // open connection to server
    int proxyfd = Open_clientfd(hostname, port);

    // send request to server
    send_requesthdrs(proxyfd, method, query, host);

    // Read response from server
    forward_response(proxyfd, fd);

    Close(proxyfd);
}

其他帮助函数

int parseuri(char* uri, char* hostname, char* port, char* query) {
    char *result = uri;
    char delim = '/';
    result = index(uri, delim);
    result += 2;  // skip '//'
    char *start = result;
    result = index(result, delim);

    if (!result) { // 不规范的uri,如 http://localhost:8080
        strcpy(hostname, uri);
        strcpy(query, &delim);
    } else {
        strncpy(hostname, start, result - start);
        hostname[result - start] = '\0';
        strcpy(query, result);
    }

    start = index(hostname, ':');
    if (start) {
        hostname[start - hostname] = '\0';
        start++; // skip ':'
        strcpy(port, start);
    } else {
        strcpy(port, "80");
    }
    Dprintf("hostname=%s, port=%s, query=%s\n", hostname, port, query);
    return 1;
}

void read_requesthdrs(rio_t *rp, char* host)
{
    char buf[MAXLINE];

    Rio_readlineb(rp, buf, MAXLINE);
    printf("%s", buf);
    while(strcmp(buf, "\r\n")) {          //line:netp:readhdrs:checkterm
        Rio_readlineb(rp, buf, MAXLINE);
        printf("%s", buf);
        char* p = index(buf, ':');
        if (strncmp("Host", buf, p-buf) == 0) {
            strcpy(host, p+1);
        }
    }
    return;
}

void send_requesthdrs(int clientfd, char* method, char* query, char* host) {
    char buf[MAXLINE];

    sprintf(buf, "%s %s HTTP/1.0\r\n", method, query);
    // send request to server
    sprintf(buf, "%sHost: %s\r\n", buf, host);
    strcat(buf, user_agent_hdr);
    strcat(buf, "Connection: close\r\n");
    strcat(buf, "Proxy-Connection: close\r\n\r\n");
    Rio_writen(clientfd, buf, strlen(buf));

    Dprintf("%s", buf);
}

// read from proxyfd and forward to clientfd
void forward_response(int proxyfd, int clientfd) {
    rio_t rio;
    int n;
    char buf[MAXLINE];
    char response_header[MAXLINE];
    int flag = 1;

    Rio_readinitb(&rio, proxyfd);
    while ((n = Rio_readlineb(&rio, buf, MAXLINE)) != 0) {
        Rio_writen(clientfd, buf, n);
        // for debug print
        if (flag && strcmp(buf, "\r\n") == 0) {
            Fputs("Response header:\n", stdout);
            Fputs(response_header, stdout);
            flag = 0;
        } else if (flag == 1) {
            strcat(response_header, buf);
        }
    }
}

使用curl测试代理服务器

$ curl -v --proxy http://localhost:34240 http://localhost:34241/godzilla.jpg --output god.jpg

对于下载文件的测试,使用curl的--output 选项,下载到指定的file文件。

Part 2 多线程代理服务器

比较简单,把proxying函数放到线程中执行就行了。记得设置为分离式线程。

int main(int argc, char* argv[])
{
    // omit

    listenfd = Open_listenfd(argv[1]);
    while (1) {
        clientlen = sizeof(clientaddr);
        connfdp = Malloc(sizeof(int));
        *connfdp = Accept(listenfd, (SA *)&clientaddr, &clientlen);

        Getnameinfo((SA *) &clientaddr, clientlen, hostname, MAXLINE,
                    port, MAXLINE, 0);
        printf("Accepted connection from (%s, %s)\n", hostname, port);

        Pthread_create(&tid, NULL, thread, connfdp);
    }
}

void *thread(void *vargp) {
    int connfd = *((int *)vargp);
    Pthread_detach(Pthread_self());
    Free(vargp);

    proxying(connfd);
    Close(connfd);

    return NULL;
}

Part 3 缓存实现

遇到问题:简单地实现后,发现第一次通过代理服务器下载的图片可以打开,再次向代理服务器请求图片,用curl获取缓存的图片保存到本地,结果显示格式损坏,无法打开。

解决:原来是我在接受response body的时候,用的是strcpy,strcpy功能是将一个以null结尾的字符串复制到另一个以null结尾的字符串末尾,由于图片数据中可能存在全0字节,所以导致复制的数据缺失。正确做法是使用memcpy,指定要复制的字节数。

Cache实现

使用了链表和字符串hash作为快速查找对应uri的对象,cache替换算法使用的是头部插入、尾部淘汰的简单LRU策略。使用pthread_rwlock_t类型作为读写锁控制。

Cache.h

typedef struct object{
    int size;
    char* data;
    uint32_t uriHash;
    char response_hdr[MAXLINE];
    struct object* next;
} Object;

typedef struct {
    Object* objs;
    int currentSize;
    pthread_rwlock_t rwlock;
} Cache;


void init_cache(Cache* cache);
void cache_object(Cache* cache, char* uri, char* header, char* data, int size);
Object * get_object(Cache* cache, char* uri);
void destruct_cache(Cache* cache);

Cache.c

#include "cache.h"

static uint32_t hash_string(const char * str, uint32_t len) {
    uint32_t hash=0;
    uint32_t i;
    for (i=0;idata);
    obj->data = NULL;
    free(obj);
}

static void remove_cache_object(Cache* cache) {
    Object *ptr = cache->objs;
    if (!ptr->next) {
        return;
    }
    while (ptr->next->next) { // find the pre-tail node
        ptr = ptr->next;
    }

    cache->currentSize -= ptr->next->size;
    free_object(ptr->next);
}

static void insert_cache(Cache* cache, Object* obj) {
    while (MAX_CACHE_SIZE - cache->currentSize < obj->size) {
        remove_cache_object(cache);
    }
    obj->next = cache->objs;
    cache->objs = obj;
    cache->currentSize += obj->size;
}

void init_cache(Cache* cache) {
    cache->currentSize = 0;
    cache->objs = NULL;
    pthread_rwlock_init(&cache->rwlock, NULL);
}

void cache_object(Cache* cache, char* uri, char* header, char* data, int size) {
    if (size > MAX_OBJECT_SIZE) {
        return;
    }
    pthread_rwlock_wrlock(&cache->rwlock);
    Object *obj = (Object *) Malloc(sizeof(Object));
    obj->size = size;
    obj->data = data;
    obj->uriHash = hash_string(uri, strlen(uri));
    strcpy(obj->response_hdr, header);
    obj->next = NULL;

    insert_cache(cache, obj);
    pthread_rwlock_unlock(&cache->rwlock);
}

Object * get_object(Cache* cache, char* uri) {
    pthread_rwlock_rdlock(&cache->rwlock);
    uint32_t hash = hash_string(uri, strlen(uri));
    Object *ptr = cache->objs;
    while (ptr != NULL) {
        if (ptr->uriHash == hash) {
            pthread_rwlock_unlock(&cache->rwlock);
            return ptr;
        } else {
            ptr = ptr->next;
        }
    }
    pthread_rwlock_unlock(&cache->rwlock);
    return NULL;
}

void destruct_cache(Cache* cache){
    Object *ptr, *prev;
    prev = cache->objs;
    ptr = prev->next;
    while (prev) {
        free_object(prev);
        prev = ptr;
        if (ptr) {
            ptr = ptr->next;
        } else {
            break;
        }
    }
    pthread_rwlock_destroy(&cache->rwlock);
}

proxying处理逻辑

void proxying(int fd) {
    char buf[MAXLINE], method[MAXLINE], uri[MAXLINE], version[MAXLINE];
    char hostname[MAXLINE], port[MAXLINE], query[MAXLINE];
    rio_t rio;
    int is_static;

    Rio_readinitb(&rio, fd);
    if (!Rio_readlineb(&rio, buf, MAXLINE)) // read request line
        return;
    printf("Request line:%s", buf);

    char host[MAXLINE];  // request header host
    read_requesthdrs(&rio, host);  // read request header

    sscanf(buf, "%s %s %s", method, uri, version);

    is_static = parseuri(uri, hostname, port, query);
    if (strlen(host) == 0) {
        strcpy(host, hostname);
    }

    Object *obj = NULL;
    if (is_static) {
        obj = get_object(&cache, uri);
    }

    if (obj == NULL) {
        // open connection to server
        int proxyfd = Open_clientfd(hostname, port);
        // send request to server
        send_requesthdrs(proxyfd, method, query, host);
        // Read response from server and forward to client
        if (is_static) {
            forward_response(proxyfd, fd, uri);
            Dprintf("Cache current size=%d\n", cache.currentSize);
        } else {
            forward_dynamic_response(proxyfd, fd);
        }
        Close(proxyfd);
    } else {
        Dprint("Using cached object.\n");
        forward_cached_response(obj, fd);
    }
}

void forward_response(int proxyfd, int clientfd, char *uri) {
    rio_t rio;
    int n;
    char buf[MAXLINE];
    char response_header[MAXLINE], *body, *bodyPtr;
    int length; // response body length
    int flag = 1, success = 0;

    // parse response header
    Rio_readinitb(&rio, proxyfd);
    n = Rio_readlineb(&rio, buf, MAXLINE);
    strcat(response_header, buf);
    if (strstr(buf, "OK") != NULL) {
        success = 1;
    }
    Rio_writen(clientfd, buf, n);

    while ((n = Rio_readlineb(&rio, buf, MAXLINE)) != 0) {
        Rio_writen(clientfd, buf, n);
        if (flag == 0 && success) {
            memcpy(bodyPtr, buf, n);
            bodyPtr += n;
        }
        // for debug print
        if (flag && strcmp(buf, "\r\n") == 0) {
            strcat(response_header, buf);
            Fputs("Response header:", stdout);
            Fputs(response_header, stdout);
            flag = 0;
        } else if (flag == 1) {
            strcat(response_header, buf);
            char *ptr = index(buf, ':'), *end;
            if (strncmp(buf, "Content-length", ptr-buf) == 0) {
                length = (int)strtol(ptr+1, &end, 10);
                body = Malloc(length * sizeof(char));
                bodyPtr = body;
            }
        }
    }
    if (success) {
        cache_object(&cache, uri, response_header, body, length);
        Dprintf("Data length=%lu\n", strlen(body));
    } else {
        Dprint("Response failed. Not cache.\n");
    }
}

void forward_dynamic_response(int proxyfd, int clientfd) {
    rio_t rio;
    int n;
    char buf[MAXLINE];
    int flag = 1;

    Rio_readinitb(&rio, proxyfd);
    printf("Response header:\n");
    while ((n = Rio_readlineb(&rio, buf, MAXLINE)) != 0) {
        Rio_writen(clientfd, buf, n);
        // for debug print
        if (flag && strcmp(buf, "\r\n") == 0) {
            flag = 0;
        } else if (flag == 1) {
            Fputs(buf, stdout);
        }
    }
}

void forward_cached_response(Object* obj, int clientfd) {
    Rio_writen(clientfd, obj->response_hdr, strlen(obj->response_hdr));
    Rio_writen(clientfd, obj->data, obj->size);
}

void sigint_handler(int sig) {
    destruct_cache(&cache);
    Dprint("Cache freed.\n");
    exit(0);
}