jv me buter

This commit is contained in:
Gu://em_ 2026-01-10 18:39:50 +01:00
parent 7a614bd0d4
commit f7af2c3850
8 changed files with 158 additions and 19 deletions

View file

@ -17,7 +17,7 @@ void daemon_init(struct config *cfg)
config = cfg; config = cfg;
} }
int get_pid() int get_pid(void)
{ {
FILE *stream = fopen(config->pid_file, "r"); FILE *stream = fopen(config->pid_file, "r");
if (stream == NULL) if (stream == NULL)

View file

@ -7,7 +7,7 @@
* *
* @return * @return
*/ */
int get_pid(); int get_pid(void);
/* @brief /* @brief
*/ */

View file

@ -53,7 +53,7 @@ ssize_t parse_headers(struct http_request *res, struct string *req,
// Yes I know I do one useless allocation but I really don't care at this // Yes I know I do one useless allocation but I really don't care at this
// point // point
while (req->data[i] != '\n') // ! Blank line while (req->data[i] != '\n' && req->data[i] != '\r') // ! Blank line
{ {
if (header == NULL) if (header == NULL)
{ {
@ -73,21 +73,26 @@ ssize_t parse_headers(struct http_request *res, struct string *req,
return ERR_HTTP_OUT_OF_MEMORY; return ERR_HTTP_OUT_OF_MEMORY;
// Read field // Read field
ssize_t nread = read_field(req, offset, &header->field); ssize_t nread = read_field(req, i, &header->field);
if (nread <= 0) if (nread <= 0)
return nread; // Contains error code when negative return nread; // Contains error code when negative
i += nread; i += nread + 1;
// Read value // Read value
nread = read_value(req, offset, &header->value); nread = read_value(req, i, &header->value);
if (nread <= 0) if (nread <= 0)
return nread; // Contains error code when negative return nread; // Contains error code when negative
i += nread + 1; i += nread + 1;
} }
return i + 1; if (req->data[i] == '\r')
i++;
if (req->data[i] == '\n')
i++;
return i;
} }
struct http_header *get_header(struct http_header *headers, const char *field) struct http_header *get_header(struct http_header *headers, const char *field)

View file

@ -184,10 +184,34 @@ static void check_req(struct http_request *req, struct http_response *resp)
!= 0) != 0)
resp->status_code = 400; resp->status_code = 400;
else if (string_compare_strictly_n_str(req->protocol, "HTTP/1.1", else if (string_compare_strictly_n_str(req->protocol, "HTTP/1.1",
strlen("HTTP/1.1") != 0)) strlen("HTTP/1.1")) != 0)
resp->status_code = 505; resp->status_code = 505;
printf("%s %d\n", req->protocol->data, resp->status_code); // Host
if (resp->status_code != 400 && resp->status_code != 505)
{
int host_count = 0;
struct http_header *cur = req->headers;
while (cur != NULL)
{
if (cur->field->size == 4
&& string_compare_n_str(cur->field, "Host", 4) == 0)
{
host_count++;
if (cur->value == NULL || cur->value->size == 0)
{
resp->status_code = 400;
break;
}
}
cur = cur->next;
}
if (host_count != 1)
resp->status_code = 400;
}
// printf("%s %d\n", req->protocol->data, resp->status_code);
} }
// === Functions // === Functions
@ -215,7 +239,8 @@ void handle_request(int client_fd, char *client_ip)
{ {
string_concat_str(str, buffer, nread); string_concat_str(str, buffer, nread);
} }
string_concat_str(str, buffer, nread); if (nread > 0)
string_concat_str(str, buffer, nread);
// Parse request // Parse request
struct http_request *req = parse_request(str); struct http_request *req = parse_request(str);
@ -288,7 +313,19 @@ struct http_request *parse_request(struct string *req)
size_t i = 0; size_t i = 0;
ssize_t nread = parse_reqline(res, req); ssize_t nread = parse_reqline(res, req);
if (nread <= 0) if (nread <= 0)
return NULL; {
if (nread == ERR_HTTP_NOT_IMPLEMENTED)
res->status_code = 501;
else
res->status_code = 400;
if (res->target == NULL)
res->target = string_create("", 0);
if (res->protocol == NULL)
res->protocol = string_create("HTTP/1.1", 8);
return res;
}
// Split path and query // Split path and query
split_target(res); split_target(res);
@ -298,7 +335,10 @@ struct http_request *parse_request(struct string *req)
// Headers // Headers
nread = parse_headers(res, req, i); nread = parse_headers(res, req, i);
if (nread <= 0) if (nread <= 0)
return NULL; {
res->status_code = 400;
return res;
}
return res; return res;
} }
@ -317,8 +357,11 @@ struct http_response *generate_response(struct http_request *req)
res->protocol = string_create(protocol, strlen(protocol)); res->protocol = string_create(protocol, strlen(protocol));
// Target // Target
str_concat_string(config->servers->root_dir, if (req->status_code == 0)
strlen(config->servers->root_dir), req->target); {
str_concat_string(config->servers->root_dir,
strlen(config->servers->root_dir), req->target);
}
// Status code // Status code
if (req->status_code == 0) if (req->status_code == 0)
@ -340,7 +383,8 @@ struct http_response *generate_response(struct http_request *req)
res->status_code = req->status_code; res->status_code = req->status_code;
// Check protocol and method // Check protocol and method
check_req(req, res); if (req->status_code == 0)
check_req(req, res);
// Headers // Headers
char *time = get_time(); char *time = get_time();
@ -362,6 +406,10 @@ struct http_response *generate_response(struct http_request *req)
res->status_code = 403; res->status_code = 403;
} }
} }
else
{
append_header(&res->headers, create_header("Content-Length", "0"));
}
append_header(&res->headers, create_header("Connection", "close")); append_header(&res->headers, create_header("Connection", "close"));
// Status msg // Status msg

