0


[C++]TinyWebServer

TinyWebServer

文章目录

1 主体框架

客户端如果想和服务器通信,首先要和服务器建立TCP连接,然后发送HTTP请求。服务端接收并处理HTTP请求,然后发送HTTP响应。

在这里插入图片描述

服务器采用单Reactor多线程模式,主线程使用IO多路复用接口监听事件,收到事件后将其分发。连接事件直接处理,而读写事件由线程池负责处理。

Reactor模式介绍:https://xiaolincoding.com/os/8_network_system/reactor.html#单-reactor-多线程-多进程

在这里插入图片描述

项目主要包含以下模块:

  • Server层-基于EPOLL的I/O多路复用和Reactor网络模式
  • HTTP处理层-解析HTTP请求并处理,生成返回的HTTP响应
  • 日志系统
  • 线程池
  • 数据库连接池
  • 定时器
  • 缓冲区-暂存缓冲数据,读写socket

TinyWebServer
├── build
│   ├── bin
│   │   └── server
│   └── Makefile
├── CMakeLists.txt
├── config.ini
├── log
├── Makefile
├── resources
├── webbench-1.5
│   ├── Makefile
│   ├── socket.c
│   ├── webbench
│   ├── webbench.c
│   └── webbench.o
└── webserver
    ├── buffer
    │   ├── Buffer.cpp
    │   └── Buffer.h
    ├── http
    │   ├── HttpResponse.cpp
    │   ├── HttpResponse.h
    │   ├── HttpWork.cpp
    │   ├── HttpWork.h
    │   ├── ParseHttpRequest.cpp
    │   └── ParseHttpRequest.h
    ├── lib
    │   └── inih-r58
    │       ├── cpp
    │       │   ├── INIReader.cpp
    │       │   └── INIReader.h
    │       ├── ini.c
    │       └── ini.h
    ├── log
    │   ├── Log.cpp
    │   ├── Log.h
    │   ├── LogLevel.h
    │   └── LogQueue.h
    ├── main.cpp
    ├── pool
    │   ├── sqlconnpool.cpp
    │   ├── sqlconnpool.h
    │   ├── sqlconnRAII.h
    │   ├── ThreadPool.cpp
    │   └── ThreadPool.h
    ├── server
    │   ├── Epoll.cpp
    │   ├── Epoll.h
    │   ├── Server.cpp
    │   └── Server.h
    ├── timer
    │   ├── Timer.cpp
    │   └── Timer.h
    └── utils
        └── Utils.h

2 Buffer

为了高效便捷的实现数据的存取,我们自定义了一个缓冲区数据结构,其主要功能是实现数据的读取、写入和空间自增。下图是缓冲区的结构图,我们使用

vector<char>

作为最基本的数据结构,实现了一个队列结构的缓冲区,并定义了三个指针,分别是头指针、读指针和写指针(这里的指针都使用下标代替,并非真正的指针)。

在这里插入图片描述

2.1 向Buffer写入数据

写指针被初始化为0,指向

vector

的首元素。写入数据时,需要提供待写入字符串的首地址和长度。待写入字符串被拷贝到写指针指向的位置,此时可能会出现空间不够的情况。

  • 待写入数据长度小于等于预留区域和空闲区域长度之和:将数据区域搬运到头指针指向的位置,并更新读指针和写指针
  • 待写入数据长度大于预留区域和空闲区域长度之和:利用vector的动态扩容机制,增大缓冲区长度。

2.2 从Buffer读取数据

读指针被初始化为0,从读指针开始读取字符串,直到写指针为止。

2.3 动态扩容

当预留区域和空闲区域加在一起也不够写下新的数据时,需要对缓冲区进行扩容。

vector.resize()

函数将被调用,从而实现了扩充容量。

2.4 从socket中读取数据

由于从socket读取的数据长度未知,直接向buffer中写入的数据的长度可能会超过

vector

的最大容量而引发错误,而增大

buffer

的初始容量又会浪费资源。因此,可以借用一个大容量的栈区作为缓冲。

buffer

和栈区同时接收数据,之后再将栈区中的数据写入

buffer

中,这样便可巧妙地解决问题。

2.5 具体实现

#ifndef TINYWEBSERVER_BUFFER_H
#define TINYWEBSERVER_BUFFER_H
#include <vector>
#include <atomic> // atomic
#include <sys/uio.h> // iovec readv
#include <cstring> // errno
#include <iostream>
#include <cassert> // assert
#include <unistd.h> // write

class Buffer {
private:
    std::vector<char> buffer_;
    std::atomic<size_t> readIdx_;
    std::atomic<size_t> writeIdx_;
    int STACK_LEN;

public:
    explicit Buffer(int init_size=1024, int stack_len=4096);
    ~Buffer()=default;

    size_t getContentLen();

    size_t getBufferLen();

    size_t getLeftLen();

    size_t getRealLeftLen();

    char* getReadPtr();

    const char *getConstReadPtr();

    char* getWritePtr();

    ssize_t readFd(int fd, int* Errno);

    ssize_t writeFd(int fd, int* Errno);

    void append(const char* str, size_t len);

    void append(const std::string& str);

    void addWriteIdx(size_t len);

    void addReadIdx(size_t len);

    void addReadIdxUntil(const char* ed);
    // memset缓冲区
    void resetBuffer();

    std::string getStringAndReset();
private:
    char* getBeginPtr();
    bool confirmSpace(size_t len);
};

#endif //TINYWEBSERVER_BUFFER_H
#include "Buffer.h"

Buffer::Buffer(int init_size, int stack_len):buffer_(init_size), readIdx_(0), writeIdx_(0), STACK_LEN(stack_len){}

// buffer中数据长度
size_t Buffer:: getContentLen() {
    return writeIdx_ - readIdx_;
}
// buffer的实际长度
size_t Buffer::getBufferLen() {
    return buffer_.size();
}
// 返回buffer的剩余空间
size_t Buffer::getLeftLen() {
    return getBufferLen() - writeIdx_;
}
// 返回buffer包含预留空间真正剩下的空间
size_t Buffer::getRealLeftLen() {
    return getBufferLen() - (writeIdx_ - readIdx_);
}
// 返回读指针
char *Buffer::getReadPtr() {
    return &buffer_[readIdx_];
}
//返回写指针
char *Buffer::getWritePtr() {
    return &buffer_[writeIdx_];
}

// 返回buffer首元素的指针
char* Buffer::getBeginPtr() {
    return &buffer_[0];
}

// 移动写指针
void Buffer::addWriteIdx(size_t len) {
    writeIdx_ += len;
}
// 移动读指针
void Buffer::addReadIdx(size_t len) {
    assert(len <= getContentLen());
    readIdx_ += len;
}

// 增加读指针移到ed位置
void Buffer::addReadIdxUntil(const char *ed) {
    assert(getReadPtr() <= ed && ed <= getWritePtr());
    addReadIdx(ed - getReadPtr());
}

ssize_t Buffer::readFd(int fd, int* Errno) {
    char stack_buf[STACK_LEN];
    iovec iv[2];
    size_t leftLen = getLeftLen();
    iv[0].iov_base = getWritePtr();
    iv[0].iov_len = leftLen;
    iv[1].iov_base = stack_buf;
    iv[1].iov_len = STACK_LEN;

    ssize_t len = readv(fd, iv, 2);
    if (len < 0) {
        // 记录报错信息
        *Errno = errno;
        return len;
    }
    // 刚好将buffer填满
    else if (static_cast<size_t>(len) <= leftLen) {
        addWriteIdx(len);
    }
    // 读入的数据超过buffer
    else {
        writeIdx_ = getBufferLen();
        // 将栈区的数据复制到buffer中
        append(stack_buf, len - static_cast<ssize_t>(leftLen));
    }
    return len;
}

ssize_t Buffer::writeFd(int fd, int* Errno) {
    ssize_t len = write(fd, getReadPtr(), getContentLen());
    if (len < 0) {
        *Errno = errno;
        return len;
    }
    addReadIdx(len);
    return len;
}

// 添加char[]到buffer中
void Buffer::append(const char* str, size_t len) {
    assert(str);
    confirmSpace(len);
    std::copy(str, str+len, getWritePtr());
    addWriteIdx(len);
}
// 添加string到buffer中
void Buffer::append(const std::string &str) {
    append(str.c_str(), str.length());
}

// 将buffer中内容转换成string,并清空buffer
std::string Buffer::getStringAndReset() {
    std::string str(getReadPtr(), getWritePtr());
    resetBuffer();
    return str;
}
// 重置buffer
void Buffer::resetBuffer() {
    writeIdx_ = 0;
    readIdx_ = 0;
    memset(&buffer_[0], 0, buffer_.size());
}

// 分配空间,扩容
bool Buffer::confirmSpace(size_t len) {
    // 剩余空间能够满足写入len
    if (getLeftLen() >= len) {
        return false;
    }
        // 不够Len,但是能够借用预留空间满足要求
    else if (getRealLeftLen() >= len) {
        auto contentLen = getContentLen();
        std::copy(getBeginPtr() + readIdx_, getBeginPtr() + writeIdx_, getBeginPtr());
        readIdx_ = 0;
        writeIdx_ = contentLen;
        assert(contentLen == getContentLen());
    }
        // 即使挪动也不够空间,需要对vector扩容
    else {
        buffer_.resize(writeIdx_ + len + 1);
    }
    assert(getLeftLen() >= len);
    return true;
}

