fix: miot storage type error

This commit is contained in:
topsworld 2025-01-17 09:41:01 +08:00
parent 37492438e8
commit cebd48355f

View File

@ -58,7 +58,6 @@ from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any, Optional, Union
import logging import logging
from urllib.request import Request, urlopen
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.x509.oid import NameOID from cryptography.x509.oid import NameOID
@ -66,6 +65,8 @@ from cryptography import x509
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives.asymmetric import ed25519
from custom_components.xiaomi_home.miot.common import MIoTHttp
# pylint: disable=relative-beyond-top-level # pylint: disable=relative-beyond-top-level
from .const import ( from .const import (
MANUFACTURER_EFFECTIVE_TIME, MANUFACTURER_EFFECTIVE_TIME,
@ -91,10 +92,10 @@ class MIoTStorage:
User data will be stored in the `.storage` directory of Home Assistant. User data will be stored in the `.storage` directory of Home Assistant.
""" """
_main_loop: asyncio.AbstractEventLoop = None _main_loop: asyncio.AbstractEventLoop
_file_future: dict[str, tuple[MIoTStorageType, asyncio.Future]] _file_future: dict[str, tuple[MIoTStorageType, asyncio.Future]]
_root_path: str = None _root_path: str
def __init__( def __init__(
self, root_path: str, self, root_path: str,
@ -138,7 +139,7 @@ class MIoTStorage:
if r_data is None: if r_data is None:
_LOGGER.error('load error, empty file, %s', full_path) _LOGGER.error('load error, empty file, %s', full_path)
return None return None
data_bytes: bytes = None data_bytes: bytes
# Hash check # Hash check
if with_hash_check: if with_hash_check:
if len(r_data) <= 32: if len(r_data) <= 32:
@ -207,17 +208,16 @@ class MIoTStorage:
else: else:
os.makedirs(os.path.dirname(full_path), exist_ok=True) os.makedirs(os.path.dirname(full_path), exist_ok=True)
try: try:
type_: type = type(data) w_bytes: bytes
w_bytes: bytes = None if isinstance(data, bytes):
if type_ == bytes:
w_bytes = data w_bytes = data
elif type_ == str: elif isinstance(data, str):
w_bytes = data.encode('utf-8') w_bytes = data.encode('utf-8')
elif type_ in [dict, list]: elif isinstance(data, (dict, list)):
w_bytes = json.dumps(data).encode('utf-8') w_bytes = json.dumps(data).encode('utf-8')
else: else:
_LOGGER.error( _LOGGER.error(
'save error, unsupported data type, %s', type_.__name__) 'save error, unsupported data type, %s', type(data).__name__)
return False return False
with open(full_path, 'wb') as w_file: with open(full_path, 'wb') as w_file:
w_file.write(w_bytes) w_file.write(w_bytes)
@ -351,7 +351,8 @@ class MIoTStorage:
def load_file(self, domain: str, name_with_suffix: str) -> Optional[bytes]: def load_file(self, domain: str, name_with_suffix: str) -> Optional[bytes]:
full_path = os.path.join(self._root_path, domain, name_with_suffix) full_path = os.path.join(self._root_path, domain, name_with_suffix)
return self.__load( return self.__load(
full_path=full_path, type_=bytes, with_hash_check=False) full_path=full_path, type_=bytes,
with_hash_check=False) # type: ignore
async def load_file_async( async def load_file_async(
self, domain: str, name_with_suffix: str self, domain: str, name_with_suffix: str
@ -369,7 +370,7 @@ class MIoTStorage:
None, self.__load, full_path, bytes, False) None, self.__load, full_path, bytes, False)
if not fut.done(): if not fut.done():
self.__add_file_future(full_path, MIoTStorageType.LOAD_FILE, fut) self.__add_file_future(full_path, MIoTStorageType.LOAD_FILE, fut)
return await fut return await fut # type: ignore
def remove_file(self, domain: str, name_with_suffix: str) -> bool: def remove_file(self, domain: str, name_with_suffix: str) -> bool:
full_path = os.path.join(self._root_path, domain, name_with_suffix) full_path = os.path.join(self._root_path, domain, name_with_suffix)
@ -436,7 +437,7 @@ class MIoTStorage:
domain=config_domain, name=config_name, data=config) domain=config_domain, name=config_name, data=config)
local_config = (self.load(domain=config_domain, local_config = (self.load(domain=config_domain,
name=config_name, type_=dict)) or {} name=config_name, type_=dict)) or {}
local_config.update(config) local_config.update(config) # type: ignore
return self.save( return self.save(
domain=config_domain, name=config_name, data=local_config) domain=config_domain, name=config_name, data=local_config)
@ -472,27 +473,31 @@ class MIoTStorage:
domain=config_domain, name=config_name, data=config) domain=config_domain, name=config_name, data=config)
local_config = (await self.load_async( local_config = (await self.load_async(
domain=config_domain, name=config_name, type_=dict)) or {} domain=config_domain, name=config_name, type_=dict)) or {}
local_config.update(config) local_config.update(config) # type: ignore
return await self.save_async( return await self.save_async(
domain=config_domain, name=config_name, data=local_config) domain=config_domain, name=config_name, data=local_config)
def load_user_config( def load_user_config(
self, uid: str, cloud_server: str, keys: Optional[list[str]] = None self, uid: str, cloud_server: str, keys: Optional[list[str]] = None
) -> dict[str, Any]: ) -> dict[str, Any]:
if keys is not None and len(keys) == 0: if isinstance(keys, list) and len(keys) == 0:
# Do nothing # Do nothing
return {} return {}
config_domain = 'miot_config' config_domain = 'miot_config'
config_name = f'{uid}_{cloud_server}' config_name = f'{uid}_{cloud_server}'
local_config = (self.load(domain=config_domain, local_config = (self.load(domain=config_domain,
name=config_name, type_=dict)) or {} name=config_name, type_=dict))
if not isinstance(local_config, dict):
return {}
if keys is None: if keys is None:
return local_config return local_config
return {key: local_config.get(key, None) for key in keys} return {
key: local_config[key] for key in keys
if key in local_config}
async def load_user_config_async( async def load_user_config_async(
self, uid: str, cloud_server: str, keys: Optional[list[str]] = None self, uid: str, cloud_server: str, keys: Optional[list[str]] = None
) -> dict[str, Any]: ) -> dict:
"""Load user configuration. """Load user configuration.
Args: Args:
@ -503,13 +508,15 @@ class MIoTStorage:
Returns: Returns:
dict[str, Any]: query result dict[str, Any]: query result
""" """
if keys is not None and len(keys) == 0: if isinstance(keys, list) and len(keys) == 0:
# Do nothing # Do nothing
return {} return {}
config_domain = 'miot_config' config_domain = 'miot_config'
config_name = f'{uid}_{cloud_server}' config_name = f'{uid}_{cloud_server}'
local_config = (await self.load_async( local_config = (await self.load_async(
domain=config_domain, name=config_name, type_=dict)) or {} domain=config_domain, name=config_name, type_=dict))
if not isinstance(local_config, dict):
return {}
if keys is None: if keys is None:
return local_config return local_config
return { return {
@ -517,7 +524,8 @@ class MIoTStorage:
if key in local_config} if key in local_config}
def gen_storage_path( def gen_storage_path(
self, domain: str = None, name_with_suffix: str = None self, domain: Optional[str] = None,
name_with_suffix: Optional[str] = None
) -> str: ) -> str:
"""Generate file path.""" """Generate file path."""
result = self._root_path result = self._root_path
@ -607,9 +615,8 @@ class MIoTCert:
if cert_data is None: if cert_data is None:
return 0 return 0
# Check user cert # Check user cert
user_cert: x509.Certificate = None
try: try:
user_cert = x509.load_pem_x509_certificate( user_cert: x509.Certificate = x509.load_pem_x509_certificate(
cert_data, default_backend()) cert_data, default_backend())
cert_info = {} cert_info = {}
for attribute in user_cert.subject: for attribute in user_cert.subject:
@ -667,7 +674,8 @@ class MIoTCert:
NameOID.COMMON_NAME, f'mips.{self._uid}.{did_hash}.2'), NameOID.COMMON_NAME, f'mips.{self._uid}.{did_hash}.2'),
])) ]))
csr = builder.sign( csr = builder.sign(
private_key, algorithm=None, backend=default_backend()) private_key, algorithm=None, # type: ignore
backend=default_backend())
return csr.public_bytes(serialization.Encoding.PEM).decode('utf-8') return csr.public_bytes(serialization.Encoding.PEM).decode('utf-8')
async def load_user_key_async(self) -> Optional[str]: async def load_user_key_async(self) -> Optional[str]:
@ -730,12 +738,11 @@ class DeviceManufacturer:
) -> None: ) -> None:
self._main_loop = loop or asyncio.get_event_loop() self._main_loop = loop or asyncio.get_event_loop()
self._storage = storage self._storage = storage
self._data = None self._data = {}
async def init_async(self) -> None: async def init_async(self) -> None:
if self._data: if self._data:
return return
data_cache: dict = None
data_cache = await self._storage.load_async( data_cache = await self._storage.load_async(
domain=self.DOMAIN, name='manufacturer', type_=dict) domain=self.DOMAIN, name='manufacturer', type_=dict)
if ( if (
@ -749,8 +756,15 @@ class DeviceManufacturer:
_LOGGER.debug('load manufacturer data success') _LOGGER.debug('load manufacturer data success')
return return
data_cloud = await self._main_loop.run_in_executor( data_cloud = None
None, self.__get_manufacturer_data) try:
data_cloud = await MIoTHttp.get_json_async(
url='https://cdn.cnbj1.fds.api.mi-img.com/res-conf/xiaomi-home/'
'manufacturer.json',
loop=self._main_loop)
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.error('get manufacturer info failed, %s', err)
if data_cloud: if data_cloud:
await self._storage.save_async( await self._storage.save_async(
domain=self.DOMAIN, name='manufacturer', domain=self.DOMAIN, name='manufacturer',
@ -758,32 +772,16 @@ class DeviceManufacturer:
self._data = data_cloud self._data = data_cloud
_LOGGER.debug('update manufacturer data success') _LOGGER.debug('update manufacturer data success')
else: else:
if data_cache: if isinstance(data_cache, dict):
self._data = data_cache.get('data', None) self._data = data_cache.get('data', {})
_LOGGER.error('load manufacturer data failed, use local data') _LOGGER.error('load manufacturer data failed, use local data')
else: else:
_LOGGER.error('load manufacturer data failed') _LOGGER.error('load manufacturer data failed')
async def deinit_async(self) -> None: async def deinit_async(self) -> None:
self._data = None self._data.clear()
def get_name(self, short_name: str) -> str: def get_name(self, short_name: str) -> str:
if not self._data or not short_name or short_name not in self._data: if not self._data or not short_name or short_name not in self._data:
return short_name return short_name
return self._data[short_name].get('name', None) or short_name return self._data[short_name].get('name', None) or short_name
def __get_manufacturer_data(self) -> dict:
try:
request = Request(
'https://cdn.cnbj1.fds.api.mi-img.com/res-conf/xiaomi-home/'
'manufacturer.json',
method='GET')
content: bytes = None
with urlopen(request) as response:
content = response.read()
return (
json.loads(str(content, 'utf-8'))
if content else None)
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.error('get manufacturer info failed, %s', err)
return None