test: format all test scripts

This commit is contained in:
igor.udot
2025-02-24 10:18:03 +08:00
parent 717c18a58e
commit daf2d31008
381 changed files with 6180 additions and 4289 deletions

View File

@@ -1,6 +1,5 @@
# SPDX-FileCopyrightText: 2022-2023 Espressif Systems (Shanghai) CO LTD
# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import contextlib
import logging
import os
@@ -9,31 +8,49 @@ import socketserver
import ssl
import subprocess
from threading import Thread
from typing import Any, Callable, Dict, Optional
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
import pytest
from common_test_methods import get_host_ip4_by_dest_ip
from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
SERVER_PORT = 2222
def _path(f): # type: (str) -> str
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
return os.path.join(os.path.dirname(os.path.realpath(__file__)), f)
def set_server_cert_cn(ip): # type: (str) -> None
arg_list = [
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
'-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']]
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'), '-subj', '/CN={}'.format(ip), '-new'],
[
'openssl',
'x509',
'-req',
'-in',
_path('srv.csr'),
'-CA',
_path('ca.crt'),
'-CAkey',
_path('ca.key'),
'-CAcreateserial',
'-out',
_path('srv.crt'),
'-days',
'360',
],
]
for args in arg_list:
if subprocess.check_call(args) != 0:
raise RuntimeError('openssl command {} failed'.format(args))
class MQTTHandler(socketserver.StreamRequestHandler):
def handle(self) -> None:
logging.info(' - connection from: {}'.format(self.client_address))
data = bytearray(self.request.recv(1024))
@@ -56,12 +73,14 @@ class TlsServer(socketserver.TCPServer):
allow_reuse_address = True
allow_reuse_port = True
def __init__(self,
port:int = SERVER_PORT,
ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler,
client_cert:bool=False,
refuse_connection:bool=False,
use_alpn:bool=False):
def __init__(
self,
port: int = SERVER_PORT,
ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler,
client_cert: bool = False,
refuse_connection: bool = False,
use_alpn: bool = False,
):
self.refuse_connection = refuse_connection
self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_error = ''
@@ -73,7 +92,7 @@ class TlsServer(socketserver.TCPServer):
if use_alpn:
self.context.set_alpn_protocols(['mymqtt', 'http/1.1'])
self.server_thread = Thread(target=self.serve_forever)
super().__init__(('',port), ServerHandler)
super().__init__(('', port), ServerHandler)
def server_activate(self) -> None:
self.socket = self.context.wrap_socket(self.socket, server_side=True)
@@ -125,14 +144,16 @@ def get_test_cases(dut: Dut) -> Any:
cases = {}
try:
# Get connection test cases configuration: symbolic names for test cases
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
for case in [
'EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN',
]:
cases[case] = dut.app.sdkconfig.get(case)
except Exception:
logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
@@ -147,7 +168,7 @@ def get_dut_ip(dut: Dut) -> Any:
@contextlib.contextmanager
def connect_dut(dut: Dut, uri:str, case_id:int) -> Any:
def connect_dut(dut: Dut, uri: str, case_id: int) -> Any:
dut.write('connection_setup')
dut.write(f'connect {uri} {case_id}')
dut.expect(f'Test case:{case_id} started')
@@ -157,12 +178,16 @@ def connect_dut(dut: Dut, uri:str, case_id:int) -> Any:
dut.write('disconnect')
def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
def run_cases(dut: Dut, uri: str, cases: Dict[str, int]) -> None:
try:
dut.write('init')
dut.write(f'start')
dut.write(f'disconnect')
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
for case in [
'EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
]:
# All these cases connect to the server with no server verification or with server only verification
with TlsServer(), connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: default server - expect to connect normally')
@@ -172,9 +197,13 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: server with client verification - handshake error since client presents no client certificate')
logging.info(
f'Running {case}: server with client verification - handshake error since client presents no client certificate'
)
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
dut.expect(
'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED'
) # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error()
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
@@ -187,15 +216,21 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
with TlsServer() as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error')
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
if re.match('.*alert.*unknown.*ca',s.last_ssl_error(), flags=re.I) is None:
dut.expect(
'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED'
) # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
if re.match('.*alert.*unknown.*ca', s.last_ssl_error(), flags=re.I) is None:
raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}')
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error')
logging.info(
f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error'
)
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
dut.expect(
'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED'
) # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error()))
@@ -214,8 +249,8 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
dut.write('destroy')
@pytest.mark.esp32
@pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_connect(
dut: Dut,
log_performance: Callable[[str, object], None],

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023-2024 Espressif Systems (Shanghai) CO LTD
# SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import contextlib
import difflib
@@ -22,13 +22,13 @@ import paho.mqtt.client as mqtt
import pexpect
import pytest
from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
DEFAULT_MSG_SIZE = 16
# Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher(mqtt.Client):
def __init__(self, config, log_details=False): # type: (MqttPublisher, dict, bool) -> None
self.log_details = log_details
self.config = config
@@ -40,7 +40,9 @@ class MqttPublisher(mqtt.Client):
self.event_client_subscribed = Event()
self.event_client_got_all = Event()
transport = 'websockets' if self.config['transport'] in ['ws', 'wss'] else 'tcp'
client_id = 'MqttTestRunner' + ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(5))
client_id = 'MqttTestRunner' + ''.join(
random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(5)
)
super().__init__(client_id, userdata=0, transport=transport)
def print_details(self, text): # type: (str) -> None
@@ -53,7 +55,7 @@ class MqttPublisher(mqtt.Client):
logging.info(f'Subscribed to {self.config["subscribe_topic"]} successfully with QoS: {granted_qos}')
self.event_client_subscribed.set()
def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc:int) -> None:
def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc: int) -> None:
self.event_client_connected.set()
def on_connect_fail(self, mqttc: Any, obj: Any) -> None:
@@ -67,8 +69,10 @@ class MqttPublisher(mqtt.Client):
self.event_client_got_all.set()
else:
differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data))))
logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
f'{len(self.expected_data)}')
logging.error(
f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
f'{len(self.expected_data)}'
)
logging.info(f'Repetitions: {payload.count(self.config["pattern"])}')
logging.info(f'Pattern: {self.config["pattern"]}')
logging.info(f'First: {payload[:DEFAULT_MSG_SIZE]}')
@@ -107,9 +111,10 @@ class MqttPublisher(mqtt.Client):
self.loop_stop()
def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
def get_configurations(dut: Dut, test_case: Any) -> Dict[str, Any]:
publish_cfg = {}
try:
@no_type_check
def get_config_from_dut(dut, config_option):
# logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option))
@@ -117,11 +122,18 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
if value is None:
return None, None
return value.group(1), int(value.group(2))
# Get publish test configuration
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_config_from_dut(
dut, 'EXAMPLE_BROKER_SSL_URI'
)
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_config_from_dut(
dut, 'EXAMPLE_BROKER_TCP_URI'
)
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_config_from_dut(
dut, 'EXAMPLE_BROKER_WSS_URI'
)
except Exception:
logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
@@ -133,9 +145,13 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
publish_cfg['qos'] = qos
publish_cfg['enqueue'] = enqueue
publish_cfg['transport'] = transport
publish_cfg['pattern'] = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
publish_cfg['pattern'] = ''.join(
random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)
)
publish_cfg['test_timeout'] = get_timeout(test_case)
unique_topic = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(DEFAULT_MSG_SIZE))
unique_topic = ''.join(
random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(DEFAULT_MSG_SIZE)
)
publish_cfg['subscribe_topic'] = 'test/subscribe_to/' + unique_topic
publish_cfg['publish_topic'] = 'test/subscribe_to/' + unique_topic
logging.info(f'configuration: {publish_cfg}')
@@ -143,7 +159,7 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
@contextlib.contextmanager
def connected_and_subscribed(dut:Dut) -> Any:
def connected_and_subscribed(dut: Dut) -> Any:
dut.write('start')
dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60)
yield
@@ -155,16 +171,17 @@ def get_scenarios() -> List[Dict[str, int]]:
# Initialize message sizes and repeat counts (if defined in the environment)
for i in count(0):
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
env_dict = {var: 'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
continue
break
if not scenarios: # No message sizes present in the env - set defaults
scenarios = [{'msg_len':0, 'nr_of_msgs':5}, # zero-sized messages
{'msg_len':2, 'nr_of_msgs':5}, # short messages
{'msg_len':200, 'nr_of_msgs':3}, # long messages
]
if not scenarios: # No message sizes present in the env - set defaults
scenarios = [
{'msg_len': 0, 'nr_of_msgs': 5}, # zero-sized messages
{'msg_len': 2, 'nr_of_msgs': 5}, # short messages
{'msg_len': 200, 'nr_of_msgs': 3}, # long messages
]
return scenarios
@@ -181,17 +198,23 @@ def get_timeout(test_case: Any) -> int:
def run_publish_test_case(dut: Dut, config: Any) -> None:
logging.info(f'Starting Publish test: transport:{config["transport"]}, qos:{config["qos"]},'
f'nr_of_msgs:{config["scenario"]["nr_of_msgs"]},'
f' msg_size:{config["scenario"]["msg_len"] * DEFAULT_MSG_SIZE}, enqueue:{config["enqueue"]}')
dut.write(f'publish_setup {config["transport"]} {config["publish_topic"]} {config["subscribe_topic"]} {config["pattern"]} {config["scenario"]["msg_len"]}')
logging.info(
f'Starting Publish test: transport:{config["transport"]}, qos:{config["qos"]},'
f'nr_of_msgs:{config["scenario"]["nr_of_msgs"]},'
f' msg_size:{config["scenario"]["msg_len"] * DEFAULT_MSG_SIZE}, enqueue:{config["enqueue"]}'
)
dut.write(
f'publish_setup {config["transport"]} {config["publish_topic"]} {config["subscribe_topic"]} {config["pattern"]} {config["scenario"]["msg_len"]}'
)
with MqttPublisher(config) as publisher, connected_and_subscribed(dut):
assert publisher.event_client_subscribed.wait(timeout=config['test_timeout']), 'Runner failed to subscribe'
msgs_published: List[mqtt.MQTTMessageInfo] = []
dut.write(f'publish {config["scenario"]["nr_of_msgs"]} {config["qos"]} {config["enqueue"]}')
assert publisher.event_client_got_all.wait(timeout=config['test_timeout']), (f'Not all data received from ESP32: {config["transport"]} '
f'qos={config["qos"]} received: {publisher.received} '
f'expected: {config["scenario"]["nr_of_msgs"]}')
assert publisher.event_client_got_all.wait(timeout=config['test_timeout']), (
f'Not all data received from ESP32: {config["transport"]} '
f'qos={config["qos"]} received: {publisher.received} '
f'expected: {config["scenario"]["nr_of_msgs"]}'
)
logging.info(' - all data received from ESP32')
payload = config['pattern'] * config['scenario']['msg_len']
for _ in range(config['scenario']['nr_of_msgs']):
@@ -214,15 +237,17 @@ def run_publish_test_case(dut: Dut, config: Any) -> None:
logging.info('ESP32 received all data from runner')
stress_scenarios = [{'msg_len':20, 'nr_of_msgs':30}] # many medium sized
stress_scenarios = [{'msg_len': 20, 'nr_of_msgs': 30}] # many medium sized
transport_cases = ['tcp', 'ws', 'wss', 'ssl']
qos_cases = [0, 1, 2]
enqueue_cases = [0, 1]
local_broker_supported_transports = ['tcp']
local_broker_scenarios = [{'msg_len':0, 'nr_of_msgs':5}, # zero-sized messages
{'msg_len':5, 'nr_of_msgs':20}, # short messages
{'msg_len':500, 'nr_of_msgs':10}, # long messages
{'msg_len':20, 'nr_of_msgs':20}] # many medium sized
local_broker_scenarios = [
{'msg_len': 0, 'nr_of_msgs': 5}, # zero-sized messages
{'msg_len': 5, 'nr_of_msgs': 20}, # short messages
{'msg_len': 500, 'nr_of_msgs': 10}, # long messages
{'msg_len': 20, 'nr_of_msgs': 20},
] # many medium sized
def make_cases(transport: Any, scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]:
@@ -233,10 +258,10 @@ test_cases = make_cases(transport_cases, get_scenarios())
stress_test_cases = make_cases(transport_cases, stress_scenarios)
@pytest.mark.esp32
@pytest.mark.ethernet
@pytest.mark.parametrize('test_case', test_cases)
@pytest.mark.parametrize('config', ['default'], indirect=True)
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
publish_cfg = get_configurations(dut, test_case)
dut.expect(re.compile(rb'mqtt>'), timeout=30)
@@ -244,11 +269,11 @@ def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
run_publish_test_case(dut, publish_cfg)
@pytest.mark.esp32
@pytest.mark.ethernet_stress
@pytest.mark.nightly_run
@pytest.mark.parametrize('test_case', stress_test_cases)
@pytest.mark.parametrize('config', ['default'], indirect=True)
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None:
publish_cfg = get_configurations(dut, test_case)
dut.expect(re.compile(rb'mqtt>'), timeout=30)
@@ -256,10 +281,10 @@ def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None:
run_publish_test_case(dut, publish_cfg)
@pytest.mark.esp32
@pytest.mark.ethernet
@pytest.mark.parametrize('test_case', make_cases(local_broker_supported_transports, local_broker_scenarios))
@pytest.mark.parametrize('config', ['local_broker'], indirect=True)
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_publish_local(dut: Dut, test_case: Any) -> None:
if test_case[0] not in local_broker_supported_transports:
pytest.skip(f'Skipping transport: {test_case[0]}...')