const char *Buffer::getConstReadPtr() {
    return &buffer_[readIdx_];
}

3 日志系统

日志系统是实现整个webserver项目的首要前提,利用日志可以方便地调试代码,记录输出。

为了使输出的日志清晰明了,日志信息被划分成了不同的等级:

  • DEBUG
  • INFO
  • WARN
  • ERROR
  • FATAL

服务器初始化时提供了一个日志等级的参数,只有大于等于该等级的日志条目才会输出。

该日志系统主要使用异步方式写入日志信息,由一个队列负责维护要输出的日志信息。其他线程要打印日志时,调用函数将内容插入到队列中;日志输出线程负责从队列中取出日志信息,并将其写入日志文件中。

3.1 生产者-消费者模型

在上述工作流程中,其他线程和输出线程构成一个生产者-消费者模型。其他线程在队列不满的情况下,插入日志信息并通知日志输出线程取出日志信息,否则挂起等待;而日志输出线程在队列不空的情况,从队列中取出日志信息并通知其他线程插入日志信息,否则挂起等待。

为了同步其他线程和日志输出线程,可以使用条件变量。

3.2 数据一致

由于该日志系统会被其他不同的线程调用,需要保证同一时间只有一个线程访问日志队列。可能会出现以下竞态情况:

  • 日志队列中插入日志和取出日志时,对日志队列的访问
  • 其他线程要插入日志时,会在buffer内构造日志信息

在这里插入图片描述

3.3 代码

#ifndef TINYWEBSERVER_LOG_H
#define TINYWEBSERVER_LOG_H

#include <string>
#include <thread>
#include <mutex>
#include <cstdarg>
#include <sys/time.h>
#include "../buffer/Buffer.h"
#include "LogQueue.h"
#include "../utils/Utils.h"
#include "LogLevel.h"

// 日志输出位置
enum LogTarget {
    LOG_TARGET_NONE = 0,
    LOG_TARGET_CONSOLE = 1,
    LOG_TARGET_FILE = 2
};

class Log {
private:
    const char* saveDir_; // 日志存储路径
    char* filename_; // 初始化提供的文件名
    const char* suffix_; // 日志文件名后缀
    std::unique_ptr<LogQueue<std::string>> log_Queue_; // 日志队列
    std::unique_ptr<std::thread> workThread_; // 处理写日志的线程
    FILE *fp_; // 日志文件描述符
    LogTarget target_; // 日志文件输出位置
    LogLevel::value logLevel_; // 日志级别
    std::mutex mtx_;
    Buffer buf_; // 缓冲区
    bool isRun_;
    unsigned long long logCnt;
    bool isAsync_;
    static const int MAX_LINES = 50000;
    static const size_t MAX_FILENAME_LEN = 50; // 最大文件名限制
public:

    void init(LogTarget target, const char* save_dir, const char* suffix, LogLevel::value logLevel,
              size_t maxQueueSize =  1024); // 初始化日志系统
    static void asyncWriteLogThread(); // 工作线程将日志异步写入文件的函数
    bool initLogFile(); // 初始化日志文件
    // 外部调用接口,输出不同类型的日志信息
    void addLog(LogLevel::value type, const char *format, ...);
    bool isRun() const { return isRun_; }

    // 外部获取实例的接口
    static Log* getInstance();
    LogLevel::value getLevel() { return logLevel_; };
    void flush();
    void AsyncWrite_();
    LogTarget getTarget() {
        return  target_;
    }

private:
    Log();
    ~Log();
    void setEntryTime(); // 设置日志条目时间头
    void setEntryType(LogLevel::value t); // 设置日志条目类型
    void setEntryMsg(const std::string &msg);
    void appendEntry(const std::string& entry);
};

