esp_prov: Compatibility changes and refactoring

- Removed python 2 compatibility
- Removed dependencies on redundant external modules
- Interactive provisioning input for security scheme 2
- Style changes:
  Updated print statements to format strings
  Colored verbose logging
  Raised exceptions on errors instead of clean exits
This commit is contained in:
Laukik Hase
2022-06-22 15:14:19 +05:30
parent 2c4e5c2963
commit 9aefcb12f5
20 changed files with 212 additions and 341 deletions

View File

@@ -1,35 +1,24 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
# APIs for interpreting and creating protobuf packets for
# protocomm endpoint with security type protocomm_security1
from __future__ import print_function
import proto
import session_pb2
import utils
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from future.utils import tobytes
from utils import long_to_bytes, str_to_bytes
from .security import Security
def a_xor_b(a: bytes, b: bytes) -> bytes:
return b''.join(long_to_bytes(a[i] ^ b[i]) for i in range(0, len(b)))
# Enum for state of protocomm_security1 FSM
class security_state:
REQUEST1 = 0
@@ -38,25 +27,11 @@ class security_state:
FINISHED = 3
def xor(a, b):
# XOR two inputs of type `bytes`
ret = bytearray()
# Decode the input bytes to strings
a = a.decode('latin-1')
b = b.decode('latin-1')
for i in range(max(len(a), len(b))):
# Convert the characters to corresponding 8-bit ASCII codes
# then XOR them and store in bytearray
ret.append(([0, ord(a[i])][i < len(a)]) ^ ([0, ord(b[i])][i < len(b)]))
# Convert bytearray to bytes
return bytes(ret)
class Security1(Security):
def __init__(self, pop, verbose):
# Initialize state of the security1 FSM
self.session_state = security_state.REQUEST1
self.pop = tobytes(pop)
self.pop = str_to_bytes(pop)
self.verbose = verbose
Security.__init__(self, self.security1_session)
@@ -66,59 +41,55 @@ class Security1(Security):
if (self.session_state == security_state.REQUEST1):
self.session_state = security_state.RESPONSE1_REQUEST2
return self.setup0_request()
if (self.session_state == security_state.RESPONSE1_REQUEST2):
elif (self.session_state == security_state.RESPONSE1_REQUEST2):
self.session_state = security_state.RESPONSE2
self.setup0_response(response_data)
return self.setup1_request()
if (self.session_state == security_state.RESPONSE2):
elif (self.session_state == security_state.RESPONSE2):
self.session_state = security_state.FINISHED
self.setup1_response(response_data)
return None
else:
print('Unexpected state')
return None
print('Unexpected state')
return None
def __generate_key(self):
# Generate private and public key pair for client
self.client_private_key = X25519PrivateKey.generate()
try:
self.client_public_key = self.client_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
except TypeError:
# backward compatible call for older cryptography library
self.client_public_key = self.client_private_key.public_key().public_bytes()
self.client_public_key = self.client_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
def _print_verbose(self, data):
if (self.verbose):
print('++++ ' + data + ' ++++')
print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def setup0_request(self):
# Form SessionCmd0 request packet using client public key
setup_req = session_pb2.SessionData()
setup_req.sec_ver = session_pb2.SecScheme1
setup_req = proto.session_pb2.SessionData()
setup_req.sec_ver = proto.session_pb2.SecScheme1
self.__generate_key()
setup_req.sec1.sc0.client_pubkey = self.client_public_key
self._print_verbose('Client Public Key:\t' + utils.str_to_hexstr(self.client_public_key.decode('latin-1')))
self._print_verbose(f'Client Public Key:\t0x{self.client_public_key.hex()}')
return setup_req.SerializeToString().decode('latin-1')
def setup0_response(self, response_data):
# Interpret SessionResp0 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
setup_resp.ParseFromString(str_to_bytes(response_data))
self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
if setup_resp.sec_ver != session_pb2.SecScheme1:
print('Incorrect sec scheme')
exit(1)
if setup_resp.sec_ver != proto.session_pb2.SecScheme1:
raise RuntimeError('Incorrect security scheme')
self.device_public_key = setup_resp.sec1.sr0.device_pubkey
# Device random is the initialization vector
device_random = setup_resp.sec1.sr0.device_random
self._print_verbose('Device Public Key:\t' + utils.str_to_hexstr(self.device_public_key.decode('latin-1')))
self._print_verbose('Device Random:\t' + utils.str_to_hexstr(device_random.decode('latin-1')))
self._print_verbose(f'Device Public Key:\t0x{self.device_public_key.hex()}')
self._print_verbose(f'Device Random:\t0x{device_random.hex()}')
# Calculate Curve25519 shared key using Client private key and Device public key
sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
self._print_verbose('Shared Key:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
self._print_verbose(f'Shared Key:\t0x{sharedK.hex()}')
# If PoP is provided, XOR SHA256 of PoP with the previously
# calculated Shared Key to form the actual Shared Key
@@ -128,8 +99,8 @@ class Security1(Security):
h.update(self.pop)
digest = h.finalize()
# XOR with and update Shared Key
sharedK = xor(sharedK, digest)
self._print_verbose('New Shared Key XORed with PoP:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
sharedK = a_xor_b(sharedK, digest)
self._print_verbose(f'Updated Shared Key (Shared key XORed with PoP):\t0x{sharedK.hex()}')
# Initialize the encryption engine with Shared Key and initialization vector
cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
self.cipher = cipher.encryptor()
@@ -137,36 +108,33 @@ class Security1(Security):
def setup1_request(self):
# Form SessionCmd1 request packet using encrypted device public key
setup_req = proto.session_pb2.SessionData()
setup_req.sec_ver = session_pb2.SecScheme1
setup_req.sec_ver = proto.session_pb2.SecScheme1
setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
# Encrypt device public key and attach to the request packet
client_verify = self.cipher.update(self.device_public_key)
self._print_verbose('Client Verify:\t' + utils.str_to_hexstr(client_verify.decode('latin-1')))
self._print_verbose(f'Client Proof:\t0x{client_verify.hex()}')
setup_req.sec1.sc1.client_verify_data = client_verify
return setup_req.SerializeToString().decode('latin-1')
def setup1_response(self, response_data):
# Interpret SessionResp1 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
setup_resp.ParseFromString(str_to_bytes(response_data))
# Ensure security scheme matches
if setup_resp.sec_ver == session_pb2.SecScheme1:
if setup_resp.sec_ver == proto.session_pb2.SecScheme1:
# Read encrypyed device verify string
device_verify = setup_resp.sec1.sr1.device_verify_data
self._print_verbose('Device verify:\t' + utils.str_to_hexstr(device_verify.decode('latin-1')))
self._print_verbose(f'Device Proof:\t0x{device_verify.hex()}')
# Decrypt the device verify string
enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
self._print_verbose('Enc client pubkey:\t ' + utils.str_to_hexstr(enc_client_pubkey.decode('latin-1')))
# Match decryped string with client public key
if enc_client_pubkey != self.client_public_key:
print('Mismatch in device verify')
return -2
raise RuntimeError('Failed to verify device!')
else:
print('Unsupported security protocol')
return -1
raise RuntimeError('Unsupported security protocol')
def encrypt_data(self, data):
return self.cipher.update(tobytes(data))
return self.cipher.update(data)
def decrypt_data(self, data):
return self.cipher.update(tobytes(data))
return self.cipher.update(data)