add auto generated test folder to components:

1. add test cases and related scripts
2. add CI config files
read README.md for detail
This commit is contained in:
Yinling
2016-09-26 14:27:57 +08:00
committed by Angus Gratton
parent 58aec93dbb
commit 90e57cdf8f
77 changed files with 28574 additions and 0 deletions

View File

@@ -0,0 +1,498 @@
import socket
import ssl
import os
import re
import time
import threading
import Parameter
from PKI import PKIDict
from PKI import PKIItem
from NativeLog import NativeLog
class SerialPortCheckFail(StandardError):
pass
class SSLHandlerFail(StandardError):
pass
class PCFail(StandardError):
pass
class TargetFail(StandardError):
pass
def ssl_handler_wrapper(handler_type):
if handler_type == "PC":
exception_type = PCFail
elif handler_type == "Target":
exception_type = TargetFail
else:
exception_type = None
def _handle_func(func):
def _handle_args(*args, **kwargs):
try:
ret = func(*args, **kwargs)
except StandardError, e:
NativeLog.add_exception_log(e)
raise exception_type(str(e))
return ret
return _handle_args
return _handle_func
class SerialPort(object):
def __init__(self, tc_action, port_name):
self.tc_action = tc_action
self.port_name = port_name
def flush(self):
self.tc_action.flush_data(self.port_name)
def write_line(self, data):
self.tc_action.serial_write_line(self.port_name, data)
def check(self, condition, timeout=10):
if self.tc_action.check_response(self.port_name, condition, timeout) is False:
raise SerialPortCheckFail("serial port check fail, condition is %s" % condition)
def read_data(self):
return self.tc_action.serial_read_data(self.port_name)
pass
class SSLHandler(object):
# ssl operation timeout is 30 seconds
TIMEOUT = 30
def __init__(self, typ, config, serial_port):
self.type = typ
self.config = config
self.timeout = self.TIMEOUT
self.serial_port = serial_port
self.accept_thread = None
self.data_validation = False
def set_timeout(self, timeout):
self.timeout = timeout
def init_context(self):
pass
def connect(self, remote_ip, remote_port, local_ip=0, local_port=0):
pass
def listen(self, local_port=0, local_ip=0):
pass
def send(self, size, data):
pass
def recv(self, length, timeout):
pass
def set_data_validation(self, validation):
pass
def close(self):
if self.accept_thread is not None:
self.accept_thread.exit()
self.accept_thread.join(5)
pass
class TargetSSLHandler(SSLHandler):
def __init__(self, typ, config, serial_port):
SSLHandler.__init__(self, typ, config, serial_port)
self.ssl_id = None
self.server_id = None
@ssl_handler_wrapper("Target")
def init_context(self):
self.serial_port.flush()
self.serial_port.write_line("soc -T")
self.serial_port.check("+CLOSEALL")
if self.type == "client":
version = Parameter.VERSION[self.config["client_version"]]
fragment = self.config["client_fragment_size"]
ca = self.config["client_trust_anchor"]
cert = self.config["client_certificate"]
key = self.config["client_key"]
verify_required = 0x01 if self.config["verify_server"] is True else 0x00
context_type = 1
else:
version = Parameter.VERSION[self.config["server_version"]]
fragment = self.config["server_fragment_size"]
ca = self.config["server_trust_anchor"]
cert = self.config["server_certificate"]
key = self.config["server_key"]
verify_required = 0x02 if self.config["verify_client"] is True else 0x00
context_type = 2
ssc_cmd = "ssl -I -t %u -r %u -v %u -o %u" % (context_type, fragment, version, verify_required)
if ca is not None:
_index = PKIDict.PKIDict.CERT_DICT[ca]
ssc_cmd += " -a %d" % _index
if cert is not None:
_index = PKIDict.PKIDict.CERT_DICT[cert]
ssc_cmd += " -c %d" % _index
if key is not None:
_index = PKIDict.PKIDict.KEY_DICT[key]
ssc_cmd += " -k %d" % _index
# write command and check result
self.serial_port.flush()
self.serial_port.write_line(ssc_cmd)
self.serial_port.check(["+SSL:OK", "AND", "!+SSL:ERROR"])
@ssl_handler_wrapper("Target")
def connect(self, remote_ip, remote_port, local_ip=0, local_port=0):
self.serial_port.flush()
self.serial_port.write_line("soc -B -t SSL -i %s -p %s" % (local_ip, local_port))
self.serial_port.check(["OK", "AND", "!ERROR"])
self.serial_port.flush()
self.serial_port.write_line("soc -C -s 0 -i %s -p %s" % (remote_ip, remote_port))
self.serial_port.check(["OK", "AND", "!ERROR"], timeout=30)
self.ssl_id = 0
pass
def accept_succeed(self):
self.ssl_id = 1
class Accept(threading.Thread):
def __init__(self, serial_port, succeed_cb):
threading.Thread.__init__(self)
self.setDaemon(True)
self.serial_port = serial_port
self.succeed_cb = succeed_cb
self.exit_flag = threading.Event()
def run(self):
while self.exit_flag.isSet() is False:
try:
self.serial_port.check("+ACCEPT:", timeout=1)
self.succeed_cb()
break
except StandardError:
pass
def exit(self):
self.exit_flag.set()
@ssl_handler_wrapper("Target")
def listen(self, local_port=0, local_ip=0):
self.serial_port.flush()
self.serial_port.write_line("soc -B -t SSL -i %s -p %s" % (local_ip, local_port))
self.serial_port.check(["OK", "AND", "!ERROR"])
self.serial_port.flush()
self.serial_port.write_line("soc -L -s 0")
self.serial_port.check(["OK", "AND", "!ERROR"])
self.server_id = 0
self.accept_thread = self.Accept(self.serial_port, self.accept_succeed)
self.accept_thread.start()
pass
@ssl_handler_wrapper("Target")
def send(self, size=10, data=None):
if data is not None:
size = len(data)
self.serial_port.flush()
self.serial_port.write_line("soc -S -s %s -l %s" % (self.ssl_id, size))
self.serial_port.check(["OK", "AND", "!ERROR"])
pass
@ssl_handler_wrapper("Target")
def recv(self, length, timeout=SSLHandler.TIMEOUT):
pattern = re.compile("\+RECV:\d+,(\d+)\r\n")
data_len = 0
data = ""
time1 = time.time()
while time.time() - time1 < timeout:
data += self.serial_port.read_data()
if self.data_validation is True:
if "+DATA_ERROR" in data:
raise SSLHandlerFail("target data validation fail")
while True:
match = pattern.search(data)
if match is None:
break
else:
data_len += int(match.group(1))
data = data[data.find(match.group())+len(match.group()):]
if data_len >= length:
result = True
break
else:
result = False
if result is False:
raise SSLHandlerFail("Target recv fail")
def set_data_validation(self, validation):
self.data_validation = validation
self.serial_port.flush()
self.serial_port.write_line("soc -V -s %s -o %s" % (self.ssl_id, 1 if validation is True else 0))
self.serial_port.check(["OK", "AND", "!ERROR"])
@ssl_handler_wrapper("Target")
def close(self):
SSLHandler.close(self)
self.serial_port.flush()
self.serial_port.write_line("ssl -D")
self.serial_port.check(["+SSL:OK", "OR", "+SSL:ERROR"])
self.serial_port.write_line("soc -T")
self.serial_port.check("+CLOSEALL")
pass
pass
def calc_hash(index):
return (index & 0xffffffff) % 83 + (index & 0xffffffff) % 167
def verify_data(data, start_index):
for i, c in enumerate(data):
if ord(c) != calc_hash(start_index + i):
NativeLog.add_trace_critical("[Data Validation Error] target sent data index %u is error."
" Sent data is %x, should be %x"
% (start_index + i, ord(c), calc_hash(start_index + i)))
return False
return True
def make_validation_data(length, start_index):
return bytes().join([chr(calc_hash(start_index + i)) for i in range(length)])
class PCSSLHandler(SSLHandler):
PROTOCOL_MAPPING = {
"SSLv23": ssl.PROTOCOL_SSLv23,
"SSLv23_2": ssl.PROTOCOL_SSLv23,
"SSLv20": ssl.PROTOCOL_SSLv2,
"SSLv30": ssl.PROTOCOL_SSLv3,
"TLSv10": ssl.PROTOCOL_TLSv1,
"TLSv11": ssl.PROTOCOL_TLSv1_1,
"TLSv12": ssl.PROTOCOL_TLSv1_2,
}
CERT_FOLDER = os.path.join(".", "PKI", PKIDict.PKIDict.CERT_FOLDER)
KEY_FOLDER = os.path.join(".", "PKI", PKIDict.PKIDict.KEY_FOLDER)
def __init__(self, typ, config, serial_port):
SSLHandler.__init__(self, typ, config, serial_port)
self.ssl_context = None
self.ssl = None
self.server_sock = None
self.send_index = 0
self.recv_index = 0
class InitContextThread(threading.Thread):
def __init__(self, handler, version, cipher_suite, ca, cert, key, verify_required, remote_cert):
threading.Thread.__init__(self)
self.setDaemon(True)
self.handler = handler
self.version = version
self.cipher_suite = cipher_suite
self.ca = ca
self.cert = cert
self.key = key
self.verify_required = verify_required
self.remote_cert = remote_cert
pass
@staticmethod
def handle_cert(cert_file, ca_file):
cert = PKIItem.Certificate()
cert.parse_file(cert_file)
ca = PKIItem.Certificate()
ca.parse_file(ca_file)
if cert.file_encoding == "PEM" and ca.name in cert.cert_chain:
cert_chain_t = cert.cert_chain[1:cert.cert_chain.index(ca.name)]
ret = ["%s.pem" % c for c in cert_chain_t]
else:
ret = []
return ret
def run(self):
try:
ssl_context = ssl.SSLContext(self.version)
# cipher suite
ssl_context.set_ciphers(self.cipher_suite)
if self.ca is not None:
ssl_context.load_verify_locations(cafile=os.path.join(self.handler.CERT_FOLDER, self.ca))
# python ssl can't verify cert chain, don't know why
# need to load cert between cert and ca for pem (pem cert contains cert chain)
if self.remote_cert is not None:
cert_chain = self.handle_cert(self.remote_cert, self.ca)
for c in cert_chain:
NativeLog.add_trace_info("load ca chain %s" % c)
ssl_context.load_verify_locations(cafile=os.path.join(self.handler.CERT_FOLDER, c))
if self.cert is not None:
cert = os.path.join(self.handler.CERT_FOLDER, self.cert)
key = os.path.join(self.handler.KEY_FOLDER, self.key)
ssl_context.load_cert_chain(cert, keyfile=key)
if self.verify_required is True:
ssl_context.verify_mode = ssl.CERT_REQUIRED
else:
ssl_context.verify_mode = ssl.CERT_NONE
self.handler.ssl_context = ssl_context
except StandardError, e:
NativeLog.add_exception_log(e)
pass
pass
@ssl_handler_wrapper("PC")
def init_context(self):
if self.type == "client":
version = self.PROTOCOL_MAPPING[self.config["client_version"]]
cipher_suite = Parameter.CIPHER_SUITE[self.config["client_cipher_suite"]]
ca = self.config["client_trust_anchor"]
cert = self.config["client_certificate"]
key = self.config["client_key"]
verify_required = self.config["verify_server"]
remote_cert = self.config["server_certificate"]
else:
version = self.PROTOCOL_MAPPING[self.config["server_version"]]
cipher_suite = Parameter.CIPHER_SUITE[self.config["server_cipher_suite"]]
ca = self.config["server_trust_anchor"]
cert = self.config["server_certificate"]
key = self.config["server_key"]
verify_required = self.config["verify_client"]
remote_cert = self.config["client_certificate"]
_init_context = self.InitContextThread(self, version, cipher_suite, ca, cert, key, verify_required, remote_cert)
_init_context.start()
_init_context.join(5)
if self.ssl_context is None:
raise StandardError("Init Context Fail")
pass
@ssl_handler_wrapper("PC")
def connect(self, remote_ip, remote_port, local_ip=0, local_port=0):
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
# reuse socket in TIME_WAIT state
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(self.timeout)
sock.bind((local_ip, local_port))
self.ssl = self.ssl_context.wrap_socket(sock)
self.ssl.connect((remote_ip, remote_port))
pass
def accept_succeed(self, ssl_new):
ssl_new.settimeout(self.timeout)
self.ssl = ssl_new
class Accept(threading.Thread):
def __init__(self, server_sock, ssl_context, succeed_cb):
threading.Thread.__init__(self)
self.setDaemon(True)
self.server_sock = server_sock
self.ssl_context = ssl_context
self.succeed_cb = succeed_cb
self.exit_flag = threading.Event()
def run(self):
while self.exit_flag.isSet() is False:
try:
new_socket, addr = self.server_sock.accept()
ssl_new = self.ssl_context.wrap_socket(new_socket, server_side=True)
self.succeed_cb(ssl_new)
break
except StandardError:
pass
pass
def exit(self):
self.exit_flag.set()
@ssl_handler_wrapper("PC")
def listen(self, local_port=0, local_ip=0):
self.server_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
# reuse socket in TIME_WAIT state
self.server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_sock.settimeout(1)
self.server_sock.bind((local_ip, local_port))
self.server_sock.listen(5)
self.accept_thread = self.Accept(self.server_sock, self.ssl_context, self.accept_succeed)
self.accept_thread.start()
pass
@ssl_handler_wrapper("PC")
def send(self, size=10, data=None):
if data is None:
self.ssl.send(make_validation_data(size, self.send_index))
if self.data_validation is True:
self.send_index += size
else:
self.ssl.send(data)
@ssl_handler_wrapper("PC")
def recv(self, length, timeout=SSLHandler.TIMEOUT, data_validation=False):
time1 = time.time()
data_len = 0
while time.time() - time1 < timeout:
data = self.ssl.read()
if data_validation is True and len(data) > 0:
if verify_data(data, self.recv_index) is False:
raise SSLHandlerFail("PC data validation fail, index is %s" % self.recv_index)
self.recv_index += len(data)
data_len += len(data)
if data_len >= length:
result = True
break
else:
result = False
if result is False:
raise SSLHandlerFail("PC recv fail")
def set_data_validation(self, validation):
self.data_validation = validation
@ssl_handler_wrapper("PC")
def close(self):
SSLHandler.close(self)
if self.ssl is not None:
self.ssl.close()
self.ssl = None
if self.server_sock is not None:
self.server_sock.close()
self.server_sock = None
del self.ssl_context
def main():
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
# cipher suite
ssl_context.set_ciphers("AES256-SHA")
ssl_context.load_cert_chain("D:\workspace\\auto_test_script\PKI\Certificate\\"
"L2CertRSA512sha1_L1CertRSA512sha1_RootCertRSA512sha1.pem",
keyfile="D:\workspace\\auto_test_script\PKI\Key\PrivateKey2RSA512.pem")
ssl_context.verify_mode = ssl.CERT_NONE
server_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
# reuse socket in TIME_WAIT state
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.settimeout(100)
server_sock.bind(("192.168.111.5", 443))
server_sock.listen(5)
while True:
try:
new_socket, addr = server_sock.accept()
ssl_new = ssl_context.wrap_socket(new_socket, server_side=True)
print "server connected"
break
except StandardError:
pass
pass
if __name__ == '__main__':
main()