protarrow icon indicating copy to clipboard operation
protarrow copied to clipboard

Support for proto extensions

Open chasezheng opened this issue 1 year ago • 5 comments

Hi all,

Thank you for developing this library, which has been really useful to me. Could you look into the support of proto extensions? If a message has extensions, then it has an extra map field - the keys are integers and the values can be any message.

Best regards, Chase

chasezheng avatar Dec 25 '24 01:12 chasezheng

the keys are integers and the values can be any message.

How would you represent this in arrow? Would the values just be a binary string with the message payload?

0x26res avatar Dec 27 '24 10:12 0x26res

I think so, it would probably a map of integer tag numbers to serialized bytes.

chasezheng avatar Dec 27 '24 17:12 chasezheng

Instead of extensions, would it be easier to support only extension declarations? These would behave like regular fields https://protobuf.dev/programming-guides/extension_declarations/#:~:text=Extension%20declarations%20aim%20to%20strike,difficult%20or%20impossible%20to%20strip.

chasezheng avatar Feb 02 '25 21:02 chasezheng

Hey I was having a look at this, but I'm not too familiar with the concept.

The code to access the extensions payload isn't exposed in the public API of protobuf so it takes a bit more effort to understand and reuse.

0x26res avatar Feb 03 '25 11:02 0x26res

@chasezheng I can put the extensions payload in a pa.map(pa.int32(), pa.binary()).

def encode_extension(message: Message, extension: FieldDescriptor) -> bytes:
    repeated = extension.label == FieldDescriptor.LABEL_REPEATED
    value = message.Extensions[extension]

    if extension.type == FieldDescriptor.TYPE_MESSAGE:
        if repeated:
            value = [_Wrapper(m) for m in value]
        else:
            value = _Wrapper(value)

    encoder = TYPE_TO_ENCODER[extension.type](
        field_number=extension.number,
        is_repeated=extension.label == FieldDescriptor.LABEL_REPEATED,
        is_packed=extension.is_packed,
    )

    with io.BytesIO() as buffer:
        encoder(write=buffer.write, value=value, deterministic=True)
        # extension._encoder(write=buffer.write, value=value, deterministic=True)
        buffer.seek(0)
        return buffer.read()


def encode_extensions(message: Message) -> dict[int, bytes]:
    return {
        extension.number: encode_extension(message, extension)
        for extension in message.Extensions
    }

But I've noticed that as such they are hard to decode individually. So far I haven't been able to get it to decode individual payloads. The code below doesn't work:

def decode_extension(payload, extension_descriptor: FieldDescriptor):
    decoder = TYPE_TO_DECODER[extension_descriptor.type](
        is_repeated=extension_descriptor.label == FieldDescriptor.LABEL_REPEATED,
        is_packed=extension_descriptor.is_packed,
        field_number=extension_descriptor.number,
        key=None,
        new_default=None,
    )
    with io.BytesIO(payload) as buffer:
        decoder(
            pos=0,
            end=len(payload),
            buffer=buffer.read,
            message=Base(),
            field_dict=Base.DESCRIPTOR.field_dict,
        )

An alternative would be to save the extensions payload together in a binary string (pa.binary()). As such they are not really usable from pyarrow, but it means we're able to put them back together in the message, and the round trip from protobuf to pyarrow and back doesn't loose any information.

def encode_extensions_simple(message: Message) -> bytes:
    copy = message.__class__()
    for extension_descriptor in message.Extensions:
        if extension_descriptor.type == FieldDescriptor.TYPE_MESSAGE:
            copy.Extensions[extension_descriptor].MergeFrom(
                message.Extensions[extension_descriptor]
            )
        else:
            copy.Extensions[extension_descriptor] = message.Extensions[
                extension_descriptor
            ]
    return copy.SerializeToString()


def decode_extensions_simple(payload: bytes, message: Message) -> Message:
    message.MergeFromString(payload)
    return message

0x26res avatar Feb 05 '25 13:02 0x26res