@runtime_checkableclassIsDataclass(Protocol):# as already noted in comments, checking for this attribute is currently# the most reliable way to ascertain that something is a dataclass__dataclass_fields__:ClassVar[Dict[str,Any]]defis_dataclass(cls:type[Any])->bool:returnhasattr(cls,"__dataclass_fields__")defhas_nested_dataclass(cls:type[IsDataclass])->bool:# iterate fields and check if any of them are dataclassesreturnany(is_dataclass(f.type)forfincls.__dataclass_fields__.values())defcontains_a_union(cls:type[IsDataclass])->bool:returnany(is_union(f.type)forfincls.__dataclass_fields__.values())defhas_nested_base_model(cls:type[IsDataclass])->bool:forfinfields(cls):field_type=f.type# Resolve forward references and other annotationsorigin=get_origin(field_type)args=get_args(field_type)# If the field type is directly a subclass of BaseModelifisinstance(field_type,type)andissubclass(field_type,BaseModel):returnTrue# If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.iforiginisnotNoneandargs:forarginargs:# Recursively check the argument typesifisinstance(arg,type)andissubclass(arg,BaseModel):returnTrueelifget_origin(arg)isnotNone:# Handle nested generics like List[List[BaseModel]]ifhas_nested_base_model_in_type(arg):returnTrue# Handle Union typeselifargs:forarginargs:ifisinstance(arg,type)andissubclass(arg,BaseModel):returnTrueelifget_origin(arg)isnotNone:ifhas_nested_base_model_in_type(arg):returnTruereturnFalsedefhas_nested_base_model_in_type(tp:Any)->bool:"""Helper function to check if a type or its arguments is a BaseModel subclass."""origin=get_origin(tp)args=get_args(tp)ifisinstance(tp,type)andissubclass(tp,BaseModel):returnTrueiforiginisnotNoneandargs:forarginargs:ifhas_nested_base_model_in_type(arg):returnTruereturnFalseDataclassT=TypeVar("DataclassT",bound=IsDataclass)JSON_DATA_CONTENT_TYPE="application/json""""JSON data content type"""# TODO: what's the correct content type? There seems to be some disagreement over what it should bePROTOBUF_DATA_CONTENT_TYPE="application/x-protobuf""""Protobuf data content type"""classDataclassJsonMessageSerializer(MessageSerializer[DataclassT]):def__init__(self,cls:type[DataclassT])->None:ifcontains_a_union(cls):raiseValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")ifhas_nested_dataclass(cls)orhas_nested_base_model(cls):raiseValueError("Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model")self.cls=cls@propertydefdata_content_type(self)->str:returnJSON_DATA_CONTENT_TYPE@propertydeftype_name(self)->str:return_type_name(self.cls)defdeserialize(self,payload:bytes)->DataclassT:message_str=payload.decode("utf-8")returnself.cls(**json.loads(message_str))defserialize(self,message:DataclassT)->bytes:returnjson.dumps(asdict(message)).encode("utf-8")PydanticT=TypeVar("PydanticT",bound=BaseModel)classPydanticJsonMessageSerializer(MessageSerializer[PydanticT]):def__init__(self,cls:type[PydanticT])->None:self.cls=cls@propertydefdata_content_type(self)->str:returnJSON_DATA_CONTENT_TYPE@propertydeftype_name(self)->str:return_type_name(self.cls)defdeserialize(self,payload:bytes)->PydanticT:message_str=payload.decode("utf-8")returnself.cls.model_validate_json(message_str)defserialize(self,message:PydanticT)->bytes:returnmessage.model_dump_json().encode("utf-8")ProtobufT=TypeVar("ProtobufT",bound=Message)# This class serializes to and from a google.protobuf.Any message that has been serialized to a stringclassProtobufMessageSerializer(MessageSerializer[ProtobufT]):def__init__(self,cls:type[ProtobufT])->None:self.cls=cls@propertydefdata_content_type(self)->str:returnPROTOBUF_DATA_CONTENT_TYPE@propertydeftype_name(self)->str:return_type_name(self.cls)defdeserialize(self,payload:bytes)->ProtobufT:# Parse payload into a proto anyany_proto=any_pb2.Any()any_proto.ParseFromString(payload)destination_message=self.cls()ifnotany_proto.Unpack(destination_message):# type: ignoreraiseValueError(f"Failed to unpack payload into {self.cls}")returndestination_messagedefserialize(self,message:ProtobufT)->bytes:any_proto=any_pb2.Any()any_proto.Pack(message)# type: ignorereturnany_proto.SerializeToString()
def_type_name(cls:type[Any]|Any)->str:ifisinstance(cls,type):returncls.__name__else:returncast(str,cls.__class__.__name__)V=TypeVar("V")deftry_get_known_serializers_for_type(cls:type[Any])->list[MessageSerializer[Any]]:""":meta private:"""serializers:List[MessageSerializer[Any]]=[]ifissubclass(cls,BaseModel):serializers.append(PydanticJsonMessageSerializer(cls))elifis_dataclass(cls):serializers.append(DataclassJsonMessageSerializer(cls))elifissubclass(cls,Message):serializers.append(ProtobufMessageSerializer(cls))returnserializersclassSerializationRegistry:""":meta private:"""def__init__(self)->None:# type_name, data_content_type -> serializerself._serializers:dict[tuple[str,str],MessageSerializer[Any]]={}defadd_serializer(self,serializer:MessageSerializer[Any]|Sequence[MessageSerializer[Any]])->None:ifisinstance(serializer,Sequence):forcinserializer:self.add_serializer(c)returnself._serializers[(serializer.type_name,serializer.data_content_type)]=serializerdefdeserialize(self,payload:bytes,*,type_name:str,data_content_type:str)->Any:serializer=self._serializers.get((type_name,data_content_type))ifserializerisNone:returnUnknownPayload(type_name,data_content_type,payload)returnserializer.deserialize(payload)defserialize(self,message:Any,*,type_name:str,data_content_type:str)->bytes:serializer=self._serializers.get((type_name,data_content_type))ifserializerisNone:raiseValueError(f"Unknown type {type_name} with content type {data_content_type}")returnserializer.serialize(message)defis_registered(self,type_name:str,data_content_type:str)->bool:return(type_name,data_content_type)inself._serializersdeftype_name(self,message:Any)->str:return_type_name(message)