Source code for flare.internal.serde

from __future__ import annotations

import abc
import typing as t

from flare.converters import get_converter
from flare.exceptions import SerializerError, SerializerVersionViolation
from flare.utils import gather_iter

if t.TYPE_CHECKING:
    from flare.components import base

__all__: t.Final[t.Sequence[str]] = ("Serde",)


class SerdeABC(abc.ABC):
    """Abstract class for implementing a custom serializer and deserializer."""

    @abc.abstractmethod
    async def serialize(self, cookie: str, types: dict[str, t.Any], kwargs: dict[str, t.Any]) -> str:
        """
        Encode a custom_id for a component.

        Args:
            cookie:
                A unique identifier for the component.
            types:
                A dictionary of argument names to argument type hints. The type hint
                is used to encode a value to a string.
            kwargs:
                Values that the user passes to save state.
        """

    @abc.abstractmethod
    async def deserialize(
        self, custom_id: str, map: dict[str, t.Any]
    ) -> tuple[type[base.SupportsCallback[t.Any]], dict[str, t.Any]]:
        """
        Decode a custom_id for a component.

        Args:
            custom_id:
                The custom_id of the component.
            map:
                A dictionary of cookies to components.
        """


[docs] class Serde(SerdeABC): """ A class that handles serialization and deserialization of component custom_id encoded data. For simple behaviour changes it may be sufficient to subclass this class, but if you desire to completely overhaul serialization and deserialization, you may wish to only subclass SerdeABC instead. Args: sep: The character used to serperate fields. null: The character used to signify `None`. esc: The escape character. increment_length: `increment` is a unique number to allow buttons for the same values in the same message. `increment_length` can be set to `0` if identical buttons are never used in the same message. version: The serializer version number. """ def __init__( self, sep: str = "\x81", null: str = "\x82", esc: str = "\\", increment_length: int = 3, version: int | None = 0, ) -> None: self._SEP: str = sep self._ESC: str = esc self._NULL: str = null self._VER: int | None = version self._increment_length = increment_length self._increment = 0 if len(sep) != 1: raise ValueError("Separator must be a single character.") if len(null) != 1: raise ValueError("Null must be a single character.") if len(esc) != 1: raise ValueError("Escape must be a single character.") @property def SEP(self) -> str: """The separator used to separate arguments.""" return self._SEP @property def ESC(self) -> str: """The escape character.""" return self._ESC @property def NULL(self) -> str: """Character used to represent a missing value.""" return self._NULL @property def VER(self) -> int | None: """ The version of the serialization format. If None, the serializer will not attempt to verify the version of the serialized data. """ return self._VER def get_inc(self) -> str: self._increment += 1 if self._increment > 2**self._increment_length - 1: self._increment = 0 return self._increment.to_bytes(self._increment_length, "little").decode("latin1")
[docs] def escape(self, string: str) -> str: """Escape a string using `self.ESC`, `self.NULL` and `self.SEP`.""" out: list[str] = [] for char in string: if char in [self.ESC, self.NULL, self.SEP]: out.append(f"{self.ESC}{char}") else: out.append(char) return "".join(out)
[docs] def unescape(self, string: str) -> list[tuple[str, bool]]: """Returns a list of tuples signifying (the character, whether it was escaped)""" out: list[tuple[str, bool]] = [] last_was_esc = False for char in string: if not last_was_esc and char != self.ESC: out.append((char, False)) continue if last_was_esc: out.append((char, True)) last_was_esc = False continue last_was_esc = True return out
[docs] async def serialize(self, cookie: str, types: dict[str, t.Any], kwargs: dict[str, t.Any]) -> str: version = "" if self.VER is None else await get_converter(int).to_str(self.VER) async def serialize_one(k: str, v: t.Any) -> str: val = kwargs.get(k) converter = get_converter(v) return self.escape(await converter.to_str(val)) if val is not None else self.NULL out = self.SEP.join( ( f"{version}{self.get_inc()}{self.escape(cookie)}", *await gather_iter(serialize_one(k, v) for k, v in types.items()), ) ) if len(out) > 100: raise SerializerError( f"The serialized custom_id for component {cookie} may be too long." " Try reducing the number of parameters the component takes." f" Got length: {len(out)} Expected length: 100 or less" ) return out
[docs] def split_on_sep(self, string: list[tuple[str, bool]]) -> list[list[tuple[str, bool]]]: """Split the provided string on the separator, but ignore separators that are escaped. Args: string: The provided string. Returns: list[str] The split string. """ out: list[list[tuple[str, bool]]] = [[]] for char, is_escaped in string: if char == self.SEP and not is_escaped: out.append([]) else: out[-1].append((char, is_escaped)) return out
[docs] @staticmethod def tuple_list_to_string(string: list[tuple[str, bool]]) -> str: """Combine a list of tuples into a string, ignoring the second value.""" out: list[str] = [] for char, _ in string: out.append(char) return "".join(out)
async def cast_kwargs(self, kwargs: dict[str, t.Any], types: dict[str, t.Any]) -> dict[str, t.Any]: ret: dict[str, t.Any] = {} async def convert_one(k: str, v: t.Any) -> None: if v is None: ret[k] = None return cast_to = types[k] ret[k] = await get_converter(cast_to).from_str(v) await gather_iter(convert_one(k, v) for k, v in kwargs.items()) return ret
[docs] async def deserialize( self, custom_id: str, map: dict[str, t.Any] ) -> tuple[type[base.SupportsCallback[t.Any]], dict[str, t.Any]]: if self.VER is not None: # Allow for no version to disable verification version = await get_converter(int).from_str(custom_id[0]) if version != self.VER: raise SerializerVersionViolation( f"Serializer {self.__class__.__name__} cannot deserialize version {version}." ) custom_id = custom_id[1:] custom_id = custom_id[self._increment_length :] cookie, *args = self.split_on_sep(self.unescape(custom_id)) component_ = map.get(self.tuple_list_to_string(cookie)) if component_ is None: raise SerializerError(f"Component with cookie {cookie} does not exist.") types = component_._dataclass_annotations transformed_args: dict[str, t.Any] = {} for k, arg in zip(types.keys(), args): if len(arg) == 1: if arg[0] == (self.NULL, False): transformed_args[k] = None continue transformed_args[k] = self.tuple_list_to_string(arg) return (component_, await self.cast_kwargs(transformed_args, types))