View file

@ -26,7 +26,7 @@ void errlog_init(bool enabled, int logfile_fd, struct server_config *serv_cfg)
config.server_cfg = serv_cfg; config.server_cfg = serv_cfg;
} }
void print_err() void print_err(void)
{ {
print_log_err("%s", get_err()); print_log_err("%s", get_err());
} }
@ -55,7 +55,7 @@ void print_log_err(char *format, ...)
fprintf(stderr, "Error: %s", get_err()); fprintf(stderr, "Error: %s", get_err());
} }
char *get_err() char *get_err(void)
{ {
return strerror(errno); return strerror(errno);
} }

View file

@ -14,7 +14,7 @@ void errlog_init(bool enabled, int logfile_fd, struct server_config *serv_cfg);
/* @brief Retrieves the last error with errno and prints the corresponding /* @brief Retrieves the last error with errno and prints the corresponding
* error message in the logs and stderr * error message in the logs and stderr
*/ */
void print_err(); void print_err(void);
/* @brief Prints error logs, just like print_log() but for errors /* @brief Prints error logs, just like print_log() but for errors
*/ */
@ -22,6 +22,6 @@ void print_log_err(char *format, ...);
/* @brief Returns the string corresponding to the last error that happened /* @brief Returns the string corresponding to the last error that happened
*/ */
char *get_err(); char *get_err(void);
#endif // ! ERRORS_H #endif // ! ERRORS_H

View file

@ -0,0 +1,40 @@
#!/bin/sh
# Simple test script for HTTP/1.1 Host header compliance
# Usage: ./test_host_compliance.sh [IP] [PORT]
IP=${1:-"127.0.0.1"}
PORT=${2:-"6996"}
echo "Targeting server at $IP:$PORT"
test_req() {
NAME="$1"
PAYLOAD="$2"
EXPECTED="$3"
echo -n "Test: $NAME ... "
# Send payload, wait max 1s for response
RESP=$(printf "$PAYLOAD" | nc -w 1 $IP $PORT 2>/dev/null | head -n 1)
if echo "$RESP" | grep -q "$EXPECTED"; then
echo "PASS"
else
echo "FAIL (Expected '$EXPECTED', got '$RESP')"
fi
}
# 1. Valid Request
test_req "Valid Request" "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" "200 OK"
# 2. Missing Host Header
test_req "Missing Host" "GET / HTTP/1.1\r\n\r\n" "400 Bad Request"
# 3. Empty Host Header
test_req "Empty Host" "GET / HTTP/1.1\r\nHost:\r\n\r\n" "400 Bad Request"
# 4. Multiple Host Headers
test_req "Multiple Hosts" "GET / HTTP/1.1\r\nHost: a\r\nHost: b\r\n\r\n" "400 Bad Request"
# 5. Bad Protocol Version (Should be 505 now with the fix)
test_req "Bad Protocol (HTTP/1.0)" "GET / HTTP/1.0\r\nHost: localhost\r\n\r\n" "505"

View file

@ -130,3 +130,49 @@ def test_bad_request():
assert response.status == 400 assert response.status == 400
finally: finally:
kill_httpd(proc) kill_httpd(proc)
@pytest.mark.timeout(2)
def test_head_index():
proc = spawn_httpd("out.log")
try:
req = requests.head(f"http://{host}:{port}/index.html")
assert req.status_code == 200
assert req.text == ""
finally:
kill_httpd(proc)
@pytest.mark.timeout(2)
def test_missing_host():
proc = spawn_httpd("out.log")
sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
sock.connect((host,int(port)))
request = f"GET /index.html HTTP/1.1\r\nConnection: close\r\n\r\n"
sock.sendall(request.encode())
response = http.client.HTTPResponse(sock)
response.begin()
try:
assert response.status == 400
finally:
kill_httpd(proc)
@pytest.mark.timeout(2)
def test_directory_traversal():
proc = spawn_httpd("out.log")
sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
sock.connect((host,int(port)))
request = f"GET /../test_suite.py HTTP/1.1\r\nHOST: {host}:{port}\r\nConnection: close\r\n\r\n"
sock.sendall(request.encode())
response = http.client.HTTPResponse(sock)
response.begin()
try:
assert response.status in [400, 403, 404]
finally:
kill_httpd(proc)