diff --git a/custom_components/xiaomi_home/miot/miot_storage.py b/custom_components/xiaomi_home/miot/miot_storage.py index a2e9741..ee8e955 100644 --- a/custom_components/xiaomi_home/miot/miot_storage.py +++ b/custom_components/xiaomi_home/miot/miot_storage.py @@ -58,7 +58,6 @@ from enum import Enum, auto from pathlib import Path from typing import Any, Optional, Union import logging -from urllib.request import Request, urlopen from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from cryptography.x509.oid import NameOID @@ -66,6 +65,8 @@ from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ed25519 +from custom_components.xiaomi_home.miot.common import MIoTHttp + # pylint: disable=relative-beyond-top-level from .const import ( MANUFACTURER_EFFECTIVE_TIME, @@ -91,10 +92,10 @@ class MIoTStorage: User data will be stored in the `.storage` directory of Home Assistant. """ - _main_loop: asyncio.AbstractEventLoop = None + _main_loop: asyncio.AbstractEventLoop _file_future: dict[str, tuple[MIoTStorageType, asyncio.Future]] - _root_path: str = None + _root_path: str def __init__( self, root_path: str, @@ -138,7 +139,7 @@ class MIoTStorage: if r_data is None: _LOGGER.error('load error, empty file, %s', full_path) return None - data_bytes: bytes = None + data_bytes: bytes # Hash check if with_hash_check: if len(r_data) <= 32: @@ -207,17 +208,16 @@ class MIoTStorage: else: os.makedirs(os.path.dirname(full_path), exist_ok=True) try: - type_: type = type(data) - w_bytes: bytes = None - if type_ == bytes: + w_bytes: bytes + if isinstance(data, bytes): w_bytes = data - elif type_ == str: + elif isinstance(data, str): w_bytes = data.encode('utf-8') - elif type_ in [dict, list]: + elif isinstance(data, (dict, list)): w_bytes = json.dumps(data).encode('utf-8') else: _LOGGER.error( - 'save error, unsupported data type, %s', type_.__name__) + 'save error, unsupported data type, %s', type(data).__name__) return False with open(full_path, 'wb') as w_file: w_file.write(w_bytes) @@ -351,7 +351,8 @@ class MIoTStorage: def load_file(self, domain: str, name_with_suffix: str) -> Optional[bytes]: full_path = os.path.join(self._root_path, domain, name_with_suffix) return self.__load( - full_path=full_path, type_=bytes, with_hash_check=False) + full_path=full_path, type_=bytes, + with_hash_check=False) # type: ignore async def load_file_async( self, domain: str, name_with_suffix: str @@ -369,7 +370,7 @@ class MIoTStorage: None, self.__load, full_path, bytes, False) if not fut.done(): self.__add_file_future(full_path, MIoTStorageType.LOAD_FILE, fut) - return await fut + return await fut # type: ignore def remove_file(self, domain: str, name_with_suffix: str) -> bool: full_path = os.path.join(self._root_path, domain, name_with_suffix) @@ -436,7 +437,7 @@ class MIoTStorage: domain=config_domain, name=config_name, data=config) local_config = (self.load(domain=config_domain, name=config_name, type_=dict)) or {} - local_config.update(config) + local_config.update(config) # type: ignore return self.save( domain=config_domain, name=config_name, data=local_config) @@ -472,27 +473,31 @@ class MIoTStorage: domain=config_domain, name=config_name, data=config) local_config = (await self.load_async( domain=config_domain, name=config_name, type_=dict)) or {} - local_config.update(config) + local_config.update(config) # type: ignore return await self.save_async( domain=config_domain, name=config_name, data=local_config) def load_user_config( self, uid: str, cloud_server: str, keys: Optional[list[str]] = None ) -> dict[str, Any]: - if keys is not None and len(keys) == 0: + if isinstance(keys, list) and len(keys) == 0: # Do nothing return {} config_domain = 'miot_config' config_name = f'{uid}_{cloud_server}' local_config = (self.load(domain=config_domain, - name=config_name, type_=dict)) or {} + name=config_name, type_=dict)) + if not isinstance(local_config, dict): + return {} if keys is None: return local_config - return {key: local_config.get(key, None) for key in keys} + return { + key: local_config[key] for key in keys + if key in local_config} async def load_user_config_async( self, uid: str, cloud_server: str, keys: Optional[list[str]] = None - ) -> dict[str, Any]: + ) -> dict: """Load user configuration. Args: @@ -503,13 +508,15 @@ class MIoTStorage: Returns: dict[str, Any]: query result """ - if keys is not None and len(keys) == 0: + if isinstance(keys, list) and len(keys) == 0: # Do nothing return {} config_domain = 'miot_config' config_name = f'{uid}_{cloud_server}' local_config = (await self.load_async( - domain=config_domain, name=config_name, type_=dict)) or {} + domain=config_domain, name=config_name, type_=dict)) + if not isinstance(local_config, dict): + return {} if keys is None: return local_config return { @@ -517,7 +524,8 @@ class MIoTStorage: if key in local_config} def gen_storage_path( - self, domain: str = None, name_with_suffix: str = None + self, domain: Optional[str] = None, + name_with_suffix: Optional[str] = None ) -> str: """Generate file path.""" result = self._root_path @@ -607,9 +615,8 @@ class MIoTCert: if cert_data is None: return 0 # Check user cert - user_cert: x509.Certificate = None try: - user_cert = x509.load_pem_x509_certificate( + user_cert: x509.Certificate = x509.load_pem_x509_certificate( cert_data, default_backend()) cert_info = {} for attribute in user_cert.subject: @@ -667,7 +674,8 @@ class MIoTCert: NameOID.COMMON_NAME, f'mips.{self._uid}.{did_hash}.2'), ])) csr = builder.sign( - private_key, algorithm=None, backend=default_backend()) + private_key, algorithm=None, # type: ignore + backend=default_backend()) return csr.public_bytes(serialization.Encoding.PEM).decode('utf-8') async def load_user_key_async(self) -> Optional[str]: @@ -730,12 +738,11 @@ class DeviceManufacturer: ) -> None: self._main_loop = loop or asyncio.get_event_loop() self._storage = storage - self._data = None + self._data = {} async def init_async(self) -> None: if self._data: return - data_cache: dict = None data_cache = await self._storage.load_async( domain=self.DOMAIN, name='manufacturer', type_=dict) if ( @@ -749,8 +756,15 @@ class DeviceManufacturer: _LOGGER.debug('load manufacturer data success') return - data_cloud = await self._main_loop.run_in_executor( - None, self.__get_manufacturer_data) + data_cloud = None + try: + data_cloud = await MIoTHttp.get_json_async( + url='https://cdn.cnbj1.fds.api.mi-img.com/res-conf/xiaomi-home/' + 'manufacturer.json', + loop=self._main_loop) + except Exception as err: # pylint: disable=broad-exception-caught + _LOGGER.error('get manufacturer info failed, %s', err) + if data_cloud: await self._storage.save_async( domain=self.DOMAIN, name='manufacturer', @@ -758,32 +772,16 @@ class DeviceManufacturer: self._data = data_cloud _LOGGER.debug('update manufacturer data success') else: - if data_cache: - self._data = data_cache.get('data', None) + if isinstance(data_cache, dict): + self._data = data_cache.get('data', {}) _LOGGER.error('load manufacturer data failed, use local data') else: _LOGGER.error('load manufacturer data failed') async def deinit_async(self) -> None: - self._data = None + self._data.clear() def get_name(self, short_name: str) -> str: if not self._data or not short_name or short_name not in self._data: return short_name return self._data[short_name].get('name', None) or short_name - - def __get_manufacturer_data(self) -> dict: - try: - request = Request( - 'https://cdn.cnbj1.fds.api.mi-img.com/res-conf/xiaomi-home/' - 'manufacturer.json', - method='GET') - content: bytes = None - with urlopen(request) as response: - content = response.read() - return ( - json.loads(str(content, 'utf-8')) - if content else None) - except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('get manufacturer info failed, %s', err) - return None