#define LOG_BASE(level, format, ...) \
    do {\
        Log* l = Log::getInstance();\
        if (l->isRun() && l->getLevel() <= level && l->getTarget() != LOG_TARGET_NONE) {\
            l->addLog(level, format, ##__VA_ARGS__); \
            l->flush();\
        }\
    } while(0);

#define LOG_DEBUG(format, ...) do {LOG_BASE(LogLevel::value::DEBUG, format, ##__VA_ARGS__)} while(0);
#define LOG_INFO(format, ...) do {LOG_BASE(LogLevel::value::INFO, format, ##__VA_ARGS__)} while(0);
#define LOG_WARN(format, ...) do {LOG_BASE(LogLevel::value::WARN, format, ##__VA_ARGS__)} while(0);
#define LOG_ERROR(format, ...) do {LOG_BASE(LogLevel::value::ERROR, format, ##__VA_ARGS__)} while(0);
#define LOG_FATAL(format, ...) do {LOG_BASE(LogLevel::value::FATAL, format, ##__VA_ARGS__)} while(0);
#endif //TINYWEBSERVER_LOG_H
#include "Log.h"

Log::Log() {
    saveDir_ = nullptr;
    filename_ = nullptr;
    suffix_ = nullptr;
    log_Queue_ = nullptr;
    workThread_ = nullptr;
    isRun_ = true;
    fp_ = nullptr;
    target_ = LOG_TARGET_CONSOLE;
    logCnt = 0;
}

Log::~Log() {
    printf("close logging...\n");
    if (log_Queue_->size()) {
        sleep(2);
    }
    isRun_ = false;
    fflush(fp_);
    if (fp_ != nullptr) {
        fclose(fp_);
        fp_ = nullptr;
    }
    delete[] filename_;
}

Log* Log::getInstance() {
    static Log log_;
    return &log_;
}

void
Log::init(LogTarget target, const char *save_dir, const char *suffix, LogLevel::value logLevel,
          size_t maxQueueSize) {
    saveDir_ = save_dir;
    suffix_ = suffix;
    logLevel_ = logLevel;
    filename_ = new char(MAX_FILENAME_LEN);
    target_ = target;
    if (maxQueueSize > 0) {
        isAsync_ = true;
        if (!log_Queue_) {
            std::unique_ptr<LogQueue<std::string>> q(new LogQueue<std::string>(maxQueueSize));
            log_Queue_ = std::move(q);
            std::unique_ptr<std::thread> t(new std::thread(asyncWriteLogThread));
            workThread_ = std::move(t);
        }
    } else {
        isAsync_ = false;
    }
    if (!initLogFile()) {
        printf("start loging failed...\n");
        return ;
    }
}

void Log::asyncWriteLogThread() {
    Log::getInstance()->AsyncWrite_();
}

bool Log::initLogFile() {
    if (target_ == LOG_TARGET_CONSOLE) {
        fp_ = stdout;
    } else if (target_ == LOG_TARGET_FILE){
        char time_str[25];
        util::Date::getDateTimeByFormat(time_str, 25, "%Y_%m_%d_%H_%M_%S");
        snprintf(filename_, MAX_FILENAME_LEN, "%s%s", time_str, suffix_);
        fp_ = util::File::createFile(saveDir_, filename_);
        if (fp_ == nullptr) {
            return false;
        }
    } else {
        fp_ = nullptr;
        return true;
    }
    printf("start logging...\n");
    return true;
}

void Log::setEntryTime() {
    char time_str[25];
    util::Date::getDateTime(time_str, 25);
    size_t str_len = strlen(time_str);
    time_str[str_len] = ' ';
    buf_.append(time_str, str_len + 1); // 追加空格
}

void Log::setEntryType(LogLevel::value t) {
     buf_.append(LogLevel::toString(t) + std::string(" "));
}

void Log::setEntryMsg(const std::string &msg) {
    buf_.append(msg + std::string("\n"));
}

void Log::appendEntry(const std::string &entry) {
    log_Queue_->push(entry);
}

void Log::flush() {
    if (isAsync_) {
        log_Queue_->flush();
    }
    fflush(fp_);
}

void Log::addLog(LogLevel::value type, const char *format, ...) {
    struct timeval now = {0, 0};
    gettimeofday(&now, nullptr);
    time_t tSec = now.tv_sec;
    struct tm *sysTime = localtime(&tSec);
    struct tm t = *sysTime;
    va_list vaList;

    // 向buf_中添加数据,如果多线程访问需要确保只有一个线程访问buf_
    std::unique_lock<std::mutex> locker(mtx_);
    int n = snprintf(buf_.getWritePtr(), 128, "%d-%02d-%02d %02d:%02d:%02d.%06ld ",
                     t.tm_year + 1900, t.tm_mon + 1, t.tm_mday,
                     t.tm_hour, t.tm_min, t.tm_sec, now.tv_usec);
    buf_.addWriteIdx(n);
    setEntryType(type);

    va_start(vaList, format);
    int m = vsnprintf(buf_.getWritePtr(), buf_.getLeftLen(), format, vaList);
    va_end(vaList);
    buf_.addWriteIdx(m);
    buf_.append("\n\0", 2);

    if (isAsync_ && log_Queue_ && !log_Queue_->full()) {
        log_Queue_->push(buf_.getStringAndReset());
    } else {
        fputs(buf_.getReadPtr(), fp_); // 这一部分不确定作用
    }
}

void Log::AsyncWrite_() {
    std::string str;
    while (log_Queue_->pop(str)) {
        std::lock_guard<std::mutex> locker(mtx_);
        fputs(str.c_str(), fp_);
    }
}
#ifndef TINYWEBSERVER_LOGQUEUE_H
#define TINYWEBSERVER_LOGQUEUE_H

#include <queue>
#include <cassert>
#include <mutex>
#include <condition_variable>
template <typename T>
class LogQueue {
private:
    std::deque<T> log;
    size_t capacity;
    bool deleted;
    std::mutex mtx_;
    std::condition_variable consumer_cv;
    std::condition_variable producer_cv;
public:
    explicit LogQueue(size_t c);
    ~LogQueue();
    void push(const T &data);
    bool pop(T &item);
    size_t size();
    bool empty();
    void onDelete();
    bool full();
    void flush();
};

template<typename T>
void LogQueue<T>::flush() {
    consumer_cv.notify_one();
}

template<typename T>
bool LogQueue<T>::full() {
    std::lock_guard<std::mutex> locker(mtx_);
    return log.size() >= capacity;
}

template<typename T>
LogQueue<T>::LogQueue(size_t c): capacity(c), deleted(false) {
    assert(c > 0);
}

template<typename T>
LogQueue<T>::~LogQueue() {
    onDelete();
}

template<typename T>
void LogQueue<T>::onDelete() {
    {
        std::lock_guard<std::mutex> locker(mtx_);
        deleted = true;
        log.clear();
    }
    consumer_cv.notify_one();
    producer_cv.notify_one();
}

template<typename T>
bool LogQueue<T>::empty() {
    std::lock_guard<std::mutex> locker(mtx_);
    return log.empty();
}

template<typename T>
size_t LogQueue<T>::size() {
    std::lock_guard<std::mutex> locker(mtx_);
    return log.size();
}
// 消费者读取日志
template<typename T>
bool LogQueue<T>::pop(T &item) {
    std::unique_lock<std::mutex> locker(mtx_);
    while(log.empty()) {
        consumer_cv.wait(locker);
        if (deleted) {
            return false;
        }
    }
    item = log.front();
    log.pop_front();
    producer_cv.notify_one();
    return true;
}

// 生产者插入日志
template<typename T>
void LogQueue<T>::push(const T &data) {
    std::unique_lock<std::mutex> locker(mtx_);
    // 直至log.size() <= capacity  缓冲区未满
    while(log.size() >= capacity) {
        producer_cv.wait(locker);
    }
    log.push_back(data);
    consumer_cv.notify_one();
}

#endif //TINYWEBSERVER_LOGQUEUE_H
#ifndef TINYWEBSERVER_LOGLEVEL_H
#define TINYWEBSERVER_LOGLEVEL_H
class LogLevel
{
public:

    enum class value
    {
        UNKNOWN =0,
        DEBUG,
        INFO,
        WARN,
        ERROR,
        FATAL
    };

    static const char *toString(value level)
    {
        switch (level)
        {
            case LogLevel::value::DEBUG: return "[DEBUG]:";
            case LogLevel::value::INFO: return  "[INFO] :";
            case LogLevel::value::WARN: return  "[WARN] :";
            case LogLevel::value::ERROR: return "[ERROR]:";
            case LogLevel::value::FATAL: return "[FATAL]:";
            case LogLevel::value::OFF: return   "[OFF]  :";
            default: return "UNKNOW";
        }
    }
};
#endif //TINYWEBSERVER_LOGLEVEL_H

4 定时器

客户端和服务器建立TCP连接后,客户端可能会不再发送数据,此时需要断开连接释放资源。在服务器初始化时指定超时时间,当客户端和服务器之间未发生通信的时间超过超时时间后,便关闭二者的连接。

定时器的数据结构基于小跟堆实现,堆顶是距离过期最近的连接。当客户端连接服务器后,向定时器内插入超时关闭连接事件。每当服务器收到来自客户端的请求时,更新定时器内的超时时间。当某个连接超时时,将其从堆顶取出,执行关闭连接回调函数。

4.1 调整堆中元素操作

定时器的堆基于vector动态扩容数组实现。以下定义了两个调整堆中元素的操作:

  • 向上调整:将某个结点不断地与其父结点比较交换,直到不能交换为止
  • 向下调整:将某个结点不断地与其子结点比较交换,直到不能交换为止

4.2 堆的操作

4.2.1 增

向堆中插入元素,可以现将新插入的元素放入vector最后位置,并执行向上调整操作。

4.2.2 删

将堆顶元素与vector最后一个元素交换位置,并将记录元素个数的变量递减,然后执行向上调整操作。

4.2.3 改

利用元素与堆中下标的映射数组,找到在堆中的位置,然后分别执行向上调整和向下调整操作。

4.2.4 查

取出堆顶元素。

4.3 代码

#ifndef TINYWEBSERVER_TIMER_H
#define TINYWEBSERVER_TIMER_H
#include <vector>
#include <functional>
#include <cassert>
#include <chrono>
#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <atomic>
#include "../log/Log.h"
typedef std::function<void()> TimerCallback;
typedef std::chrono::high_resolution_clock Clock;
typedef std::chrono::milliseconds MS;
typedef Clock::time_point TimeStamp;

// 定时器结点
struct TimerNode {
    int id_; // 定时器id
    TimeStamp expires_; // 定时器过期时间点
    TimerCallback cb_; // 回调函数
    bool operator<(const TimerNode &tn) const {
        return expires_ < tn.expires_;
    }
    bool operator>(const TimerNode &tn) const {
        return expires_ > tn.expires_;
    }
};

// 堆定时器,存储定时事件
class Timer {
private:
    // 定时器堆,采用vector数组方式存储
    std::vector<TimerNode> heap_;
    std::unordered_map<int, size_t> ref_;  // 从id到heap中的下标映射,方便直接操作某个node
    std::atomic<size_t> si_{};

private:
    void up(size_t u); // 将某个结点向上调整的操作
    void down(size_t u); // 将某个结点向下调整的操作
    void pop(); // 删除堆顶的结点
    TimerNode &top(); // 获得堆顶的结点
    void del(size_t i); // 删除下标为i的结点 并执行回调函数
    void swap_(size_t t1, size_t t2); // 交换两个下标的位置

public:
    explicit Timer(size_t cnt);
    ~Timer();
    int getNextTick(); // 返回最近的定时器事件超时的间隔时间
    void reset(int id, int timeout); // 重新设置某个结点的过期时间
    void reset(int id, int timeout, TimerCallback &cb); // 重设某个结点的过期时间和回调函数
    void execCb(int id); // 执行id的回调函数
    void tick(); // 处理堆中过期的定时器
    void push(int id, int timeout, const TimerCallback &cb);

    bool empty();
};

#endif //TINYWEBSERVER_TIMER_H
#include "Timer.h"

#include <utility>

void Timer::up(size_t u) {
    while(u != 1 && heap_[u] < heap_[u/2]) {
        swap_(u, u / 2);
        u /= 2;
    }
}
// 从1开始 1 2 3 4 ... 1是根结点
void Timer::down(size_t u) {
    size_t t = u;
    if (u * 2 <= si_ && heap_[u*2] < heap_[t]) t = u * 2;
    if (u * 2 + 1 <= si_ && heap_[u * 2 + 1] < heap_[t]) t = u * 2 + 1;
    if (t != u) {
        swap_(u, t);
        down(t);
    }
}

void Timer::swap_(size_t t1, size_t t2) {
    assert(t1 >= 1 && t1 < heap_.size());
    assert(t2 >= 1 && t2 <= heap_.size());
    std::swap(heap_[t1], heap_[t2]);
    ref_[heap_[t1].id_] = t1;
    ref_[heap_[t2].id_] = t2;
}
// 删除顶部结点
void Timer::pop() {
    assert(si_ > 0);
    del(1);
}

TimerNode& Timer::top() {
    return heap_[1];
}

// 删除给定结点i,将其和最后一个结点交换,之后执行up和down操作
void Timer::del(size_t i) {
    assert(i > 0 && i <= si_);
    swap_(i, si_);
    -- si_;
    ref_.erase(heap_.back().id_);
    heap_.pop_back();
    up(i);
    down(i);
}
// 执行id的回调函数
void Timer::execCb(int id) {
    if (si_ == 0 || ref_.count(id) == 0) {
        return ;
    }
    auto idx = ref_[id];
    auto &node = heap_[idx];
    node.cb_();
    del(idx);
}

void Timer::push(int id, int timeout, const TimerCallback &cb) {
    assert(id >= 0);
    if (ref_.count(id)) {
        auto idx = ref_[id];
        auto &node = heap_[idx];
        node.expires_ = Clock::now() + MS(timeout);
        node.cb_ = cb;
        up(idx);
        down(idx);
    } else {
        LOG_INFO("增加计时时间 id: %d timeout: %d", id, timeout);
        heap_.push_back({id, MS(timeout) + Clock::now(), cb});
        ++ si_;
        ref_[id] = si_;
        up(si_);
    }
}

Timer::Timer(size_t cnt) {
//    Log::INFO("%s", "Timer start...");
    heap_.reserve(cnt + 1);
    heap_.emplace_back();
    si_ = 0;
}

void Timer::reset(int id, int timeout, TimerCallback &cb) {
    assert(id >= 0);
    auto idx = ref_[id];
    auto &node = heap_[idx];
    node.expires_ = Clock::now() + MS(timeout);
    node.cb_ = cb;
    down(idx);
    up(idx);
}

void Timer::reset(int id, int timeout) {
    assert(id >= 0);
    auto idx = ref_[id];
    auto &node = heap_[idx];
    node.expires_ = Clock::now() + MS(timeout);
    down(idx);
    up(idx);
}

void Timer::tick() {
    while(si_) {
        auto &node = top();
        if (std::chrono::duration_cast<MS>(node.expires_ - Clock::now()).count() > 0)
            break;
        LOG_INFO("timer %d is expired", node.id_);
        node.cb_();
        pop();
    }
}

bool Timer::empty() {
    return si_ == 0;
}

int Timer::getNextTick() {
    tick();
    size_t res = -1;
    if (si_) {
        res = std::chrono::duration_cast<MS>(top().expires_ - Clock::now()).count();
        if (res < 0) {
            res = 0;
        }

    }
    return res;
}

Timer::~Timer() {
    heap_.clear();
    ref_.clear();
}

5 线程池

由于线程的创建和销毁需要开销,频繁创建和销毁线程会影响服务器的性能。因此,服务器维护一个预先创建好的线程池。每当任务队列中有任务时,某个线程将其从中取出并执行。执行完成后,继续等待任务。

在这里插入图片描述

在这里使用一个条件变量,当任务队列为空时阻塞线程,当有任务插入队列时,通知线程执行任务。

5.1 代码

#ifndef TINYWEBSERVER_THREADPOOL_H
#define TINYWEBSERVER_THREADPOOL_H
#include <queue>
#include <functional>
#include <mutex>
#include <cassert>
#include <thread>
#include <condition_variable>
#include <unistd.h>
#include "../log/Log.h"
class ThreadPool {
private:
    struct Pool {
        std::mutex mtx_;
        bool isRun = true;
        std::condition_variable cv;
        std::queue<std::function<void()>>taskQueue_;
    };
    std::shared_ptr<Pool> pool_;

//private:
//    static void work(); // 工作函数,从任务队列中取出任务并执行
public:
    explicit ThreadPool(int max_thread_cnt);
    ThreadPool() = default;
    // 定义移动构造函数
    ThreadPool(ThreadPool&&) = default;

    ~ThreadPool();
    bool addTask(std::function<void()> &&f); // 外部调用接口
    void resetTaskQueue();
};

#endif //TINYWEBSERVER_THREADPOOL_H
#include "ThreadPool.h"

ThreadPool::ThreadPool(int max_thread_cnt): pool_(std::make_shared<Pool>()) {
    assert(max_thread_cnt > 0);
    for (int i = 0; i < max_thread_cnt; ++ i) {
        printf("init thread %d\n", i);
        std::thread([pool = pool_, i]{
            std::unique_lock<std::mutex> locker(pool->mtx_);
            while(true) {
                if (!pool->taskQueue_.empty()) {
//                    LOG_INFO("thread pool: thread %d process task", i);
                    // 有任务,开始干活
                    // 这个地方使用move,提高效率
                    auto task = std::move(pool->taskQueue_.front());
                    pool->taskQueue_.pop();
                    locker.unlock();
                    task(); // 处理任务
                    locker.lock();
                } else if (!pool->isRun) {
                    break;
                } else {
                    pool->cv.wait(locker);
                }
            }
        }).detach();
    }
}

ThreadPool::~ThreadPool() {
    if (static_cast<bool>(pool_))
    {
        {
            std::lock_guard<std::mutex> locker(pool_->mtx_);
            pool_->isRun = false;
        }
        pool_->cv.notify_all();
    }
}

void ThreadPool::resetTaskQueue() {
    std::queue<std::function<void()>> q;
    swap(q, pool_->taskQueue_);
}

bool ThreadPool::addTask(std::function<void()> &&f) {
    {
        std::lock_guard<std::mutex> locker(pool_->mtx_);
        pool_->taskQueue_.emplace(std::forward<std::function<void()>>(f));
    }
//    LOG_INFO("thead pool: %s", "add task");
    pool_->cv.notify_one();
    return true;
}

6 数据库连接池

和线程池类似,数据库的连接池由一个队列来维护与数据库的多个连接。初始化时,创建

n

个数据库连接并将其插入到队列中。在需要访问数据库时,从队列中取出一个连接进行数据库读写操作。在读写完数据后重新将连接插入到队列中。

6.1 RAII

RAII(Resource Acquisition Is Initialization,资源获取即初始化)是一种C++编程惯用法,用于管理资源(如内存、文件句柄、网络连接等),确保它们在对象的生命周期内得到正确的管理和释放。

  • 资源绑定到对象的生命周期: 资源在对象创建时被获取,并在对象销毁时被释放。构造函数负责获取资源,析构函数负责释放资源。
  • 自动管理:通过栈上对象的自动创建和销毁,避免手动管理资源的复杂性和潜在错误(如资源泄漏)。

数据库连接的管理采用RAII机制,可以简化资源管理。

6.2 代码

#ifndef SQLCONNPOOL_H
#define SQLCONNPOOL_H

#include <mysql/mysql.h>
#include <string>
#include <queue>
#include <mutex>
#include <semaphore.h>
#include <thread>
#include <cassert>
#include "../log/Log.h"

class SqlConnPool {
public:
    static SqlConnPool *Instance();

    MYSQL *GetConn();
    void FreeConn(MYSQL * conn);
    int GetFreeConnCount();

    void Init(const char* host, int port,
              const char* user,const char* pwd, 
              const char* dbName, int connSize);
    void ClosePool();

private:
    SqlConnPool();
    ~SqlConnPool();

    int MAX_CONN_;
    int useCount_;
    int freeCount_;

    std::queue<MYSQL*> connQue_;
    std::mutex mtx_;
    sem_t semId_;
};

#endif // SQLCONNPOOL_H
#include "sqlconnpool.h"
using namespace std;

SqlConnPool::SqlConnPool() {
    useCount_ = 0;
    freeCount_ = 0;
}

SqlConnPool* SqlConnPool::Instance() {
    static SqlConnPool connPool;
    return &connPool;
}

void SqlConnPool::Init(const char* host, int port,
            const char* user,const char* pwd, const char* dbName,
            int connSize = 10) {
    assert(connSize > 0);
    for (int i = 0; i < connSize; i++) {
        MYSQL *sql = nullptr;
        sql = mysql_init(sql);
        if (!sql) {
            LOG_ERROR("MySql init error!");
            assert(sql);
        }
        sql = mysql_real_connect(sql, host,
                                 user, pwd,
                                 dbName, port, nullptr, 0);
        if (!sql) {
            LOG_ERROR("MySql Connect error!");
        }
        connQue_.push(sql);
    }
    MAX_CONN_ = connSize;
    sem_init(&semId_, 0, MAX_CONN_);
}

MYSQL* SqlConnPool::GetConn() {
    MYSQL *sql = nullptr;
    if(connQue_.empty()){
        LOG_WARN("SqlConnPool busy!");
        return nullptr;
    }
    sem_wait(&semId_);
    {
        lock_guard<mutex> locker(mtx_);
        sql = connQue_.front();
        connQue_.pop();
    }
    return sql;
}

void SqlConnPool::FreeConn(MYSQL* sql) {
    assert(sql);
    lock_guard<mutex> locker(mtx_);
    connQue_.push(sql);
    sem_post(&semId_);
}

void SqlConnPool::ClosePool() {
    lock_guard<mutex> locker(mtx_);
    while(!connQue_.empty()) {
        auto item = connQue_.front();
        connQue_.pop();
        mysql_close(item);
    }
    mysql_library_end();        
}

int SqlConnPool::GetFreeConnCount() {
    lock_guard<mutex> locker(mtx_);
    return connQue_.size();
}

SqlConnPool::~SqlConnPool() {
    ClosePool();
}
#ifndef SQLCONNRAII_H
#define SQLCONNRAII_H
#include "sqlconnpool.h"

/* 资源在对象构造初始化 资源在对象析构时释放*/
class SqlConnRAII {
public:
    SqlConnRAII(MYSQL** sql, SqlConnPool *connpool) {
        assert(connpool);
        *sql = connpool->GetConn();
        sql_ = *sql;
        connpool_ = connpool;
    }
    
    ~SqlConnRAII() {
        if(sql_) { connpool_->FreeConn(sql_); }
    }
    
private:
    MYSQL *sql_;
    SqlConnPool* connpool_;
};

#endif //SQLCONNRAII_H

7 HTTP层处理

7.1 HTTP解析

HTTP请求的格式如下图所示:

在这里插入图片描述

HTTP请求报文遵循着规定的格式,我们只需要按要求即可准确解析。

报文分为三个部分:

  • 请求行:包含请求方法(GET,POST,…),请求url和HTTP版本。中间由空格分隔,最后有个\r\n
  • 请求头:包含若干个请求体,由key: value组成,末尾有\r\n。最后一个请求头最后有两个\r\n
  • 请求体(可有可无)
GET /index.html HTTP/1.1
Host: www.example.com
User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36
Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8
Accept-Language: en-US,en;q=0.5
Accept-Encoding: gzip, deflate, br
Connection: keep-alive

由于HTTP报文每一行最后都有一个

\r\n

,我们可以从缓冲区中搜索该字符串,将每一行截取出来再进行解析。整个过程由于分三个阶段进行,因此可使用状态机解决。为此,我们规定4个状态,分别是解析请求行、解析请求头、解析请求体和结束。在不同的状态执行不同的操作,来解析不同的内容。

7.1.1 代码


#ifndef TINYWEBSERVER_PARSEHTTPREQUEST_H
#define TINYWEBSERVER_PARSEHTTPREQUEST_H

#include <unordered_map>
#include <regex>
#include <unordered_set>
#include <mysql/mysql.h>
#include "../buffer/Buffer.h"
#include "../utils/Utils.h"
#include "../log/Log.h"
#include "../pool/sqlconnpool.h"
#include "../pool/sqlconnRAII.h"

class ParseHttpRequest {
public:
    // 当前的处理状态
    enum Status {
        PARSE_LINE,
        PARSE_HEADERS,
        PARSE_BODY,
        FINISH
    };
private:
    std::string version_; // HTTP版本
    std::unordered_map<std::string, std::string> headers_; // 头部字段
    Status state_ = PARSE_LINE;
    std::string url_;
    std::string method_;
    std::string body_;
    static const std::unordered_set<std::string> DEFAULT_HTML;
    static const std::unordered_map<std::string, int>DEFAULT_HTML_TAG;
    std::unordered_map<std::string, std::string> post_;
public:

    ParseHttpRequest();
    ~ParseHttpRequest();
    void init();

    bool parse(Buffer &buf);
    void parse_url();
    bool parseRequestLine(const std::string &request_line);
    void parseRequestHeader(const std::string &header_line);
    void parseRequestBody(const std::string &body);

    std::string &method();
    std::string &version();
    std::string &path();

    bool keepAlive();

    void parsePost();

    void parseFromUrlencoded();

    static bool userVerify(const std::string &name, const std::string &pwd, bool isLogin);

    static int convertHex(char ch);
};

#endif //TINYWEBSERVER_PARSEHTTPREQUEST_H
#include "ParseHttpRequest.h"
using namespace std;

const unordered_set<string> ParseHttpRequest::DEFAULT_HTML{
        "/index", "/register", "/login",
        "/welcome", "/video", "/picture", };

const unordered_map<string, int> ParseHttpRequest::DEFAULT_HTML_TAG {
        {"/register.html", 0}, {"/login.html", 1},  };

void ParseHttpRequest::init() {
    method_ = url_ = version_ = body_ = "";
    state_ = PARSE_LINE;
    headers_.clear();
    post_.clear();
}

bool ParseHttpRequest::keepAlive() {
    if(headers_.count("Connection") == 1) {
        return headers_.find("Connection")->second == "keep-alive" && version_ == "1.1";
    }
    return false;
}

bool ParseHttpRequest::parse(Buffer& buff) {
    const char CRLF[] = "\r\n";
    if(buff.getContentLen() <= 0) {
        return false;
    }
    while(buff.getContentLen() && state_ != FINISH) {
        const char* lineEnd = search(buff.getReadPtr(), buff.getWritePtr(), CRLF, CRLF + 2);
        std::string line(buff.getConstReadPtr(), lineEnd);
        switch(state_)
        {
            case PARSE_LINE:
                if(!parseRequestLine(line)) {
                    return false;
                }
                parse_url();
                break;
            case PARSE_HEADERS:
                parseRequestHeader(line);
                if(buff.getContentLen() <= 2) {
                    state_ = FINISH;
                }
                break;
            case PARSE_BODY:
                parseRequestBody(line);
                break;
            default:
                break;
        }
        if(lineEnd == buff.getWritePtr()) { break; }
        buff.addReadIdxUntil(lineEnd + 2);
    }
    LOG_DEBUG("[%s], [%s], [%s]", method_.c_str(), url_.c_str(), version_.c_str());
    return true;
}

void ParseHttpRequest::parse_url() {
    if(url_ == "/") {
        url_ = "/index.html";
    }
    else {
        for(auto &item: DEFAULT_HTML) {
            if(item == url_) {
                url_ += ".html";
                break;
            }
        }
    }
}

bool ParseHttpRequest::parseRequestLine(const string& line) {
    regex patten("^([^ ]*) ([^ ]*) HTTP/([^ ]*)$");
    smatch subMatch;
    if(regex_match(line, subMatch, patten)) {
        method_ = subMatch[1];
        url_ = subMatch[2];
        version_ = subMatch[3];
        state_ = PARSE_HEADERS;
        return true;
    }
    LOG_ERROR("RequestLine Error");
    return false;
}

void ParseHttpRequest::parseRequestHeader(const string& line) {
    regex patten("^([^:]*): ?(.*)$");
    smatch subMatch;
    if(regex_match(line, subMatch, patten)) {
        headers_[subMatch[1]] = subMatch[2];
    }
    else {
        state_ = PARSE_BODY;
    }
}

void ParseHttpRequest::parseRequestBody(const string& line) {
    body_ = line;
    parsePost();
    state_ = FINISH;
    LOG_DEBUG("Body:%s, len:%d", line.c_str(), line.size());
}

int ParseHttpRequest::convertHex(char ch) {
    if(ch >= 'A' && ch <= 'F') return ch -'A' + 10;
    if(ch >= 'a' && ch <= 'f') return ch -'a' + 10;
    return ch;
}

void ParseHttpRequest::parsePost() {
    if(method_ == "POST" && headers_["Content-Type"] == "application/x-www-form-urlencoded") {
        parseFromUrlencoded();
        if(DEFAULT_HTML_TAG.count(url_)) {
            int tag = DEFAULT_HTML_TAG.find(url_)->second;
            LOG_DEBUG("Tag:%d", tag);
            if(tag == 0 || tag == 1) {
                bool isLogin = (tag == 1);
                if(userVerify(post_["username"], post_["password"], isLogin)) {
                    url_ = "/welcome.html";
                }
                else {
                    url_ = "/error.html";
                }
            }
        }
    }
}

void ParseHttpRequest::parseFromUrlencoded() {
    if(body_.size() == 0) { return; }

    string key, value;
    int num = 0;
    int n = body_.size();
    int i = 0, j = 0;

    for(; i < n; i++) {
        char ch = body_[i];
        switch (ch) {
            case '=':
                key = body_.substr(j, i - j);
                j = i + 1;
                break;
            case '+':
                body_[i] = ' ';
                break;
            case '%':
                num = convertHex(body_[i + 1]) * 16 + convertHex(body_[i + 2]);
                body_[i + 2] = num % 10 + '0';
                body_[i + 1] = num / 10 + '0';
                i += 2;
                break;
            case '&':
                value = body_.substr(j, i - j);
                j = i + 1;
                post_[key] = value;
                LOG_DEBUG("%s = %s", key.c_str(), value.c_str());
                break;
            default:
                break;
        }
    }
    assert(j <= i);
    if(post_.count(key) == 0 && j < i) {
        value = body_.substr(j, i - j);
        post_[key] = value;
    }
}

bool ParseHttpRequest::userVerify(const string &name, const string &pwd, bool isLogin) {
    if(name.empty() || pwd.empty()) { return false; }
    LOG_INFO("Verify name:%s pwd:%s", name.c_str(), pwd.c_str());
    MYSQL* sql;
    SqlConnRAII give_me_a_name(&sql,  SqlConnPool::Instance());
    assert(sql);

    bool flag = false;
    char order[256] = { 0 };

    MYSQL_RES *res = nullptr;

    if(!isLogin) { flag = true; }
    /* 查询用户及密码 */
    snprintf(order, 256, "SELECT username, password FROM user WHERE username='%s' LIMIT 1", name.c_str());
    LOG_DEBUG("%s", order);

    if(mysql_query(sql, order)) {
        mysql_free_result(res);
        return false;
    }
    res = mysql_store_result(sql);
    mysql_num_fields(res);
    mysql_fetch_fields(res);

    while(MYSQL_ROW row = mysql_fetch_row(res)) {
        LOG_DEBUG("MYSQL ROW: %s %s", row[0], row[1]);
        string password(row[1]);
        /* 注册行为 且 用户名未被使用*/
        if(isLogin) {
            if(pwd == password) { flag = true; }
            else {
                flag = false;
                LOG_DEBUG("pwd error!");
            }
        }
        else {
            flag = false;
            LOG_DEBUG("user used!");
        }
    }
    mysql_free_result(res);

    /* 注册行为 且 用户名未被使用*/
    if(!isLogin && flag) {
        LOG_DEBUG("regirster!");
        bzero(order, 256);
        snprintf(order, 256,"INSERT INTO user(username, password) VALUES('%s','%s')", name.c_str(), pwd.c_str());
        LOG_DEBUG( "%s", order)
        if(mysql_query(sql, order)) {
            LOG_DEBUG( "Insert error!");
            flag = false;
        }
        flag = true;
    }
    SqlConnPool::Instance()->FreeConn(sql);
    LOG_DEBUG( "UserVerify success!!");
    return flag;
}

std::string &ParseHttpRequest::path(){
    return url_;
}
std::string &ParseHttpRequest::method() {
    return method_;
}

std::string &ParseHttpRequest::version() {
    return version_;
}

ParseHttpRequest::ParseHttpRequest() {
    init();
}

ParseHttpRequest::~ParseHttpRequest() {

}

7.2 HTTP响应

以下是HTTP响应报文的一个例子,主要包含响应行、响应头和响应体。

HTTP/1.1 200 OK
Date: Fri, 19 Jul 2024 10:00:00 GMT
Server: Apache/2.4.41 (Ubuntu)
Last-Modified: Mon, 28 Jun 2024 14:30:00 GMT
Content-Type: text/html; charset=UTF-8
Content-Length: 305
Connection: close

<!DOCTYPEhtml><html><head><title>Example Page</title></head><body><h1>Welcome to Example Page</h1><p>This is a sample HTML page.</p></body></html>

根据对HTTP请求的处理结果,生成相应的HTTP响应结果。

响应行中包含HTTP版本、响应状态码和摘要。

响应头中包含连接状态、返回的文件类型和长度。

响应体中包含返回的资源文件。

7.2.1 代码

#ifndef TINYWEBSERVER_HTTPRESPONSE_H
#define TINYWEBSERVER_HTTPRESPONSE_H
#include <unordered_map>
#include <sys/stat.h>
#include <fcntl.h>
#include <sys/mman.h>
#include "../utils/Utils.h"
#include "../buffer/Buffer.h"
#include "../log/Log.h"

// 构造HTTP响应报文
class HttpResponse {
private:
    static const std::unordered_map<std::string, std::string> SUFFIX_TYPE;
    static const std::unordered_map<int, std::string> CODE_PATH;
    static std::unordered_map<int, std::string> CODE;
    std::string srcDir_;
    std::string path_;
    bool keepAlive_{};
    int code_{};
    char* mmFile_{};
    struct stat mmFileStat_{};
public:

    HttpResponse() = default;
    ~HttpResponse() = default;

    static void addCors(Buffer &buf);

    static void ErrorContent(Buffer &buff, std::string &&message);

    void unmapFile();

    void makeResponse(Buffer &buf);

    void init(const std::string &srcDir, const std::string &path, bool isKeepAlive, int code);

    void ErrorHtml_();

    void AddStateLine_(Buffer &buf);

    void AddHeader_(Buffer &buf);

    void AddContent_(Buffer &buff);

    std::string GetFileType_();

    size_t fileLen() const;

    char* file();
};

#endif //TINYWEBSERVER_HTTPRESPONSE_H
#include "HttpResponse.h"

const std::unordered_map<std::string, std::string> HttpResponse::SUFFIX_TYPE = {
        { ".html",  "text/html" },
        { ".xml",   "text/xml" },
        { ".xhtml", "application/xhtml+xml" },
        { ".txt",   "text/plain" },
        { ".rtf",   "application/rtf" },
        { ".pdf",   "application/pdf" },
        { ".word",  "application/nsword" },
        { ".png",   "image/png" },
        { ".gif",   "image/gif" },
        { ".jpg",   "image/jpeg" },
        { ".jpeg",  "image/jpeg" },
        { ".au",    "audio/basic" },
        { ".mpeg",  "video/mpeg" },
        { ".mpg",   "video/mpeg" },
        { ".avi",   "video/x-msvideo" },
        { ".gz",    "application/x-gzip" },
        { ".tar",   "application/x-tar" },
        { ".css",   "text/css "},
        { ".js",    "text/javascript "},
};
std::unordered_map<int, std::string> HttpResponse::CODE = {
        { 200, "OK" },
        { 400, "Bad Request" },
        { 403, "Forbidden" },
        { 404, "Not Found" },
};
const std::unordered_map<int, std::string> HttpResponse::CODE_PATH = {
        { 400, "/400.html" },
        { 403, "/403.html" },
        { 404, "/404.html" },
};

void HttpResponse::init(const std::string& srcDir, const std::string& path, bool isKeepAlive, int code) {
    if (mmFile_) {
        unmapFile();
    }
    keepAlive_ = isKeepAlive;
    srcDir_ = srcDir;
    mmFile_ = nullptr;
    mmFileStat_ = { 0 };
    code_ = code;
    path_ = path;
}

//void HttpResponse::addHeaders(bool keepAlive, Buffer &buf, int type) {
//    buf.append("Connection: ");
//    if (keepAlive) {
//        buf.append("keep-alive\r\n");
//        buf.append("keep-alive: max=6, timeout=120\r\n");
//    } else {
//        buf.append("close\r\n");
//    }
//    if (type) {
//        buf.append("Content-Type:application/json\r\n");
//    } else {
//        buf.append("Content-Type:text/html\r\n");
//    }
//    buf.append("Access-Control-Allow-Origin:*\r\n");
//}

void HttpResponse::addCors(Buffer &buf) {
    buf.append("Access-Control-Allow-Methods:POST, OPTIONS, GET, PUT, DELETE\r\n");
    buf.append("Access-Control-Allow-Headers:Content-Type, Connection, Content-Length, Keep-Alive, \r\n");
    buf.append("Access-Control-Max-Age:3600\r\n");
    buf.append("Cache-Control:no-cache, no-store, must-revalidate\r\n");
}

//void HttpResponse::addBody(const std::string &&data, Buffer &buf) {
//    buf.append("Content-Length: " + std::to_string(data.length()) + "\r\n\r\n");
//    buf.append(data + "\r\n");
//}

void HttpResponse::AddStateLine_(Buffer& buf) {
    std::string status;
    if(CODE.count(code_) == 1) {
        status = CODE.find(code_)->second;
    }
    else {
        code_ = 400;
        status = CODE.find(400)->second;
    }
    buf.append("HTTP/1.1 " + std::to_string(code_) + " " + status + "\r\n");
}

void HttpResponse::AddHeader_(Buffer& buf) {
    buf.append("Connection: ");
    if(keepAlive_) {
        buf.append("keep-alive\r\n");
        buf.append("keep-alive: max=6, timeout=120\r\n");
    } else{
        buf.append("close\r\n");
    }
    buf.append("Content-type: " + GetFileType_() + "\r\n");
}

void HttpResponse::AddContent_(Buffer& buf) {
    int srcFd = open((srcDir_ + path_).data(), O_RDONLY);
    if(srcFd < 0) {
        ErrorContent(buf, "File NotFound!");
        return;
    }

    /* 将文件映射到内存提高文件的访问速度
        MAP_PRIVATE 建立一个写入时拷贝的私有映射*/
//    LOG_DEBUG("file path %s, size: %d", (srcDir_ + path_).data(), mmFileStat_.st_size);
    int* mmRet = (int*)mmap(nullptr, mmFileStat_.st_size, PROT_READ, MAP_PRIVATE, srcFd, 0);
    if(*mmRet == -1) {
        LOG_ERROR("map file failed");
        ErrorContent(buf, "File NotFound!");
        return;
    }
    mmFile_ = (char*)mmRet;
    close(srcFd);
    buf.append("Content-length: " + std::to_string(mmFileStat_.st_size) + "\r\n\r\n");
}

void HttpResponse::makeResponse(Buffer &buf) {
    /* 判断请求的资源文件 */
    if(stat((srcDir_ + path_).data(), &mmFileStat_) < 0 || S_ISDIR(mmFileStat_.st_mode)) {
        code_ = 404;
    }
    else if(!(mmFileStat_.st_mode & S_IROTH)) {
        code_ = 403;
    }
    else if(code_ == -1) {
        code_ = 200;
    }
    ErrorHtml_();
    AddStateLine_(buf);
    AddHeader_(buf);
    AddContent_(buf);
}

void HttpResponse::ErrorContent(Buffer& buff, std::string &&message)
{
    std::string body;
    body += "<html><title>Error</title>";
    body += "<body bgcolor=\"ffffff\">";
    body += "<p>" + message + "</p>";
    body += "<hr><em>TinyWebServer</em></body></html>";
    buff.append("Content-length: " + std::to_string(body.size()) + "\r\n\r\n");
    buff.append(body);
}

void HttpResponse::unmapFile() {
    if(mmFile_) {
        munmap(mmFile_, mmFileStat_.st_size);
        mmFile_ = nullptr;
    }
}
void HttpResponse::ErrorHtml_() {
    if(CODE_PATH.count(code_) == 1) {
        path_ = CODE_PATH.find(code_)->second;
        stat((srcDir_ + path_).data(), &mmFileStat_);
    }
}
std::string HttpResponse::GetFileType_() {
    /* 判断文件类型 */
    std::string::size_type idx = path_.find_last_of('.');
    if(idx == std::string::npos) {
        return "text/plain";
    }
    std::string suffix = path_.substr(idx);
    if(SUFFIX_TYPE.count(suffix) == 1) {
        return SUFFIX_TYPE.find(suffix)->second;
    }
    return "text/plain";
}

char *HttpResponse::file() {
    return mmFile_;
}

size_t HttpResponse::fileLen() const {
    return mmFileStat_.st_size;
}

7.3 HTTP处理

HTTP处理模块是整个服务器的核心模块,负责管理客户端的连接、读写数据、HTTP处理逻辑。

  • 管理连接:当有客户端连接时,初始化相关数据,存储fd和客户端地址。当由于某种原因,需要断开连接时,该模块关闭fd,重置相关数据
  • 读写数据:- 根据不同的模式读取数据,调用buffer中的readFd函数。- 将缓冲区的数据写入socket中。此时需要注意一次可能不能将全部数据写出,需要循环写出,并更新指针。
  • 负责整个HTTP的处理流程:首先调用ParseHttpRequest解析读缓冲区的数据,然后再调用HttpResponse生成响应报文,并放入写缓冲区中,最后将写缓冲和请求文件地址赋值给iovec结点,等待写出。

7.3.1 代码

#ifndef TINYWEBSERVER_HTTPWORK_H
#define TINYWEBSERVER_HTTPWORK_H

#include <sys/socket.h>
#include <netinet/in.h>
#include "HttpResponse.h"
#include <mutex>
#include "ParseHttpRequest.h"
#include "../buffer/Buffer.h"

// 每个工作线程操纵的类接口,负责读写数据,处理Http请求,每个用户持有一个类
class HttpWork {
private:
    Buffer writeBuf_;
    Buffer readBuf_;
    int fd_{};
    bool isRun_;
    struct sockaddr_in addr_{};
    iovec iv[2]{};
    int io_cnt = 2;

    std::mutex mtx_;

public:
    ParseHttpRequest request_;
    HttpResponse response_;

    static std::string srcDir_;
    static bool et_;
    static std::atomic<int> userCount;
public:

    HttpWork();
    ~HttpWork();
    void init(int fd, const sockaddr_in &addr);
    ssize_t writeFd(int *Errno);
    ssize_t readFd(int *Errno);
    bool processHttp();
    size_t getWriteLen();
    void closeConn();
    int getFd();
    bool isKeepAlive();

    void resetBuffer();

    bool getIsRun();
};

#endif //TINYWEBSERVER_HTTPWORK_H
#include "HttpWork.h"
bool HttpWork::et_;
std::string HttpWork::srcDir_;
std::atomic<int> HttpWork::userCount;

void HttpWork::init(int fd, const sockaddr_in &addr) {
    assert(fd > 0);
    std::lock_guard<std::mutex> locker(mtx_);
    isRun_ = true;
    fd_ = fd;
    addr_ = addr;
    writeBuf_.resetBuffer();
    readBuf_.resetBuffer();
    request_.init();
    userCount ++;
}

HttpWork::HttpWork() {
    isRun_ = false;
    addr_ = {0};
}

ssize_t HttpWork::readFd(int *Errno) {
    assert(fd_ >= 0);
    ssize_t len = 0;
    do {
        auto t_len = readBuf_.readFd(fd_, Errno);
        // 返回0代表此次读取数据为0
        if (t_len <= 0) {
            break;
        }
        len += t_len;
    } while(et_);
    // len是此次总计读取的数据
    return len;
}

ssize_t HttpWork::writeFd(int *Errno) {
    assert(fd_ >= 0);
    ssize_t len = 0;
    do {
        len = writev(fd_, iv, io_cnt);
        if (len <= 0) {
            // 写错误
            *Errno = errno;
            break;
        }
        // 处理第一个缓冲区
        if (iv[0].iov_len > 0) {
            // 此时第一个iovec没有写完
            // 我们将更新iovec的base和buffer中的指针位置
            auto iv_len1 = writeBuf_.getContentLen(); // 获取待写入数据的长度
            if (iv_len1 <= static_cast<size_t>(len)) { // buf中全部写完
                // iv1已经全部写完,后续不再处理
                iv[0].iov_base = nullptr;
                iv[0].iov_len = 0;
                writeBuf_.resetBuffer();
                len = static_cast<ssize_t>(static_cast<size_t>(len) - iv_len1); // 获取第二个iv结点写入的数据
            } else {
                // iv1写了一部分
                writeBuf_.addReadIdx(len); // 更新
                // 指针
                iv[0].iov_base = writeBuf_.getReadPtr();
                iv[0].iov_len = writeBuf_.getContentLen();
                len = 0;
            }
        }
        // 处理第二个缓冲区
        if (iv[0].iov_len == 0)
        {
            iv[1].iov_base = (uint8_t*)iv[1].iov_base + len;
            iv[1].iov_len -= len;
        }
        if (0 == getWriteLen()) {
            iv[1].iov_base = nullptr;
            iv[1].iov_len = 0;
            break; // 写成功
        }
    } while(et_);
    return len;
}

HttpWork::~HttpWork() {
    writeBuf_.resetBuffer();
    readBuf_.resetBuffer();
    fd_ = -1;
    close(fd_);
}

size_t HttpWork::getWriteLen() {
    return iv[0].iov_len + iv[1].iov_len;
}

void HttpWork::closeConn() {
    std::lock_guard<std::mutex> locker(mtx_);
    if (isRun_) {
        close(fd_);
        fd_ = -1;
        isRun_ = false;
        userCount --;
        LOG_DEBUG("client %d is closed", fd_);
    }
}

bool HttpWork::getIsRun() {
    std::lock_guard<std::mutex> locker(mtx_);
    return isRun_;
}

int HttpWork::getFd() {
    std::lock_guard<std::mutex> locker(mtx_);
    return fd_;
}

bool HttpWork::isKeepAlive() {
    return request_.keepAlive();
}

void HttpWork::resetBuffer() {
    readBuf_.resetBuffer();
    writeBuf_.resetBuffer();
}

bool HttpWork::processHttp() {
    // 读缓冲中没有数据,接下来继续等待读
    if (readBuf_.getContentLen() <= 0) {
        return false;
    }
//    LOG_DEBUG("readBuf: %s", std::string(readBuf_.getConstReadPtr(), readBuf_.getContentLen()).c_str());
    request_.init(); // 清空上一次的数据
    // 请求成功解析
    if (request_.parse(readBuf_)) {
        // 解析成功,正式进入业务逻辑处理流程
        response_.init(srcDir_, request_.path(), request_.keepAlive(), 200);
    } else {
        response_.init(srcDir_, request_.path(), false, 400);
    }
    LOG_INFO("%s %s", request_.method().c_str(), request_.path().c_str());
    response_.makeResponse(writeBuf_);
    // 输出报文11
    iv[0].iov_base = writeBuf_.getReadPtr();
    iv[0].iov_len = writeBuf_.getContentLen();
    io_cnt = 1;
    if (response_.file() && response_.fileLen() > 0) {
        iv[1].iov_base = response_.file();
        iv[1].iov_len = response_.fileLen();
        io_cnt = 2;
    }
//    LOG_DEBUG("wait for write data: %d", getWriteLen());
    // 返回true表示等待写
    return true;
}

8 Server层处理

上层的基础API已经实现,Server层主要负责监听事件,等待客户端连接、接收请求和发送响应。

在这里使用EPOLL来监听各种事件,之后分类处理。但是主线程并不是真正的处理,而是将读写事件插入到任务队列中,由线程池负责处理。

此外当某个fd有事件发生时,要延长定时器的超时时间。

服务器所需的参数使用配置文件的形式传入程序中。

8.1 代码

#ifndef TINYWEBSERVER_SERVER_H
#define TINYWEBSERVER_SERVER_H
#include "../http/HttpWork.h"
#include "../http/HttpResponse.h"
#include "../http/ParseHttpRequest.h"
#include "../log/Log.h"
#include "../pool/ThreadPool.h"
#include "../timer/Timer.h"
#include "Epoll.h"
#include <unordered_map>
#include <sys/epoll.h>
#include <arpa/inet.h>
#include <functional>

class Server {
private:
    const char *ip_;
    int port_;
    int trigMod_;
    int timeoutMs_;
    int MAXFD_;
    std::unique_ptr<ThreadPool> threadPool_;
    std::unique_ptr<Timer> timer_;
    std::unique_ptr<Epoll> epoll_;

    std::unordered_map<int, HttpWork> users_; // 负责处理HTTP请求
    uint32_t httpConnEvents_{};
    uint32_t listenEvents_{};
    int listenFd_{};
    bool isRun_;
    std::string log_dir_;
    std::string srcDir_;

public:
    // 提供服务器运行参数
    Server(const char* ip, int port, int trigMod, int timeout, LogTarget target, LogLevel::value logLevel,
           int max_thread_cnt, int max_timer_cnt, int max_fd, int max_epoll_events, int sqlPort, const char * sqlUser,
           const char * sqlPwd, const char * dbName, int connPoolNum);
    ~Server();
    void initTrigMode();
    bool startListen();
    static int setNonBlocking(int fd);
    void dealListen();
    void addClient(int fd, sockaddr_in &addr);
    void dealWrite(HttpWork &client);
    void dealRead(HttpWork &client);
    static void sendError(int fd, const char *msg);
    void closeConn(HttpWork &client);
    void extendTime(int fd);
    void readCb(HttpWork &client);
    void writeCb(HttpWork &client);
    void run();
};

#endif //TINYWEBSERVER_SERVER_
#include "Server.h"

Server::Server(const char *ip, int port, int trigMod, int timeout, LogTarget target, LogLevel::value logLevel, int max_thread_cnt,
               int max_timer_cnt, int max_fd, int max_epoll_events, int sqlPort, const char *sqlUser,
               const char * sqlPwd, const char * dbName, int connPoolNum):ip_(ip), port_(port),
               trigMod_(trigMod), timeoutMs_(timeout), MAXFD_(max_fd),
               threadPool_(new ThreadPool(max_thread_cnt)), timer_(new Timer(max_timer_cnt)),
               epoll_(new Epoll(max_epoll_events)) {
    isRun_ = false;
    srcDir_ = getcwd(nullptr, 256);
    auto l = Log::getInstance();
//     初始化日志系统
    l->init(target, (srcDir_ + "/log").c_str(), ".log", logLevel);
    HttpWork::srcDir_ = srcDir_ + "/resources";

    SqlConnPool::Instance()->Init("localhost", sqlPort, sqlUser, sqlPwd, dbName, connPoolNum);
//     初始化监听事件
    initTrigMode();
//     启动listenFd
    if (startListen()) {
        isRun_ = true;
    }
}

void Server::initTrigMode() {
    listenEvents_ = EPOLLRDHUP;    // 检测socket关闭
    httpConnEvents_ = EPOLLONESHOT | EPOLLRDHUP;     // EPOLLONESHOT由一个线程处理
    switch (trigMod_) {
        case 0:
            break;
        case 1:
            httpConnEvents_ |= EPOLLET;
            break;
        case 2:
            listenEvents_ |= EPOLLET;
            break;
        case 3:
            listenEvents_ |= EPOLLET;
            httpConnEvents_ |= EPOLLET;
            break;
        default:
            listenEvents_ |= EPOLLET;
            httpConnEvents_ |= EPOLLET;
    }
    HttpWork::et_ = (httpConnEvents_ & EPOLLET);
}

bool Server::startListen() {
    struct sockaddr_in address = {0};
    address.sin_port = htons(port_);
    address.sin_family = AF_INET;
    inet_pton(AF_INET, ip_, &address.sin_addr);

    listenFd_ = socket(PF_INET, SOCK_STREAM, 0);
    if (listenFd_ < 0) {
        LOG_FATAL("create socket failed");
        return false;
    }
    setNonBlocking(listenFd_);
    int res;
    int optVal = 1;
    res = setsockopt(listenFd_, SOL_SOCKET, SO_REUSEADDR, &optVal, sizeof(int));
    if(res == -1) {
        LOG_FATAL("set socket setsockopt error !");
        close(listenFd_);
        return false;
    }

    res = bind(listenFd_, (struct sockaddr*)&address, sizeof address);
    if (res == -1) {
        LOG_FATAL("bind socket failed");
        return false;
    }

    res = listen(listenFd_, 8);
    if (res < 0) {
        LOG_FATAL("%s %d", "listen failed", res);
        return false;
    }

    epoll_->addFd(listenFd_, EPOLLIN|listenEvents_);

    LOG_INFO("listening on %s:%d", ip_, port_);
    return true;
}

int Server::setNonBlocking(int fd) {
    int old = fcntl(fd, F_GETFL);
    int newOp = old | O_NONBLOCK;
    fcntl(fd, F_SETFL, newOp);
    return old;
}

void Server::run() {
    if (!isRun_) {
        LOG_ERROR("Server start failed");
        return;
    }
    int timeout = -1;
    LOG_INFO("Server start running");
    while(isRun_) {
        if (timeoutMs_ > 0) {
            // 清理过期时间
            timeout = timer_->getNextTick();
        }
        // 等待直到下一个定时事件超时,如果timeout为-1代表队列中已经没有定时任务,阻塞等待
        int cnt = epoll_->wait(timeout);
         for (int i = 0; i < cnt; ++ i) {
            // 以此处理每个事件
            int fd = epoll_->getEventFd(i);
            uint32_t events = epoll_->getEvents(i);
            if (fd == listenFd_) {
                // 处理服务器连接请求
                dealListen();
            } else if (events & (EPOLLRDHUP & EPOLLERR & EPOLLHUP)) {
                LOG_WARN("(main): close event: fd(%d)", fd);
                closeConn(users_[fd]); // 关闭连接
            } else if (events & EPOLLIN) {
                dealRead(users_[fd]);
            } else if (events & EPOLLOUT) {
                dealWrite(users_[fd]);
            } else {
                LOG_ERROR("(main): unexpected event");
            }
        }
    }
}

void Server::dealListen() {
    sockaddr_in address{0};
    socklen_t addr_len = sizeof address;
    do {
        int fd = accept(listenFd_, (struct sockaddr*)&address, &addr_len);
        if (fd <= 0) {
            return;
        } else if (HttpWork::userCount >= MAXFD_) {
            sendError(fd, "Server busy");
            LOG_ERROR("server is full");
            return;
        }
        addClient(fd, address);
    } while(listenEvents_ & EPOLLET);
}
void Server::addClient(int fd, sockaddr_in &addr) {
    // 初始化连接
    users_[fd].init(fd, addr);
    HttpWork &client = users_[fd];
//    Log::DEBUG("(main): user %d isRun: %s", fd, std::to_string(client.getIsRun()).c_str());
    setNonBlocking(fd);
    // 假如监听列表
    epoll_->addFd(fd, EPOLLIN|httpConnEvents_);
    // 超时后断开连接
    if (timeoutMs_ > 0) {
        // 添加定时事件u
        timer_->push(fd, timeoutMs_, [this, &client] { closeConn(client); }); // 这里报错了,原因是closeConn的client参数应为指针
    }
    LOG_INFO("(main): user[%d] in, ip: %s, port: %d", fd, inet_ntoa(addr.sin_addr), ntohs(addr.sin_port));
}

void Server::dealWrite(HttpWork &client) {
    assert(client.getIsRun());
    extendTime(client.getFd());
    threadPool_->addTask([this, &client] { writeCb(client); });
}

void Server::dealRead(HttpWork &client) {
//    LOG_INFO("(main): dealRead client: %d", client.getFd());
    assert(client.getIsRun());
    extendTime(client.getFd());
    threadPool_->addTask([this, &client] { readCb(client); });
}

void Server::sendError(int fd, const char *msg) {
    assert(fd >= 0);
    auto len = write(fd, msg, sizeof msg);
    if (len <= 0) {
        LOG_WARN("(main): send error to client %d error", fd);
    }
    close(fd);
}

void Server::extendTime(int fd) {
    assert(fd >= 0);
    timer_->reset(fd, timeoutMs_);
}

void Server::readCb(HttpWork &client) {
    assert(client.getIsRun());
    int Errno = 0;

    auto len = client.readFd(&Errno);
    if (len <= 0 && !(Errno == EAGAIN || Errno == 0)) {
        // 出现了其他错误,关闭连接
        LOG_ERROR("(thread):read error: %d, client %d is closing", Errno, client.getFd());
        closeConn(client);
        return;
    }
    if (client.processHttp()) {
        // 成功处理了http读请求,response已生成,等待写出
        epoll_->modFd(client.getFd(), EPOLLOUT | httpConnEvents_);
    } else {
        // http请求未处理,读缓冲为空,重新等待请求
        epoll_->delFd(client.getFd());
        LOG_ERROR("(thread): readBuf is none, client: %d", client.getFd());
        closeConn(client);
    }
}

void Server::writeCb(HttpWork &client) {
    assert(client.getIsRun()); // 连接未关闭
    int Errno = 0;
    auto len = client.writeFd(&Errno);
//    LOG_DEBUG("Error: %d", Errno);
    if (client.getWriteLen() == 0) {
        LOG_INFO("(thread): write successfully from user %d", client.getFd());
        // 传输成功
        if (client.isKeepAlive()) {
            epoll_->modFd(client.getFd(), EPOLLIN | httpConnEvents_);
            client.resetBuffer();
            return;
        }
    } else if (len <= 0 && Errno == EAGAIN) {
        // 写缓冲满了,继续传输
//        LOG_WARN("EAGAIN, continue write, client %d", client.getFd());
        epoll_->modFd(client.getFd(), EPOLLOUT | httpConnEvents_);
        return;
    }
    LOG_INFO("(thread): client %d is closing", client.getFd());
    closeConn(client);
}

void Server::closeConn(HttpWork &client) {
    if (!client.getIsRun())
        return;
    LOG_INFO("(main): client %d is closing", client.getFd());
    epoll_->delFd(client.getFd());
    client.closeConn();
}

Server::~Server() {
    close(listenFd_);
    isRun_ = false;
}

9 压力测试

9.1 ET模式

./webbench-1.5/webbench -c5000-t10 http://127.0.0.1:20001/

在这里插入图片描述

./webbench-1.5/webbench -c8000-t10 http://127.0.0.1:20001/

在这里插入图片描述

./webbench-1.5/webbench -c10000-t10 http://127.0.0.1:20001/

在这里插入图片描述

9.2 LT模式

./webbench-1.5/webbench -c10000-t10 http://127.0.0.1:20001/

在这里插入图片描述

9.3 测试环境

  • Ubuntu: 20.04
  • cpu: i5-1035G1
  • 内存: 16G

10 运行说明

10.1 数据库初始化

CREATEDATABASE webserver;USE webserver;CREATETABLEuser(
    username VARCHAR(50)NOTNULL,
    password VARCHAR(50)NOTNULL,PRIMARYKEY(username));INSERTINTOuser(username, password)VALUES('root','123456');

10.2 导入mysql.h

安装mysql驱动

sudo apt-get install libmysqlclient-dev

10.3 编译运行

进入项目根目录

make
./build/bin/server

11 致谢

https://github.com/markparticle/WebServer
Linux高性能服务器编程,游双著.

完整项目链接:https://github.com/Joker0x00/TinyWebServer

标签: c++ 开发语言

本文转载自: https://blog.csdn.net/Joker15517/article/details/140621509
版权归原作者 青铜世纪 所有, 如有侵权,请联系我们删除。

“[C++]TinyWebServer”的评论:

还没有评论