fix: miot storage type error

This commit is contained in:
topsworld 2025-01-17 09:41:01 +08:00
parent 37492438e8
commit cebd48355f

View File

@ -58,7 +58,6 @@ from enum import Enum, auto
from pathlib import Path
from typing import Any, Optional, Union
import logging
from urllib.request import Request, urlopen
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.x509.oid import NameOID
@ -66,6 +65,8 @@ from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ed25519
from custom_components.xiaomi_home.miot.common import MIoTHttp
# pylint: disable=relative-beyond-top-level
from .const import (
MANUFACTURER_EFFECTIVE_TIME,
@ -91,10 +92,10 @@ class MIoTStorage:
User data will be stored in the `.storage` directory of Home Assistant.
"""
_main_loop: asyncio.AbstractEventLoop = None
_main_loop: asyncio.AbstractEventLoop
_file_future: dict[str, tuple[MIoTStorageType, asyncio.Future]]
_root_path: str = None
_root_path: str
def __init__(
self, root_path: str,
@ -138,7 +139,7 @@ class MIoTStorage:
if r_data is None:
_LOGGER.error('load error, empty file, %s', full_path)
return None
data_bytes: bytes = None
data_bytes: bytes
# Hash check
if with_hash_check:
if len(r_data) <= 32:
@ -207,17 +208,16 @@ class MIoTStorage:
else:
os.makedirs(os.path.dirname(full_path), exist_ok=True)
try:
type_: type = type(data)
w_bytes: bytes = None
if type_ == bytes:
w_bytes: bytes
if isinstance(data, bytes):
w_bytes = data
elif type_ == str:
elif isinstance(data, str):
w_bytes = data.encode('utf-8')
elif type_ in [dict, list]:
elif isinstance(data, (dict, list)):
w_bytes = json.dumps(data).encode('utf-8')
else:
_LOGGER.error(
'save error, unsupported data type, %s', type_.__name__)
'save error, unsupported data type, %s', type(data).__name__)
return False
with open(full_path, 'wb') as w_file:
w_file.write(w_bytes)
@ -351,7 +351,8 @@ class MIoTStorage:
def load_file(self, domain: str, name_with_suffix: str) -> Optional[bytes]:
full_path = os.path.join(self._root_path, domain, name_with_suffix)
return self.__load(
full_path=full_path, type_=bytes, with_hash_check=False)
full_path=full_path, type_=bytes,
with_hash_check=False) # type: ignore
async def load_file_async(
self, domain: str, name_with_suffix: str
@ -369,7 +370,7 @@ class MIoTStorage:
None, self.__load, full_path, bytes, False)
if not fut.done():
self.__add_file_future(full_path, MIoTStorageType.LOAD_FILE, fut)
return await fut
return await fut # type: ignore
def remove_file(self, domain: str, name_with_suffix: str) -> bool:
full_path = os.path.join(self._root_path, domain, name_with_suffix)
@ -436,7 +437,7 @@ class MIoTStorage:
domain=config_domain, name=config_name, data=config)
local_config = (self.load(domain=config_domain,
name=config_name, type_=dict)) or {}
local_config.update(config)
local_config.update(config) # type: ignore
return self.save(
domain=config_domain, name=config_name, data=local_config)
@ -472,27 +473,31 @@ class MIoTStorage:
domain=config_domain, name=config_name, data=config)
local_config = (await self.load_async(
domain=config_domain, name=config_name, type_=dict)) or {}
local_config.update(config)
local_config.update(config) # type: ignore
return await self.save_async(
domain=config_domain, name=config_name, data=local_config)
def load_user_config(
self, uid: str, cloud_server: str, keys: Optional[list[str]] = None
) -> dict[str, Any]:
if keys is not None and len(keys) == 0:
if isinstance(keys, list) and len(keys) == 0:
# Do nothing
return {}
config_domain = 'miot_config'
config_name = f'{uid}_{cloud_server}'
local_config = (self.load(domain=config_domain,
name=config_name, type_=dict)) or {}
name=config_name, type_=dict))
if not isinstance(local_config, dict):
return {}
if keys is None:
return local_config
return {key: local_config.get(key, None) for key in keys}
return {
key: local_config[key] for key in keys
if key in local_config}
async def load_user_config_async(
self, uid: str, cloud_server: str, keys: Optional[list[str]] = None
) -> dict[str, Any]:
) -> dict:
"""Load user configuration.
Args:
@ -503,13 +508,15 @@ class MIoTStorage:
Returns:
dict[str, Any]: query result
"""
if keys is not None and len(keys) == 0:
if isinstance(keys, list) and len(keys) == 0:
# Do nothing
return {}
config_domain = 'miot_config'
config_name = f'{uid}_{cloud_server}'
local_config = (await self.load_async(
domain=config_domain, name=config_name, type_=dict)) or {}
domain=config_domain, name=config_name, type_=dict))
if not isinstance(local_config, dict):
return {}
if keys is None:
return local_config
return {
@ -517,7 +524,8 @@ class MIoTStorage:
if key in local_config}
def gen_storage_path(
self, domain: str = None, name_with_suffix: str = None
self, domain: Optional[str] = None,
name_with_suffix: Optional[str] = None
) -> str:
"""Generate file path."""
result = self._root_path
@ -607,9 +615,8 @@ class MIoTCert:
if cert_data is None:
return 0
# Check user cert
user_cert: x509.Certificate = None
try:
user_cert = x509.load_pem_x509_certificate(
user_cert: x509.Certificate = x509.load_pem_x509_certificate(
cert_data, default_backend())
cert_info = {}
for attribute in user_cert.subject:
@ -667,7 +674,8 @@ class MIoTCert:
NameOID.COMMON_NAME, f'mips.{self._uid}.{did_hash}.2'),
]))
csr = builder.sign(
private_key, algorithm=None, backend=default_backend())
private_key, algorithm=None, # type: ignore
backend=default_backend())
return csr.public_bytes(serialization.Encoding.PEM).decode('utf-8')
async def load_user_key_async(self) -> Optional[str]:
@ -730,12 +738,11 @@ class DeviceManufacturer:
) -> None:
self._main_loop = loop or asyncio.get_event_loop()
self._storage = storage
self._data = None
self._data = {}
async def init_async(self) -> None:
if self._data:
return
data_cache: dict = None
data_cache = await self._storage.load_async(
domain=self.DOMAIN, name='manufacturer', type_=dict)
if (
@ -749,8 +756,15 @@ class DeviceManufacturer:
_LOGGER.debug('load manufacturer data success')
return
data_cloud = await self._main_loop.run_in_executor(
None, self.__get_manufacturer_data)
data_cloud = None
try:
data_cloud = await MIoTHttp.get_json_async(
url='https://cdn.cnbj1.fds.api.mi-img.com/res-conf/xiaomi-home/'
'manufacturer.json',
loop=self._main_loop)
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.error('get manufacturer info failed, %s', err)
if data_cloud:
await self._storage.save_async(
domain=self.DOMAIN, name='manufacturer',
@ -758,32 +772,16 @@ class DeviceManufacturer:
self._data = data_cloud
_LOGGER.debug('update manufacturer data success')
else:
if data_cache:
self._data = data_cache.get('data', None)
if isinstance(data_cache, dict):
self._data = data_cache.get('data', {})
_LOGGER.error('load manufacturer data failed, use local data')
else:
_LOGGER.error('load manufacturer data failed')
async def deinit_async(self) -> None:
self._data = None
self._data.clear()
def get_name(self, short_name: str) -> str:
if not self._data or not short_name or short_name not in self._data:
return short_name
return self._data[short_name].get('name', None) or short_name
def __get_manufacturer_data(self) -> dict:
try:
request = Request(
'https://cdn.cnbj1.fds.api.mi-img.com/res-conf/xiaomi-home/'
'manufacturer.json',
method='GET')
content: bytes = None
with urlopen(request) as response:
content = response.read()
return (
json.loads(str(content, 'utf-8'))
if content else None)
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.error('get manufacturer info failed, %s', err)
return None