mirror of
https://github.com/espressif/esp-idf.git
synced 2025-11-02 21:48:13 +00:00
feat(esp_http_server): Added pre handshake callback for websocket
1. If the user wants authenticate the request, then user needs to do this before upgrading the protocol to websocket. 2. To achieve this, added pre_handshake_callack, which will execute before handshake, i.e. before switching protocol.
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
menu "Example Configuration"
|
||||
|
||||
config EXAMPLE_ENABLE_WS_PRE_HANDSHAKE_CB
|
||||
bool "Enable WebSocket pre-handshake callback"
|
||||
select HTTPD_WS_PRE_HANDSHAKE_CB_SUPPORT
|
||||
default y
|
||||
help
|
||||
Enable this option to use WebSocket pre-handshake callback.
|
||||
This will allow the server to register a callback function that will be
|
||||
called before the WebSocket handshake is processed.
|
||||
|
||||
endmenu
|
||||
@@ -67,6 +67,26 @@ static esp_err_t trigger_async_send(httpd_handle_t handle, httpd_req_t *req)
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifdef CONFIG_EXAMPLE_ENABLE_WS_PRE_HANDSHAKE_CB
|
||||
static esp_err_t ws_pre_handshake_cb(httpd_req_t *req)
|
||||
{
|
||||
ESP_LOGI(TAG, "=== ws_pre_handshake_cb called ===");
|
||||
|
||||
// Get the URI with query string
|
||||
const char *uri = req->uri;
|
||||
ESP_LOGI(TAG, "Requested URI: %s", uri ? uri : "NULL");
|
||||
|
||||
// Check if the query string contains token=valid
|
||||
if (uri && strstr(uri, "token=valid") != NULL) {
|
||||
ESP_LOGI(TAG, "Valid token found, accepting handshake");
|
||||
return ESP_OK;
|
||||
} else {
|
||||
ESP_LOGI(TAG, "No valid token found, rejecting handshake");
|
||||
return ESP_FAIL;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/*
|
||||
* This handler echos back the received ws data
|
||||
* and triggers an async send if certain message received
|
||||
@@ -107,6 +127,7 @@ static esp_err_t echo_handler(httpd_req_t *req)
|
||||
}
|
||||
ESP_LOGI(TAG, "Packet type: %d", ws_pkt.type);
|
||||
if (ws_pkt.type == HTTPD_WS_TYPE_TEXT &&
|
||||
ws_pkt.payload != NULL &&
|
||||
strcmp((char*)ws_pkt.payload,"Trigger async") == 0) {
|
||||
free(buf);
|
||||
return trigger_async_send(req->handle, req);
|
||||
@@ -128,6 +149,17 @@ static const httpd_uri_t ws = {
|
||||
.is_websocket = true
|
||||
};
|
||||
|
||||
static const httpd_uri_t ws_auth = {
|
||||
.uri = "/auth",
|
||||
.method = HTTP_GET,
|
||||
.handler = echo_handler,
|
||||
.user_ctx = NULL,
|
||||
.is_websocket = true,
|
||||
#ifdef CONFIG_EXAMPLE_ENABLE_WS_PRE_HANDSHAKE_CB
|
||||
.ws_pre_handshake_cb = ws_pre_handshake_cb
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
static httpd_handle_t start_webserver(void)
|
||||
{
|
||||
@@ -140,6 +172,7 @@ static httpd_handle_t start_webserver(void)
|
||||
// Registering the ws handler
|
||||
ESP_LOGI(TAG, "Registering URI handlers");
|
||||
httpd_register_uri_handler(server, &ws);
|
||||
httpd_register_uri_handler(server, &ws_auth);
|
||||
return server;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,13 +23,14 @@ OPCODE_PONG = 0xA
|
||||
|
||||
|
||||
class WsClient:
|
||||
def __init__(self, ip: str, port: int) -> None:
|
||||
def __init__(self, ip: str, port: int, uri: str = '') -> None:
|
||||
self.port = port
|
||||
self.ip = ip
|
||||
self.ws = websocket.WebSocket()
|
||||
self.uri = uri
|
||||
|
||||
def __enter__(self): # type: ignore
|
||||
self.ws.connect('ws://{}:{}/ws'.format(self.ip, self.port))
|
||||
self.ws.connect('ws://{}:{}/{}'.format(self.ip, self.port, self.uri))
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
@@ -71,7 +72,7 @@ def test_examples_protocol_http_ws_echo_server(dut: Dut) -> None:
|
||||
logging.info('Got Port : {}'.format(got_port))
|
||||
|
||||
# Start ws server test
|
||||
with WsClient(got_ip, int(got_port)) as ws:
|
||||
with WsClient(got_ip, int(got_port), uri='ws') as ws:
|
||||
DATA = 'Espressif'
|
||||
for expected_opcode in [OPCODE_TEXT, OPCODE_BIN, OPCODE_PING]:
|
||||
ws.write(data=DATA, opcode=expected_opcode)
|
||||
@@ -94,3 +95,35 @@ def test_examples_protocol_http_ws_echo_server(dut: Dut) -> None:
|
||||
data = data.decode()
|
||||
if opcode != OPCODE_TEXT or data != 'Async data':
|
||||
raise RuntimeError('Failed to receive correct opcode:{} or data:{}'.format(opcode, data))
|
||||
|
||||
|
||||
@pytest.mark.wifi_router
|
||||
@idf_parametrize('target', ['esp32'], indirect=['target'])
|
||||
def test_ws_auth_handshake(dut: Dut) -> None:
|
||||
"""
|
||||
Test that connecting to /ws does NOT print the handshake success log.
|
||||
This is used to verify ws_pre_handshake_cb can reject the handshake.
|
||||
"""
|
||||
# Wait for device to connect and start server
|
||||
if dut.app.sdkconfig.get('EXAMPLE_WIFI_SSID_PWD_FROM_STDIN') is True:
|
||||
dut.expect('Please input ssid password:')
|
||||
env_name = 'wifi_router'
|
||||
ap_ssid = get_env_config_variable(env_name, 'ap_ssid')
|
||||
ap_password = get_env_config_variable(env_name, 'ap_password')
|
||||
dut.write(f'{ap_ssid} {ap_password}')
|
||||
got_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[1].decode()
|
||||
got_port = dut.expect(r"Starting server on port: '(\d+)'", timeout=30)[1].decode()
|
||||
# Prepare a minimal WebSocket handshake request
|
||||
# Use WSClient to attempt the handshake, expecting it to fail (handshake rejected)
|
||||
|
||||
handshake_success = False
|
||||
try:
|
||||
# Attempt to use WSClient, expecting it to fail handshake
|
||||
with WsClient(got_ip, int(got_port), uri='auth?token=valid') as ws: # type: ignore # noqa: F841
|
||||
handshake_success = True
|
||||
except Exception as e:
|
||||
logging.info(f'WebSocket handshake failed: {e}')
|
||||
handshake_success = False
|
||||
|
||||
if handshake_success is False:
|
||||
raise RuntimeError('WebSocket handshake succeeded, but it should have been rejected by ws_pre_handshake_cb')
|
||||
|
||||
Reference in New Issue
Block a user