Skip to content

Commit

Permalink
Add "absl::StatusOr<PythonConstMessagePointer> GetConstMessagePointer…
Browse files Browse the repository at this point in the history
…(PyObject* msg)" in proto_api which works with cpp extension, upb and pure python.

PiperOrigin-RevId: 699316527
  • Loading branch information
anandolee authored and copybara-github committed Nov 23, 2024
1 parent aded9b7 commit b9e69e8
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 36 deletions.
1 change: 1 addition & 0 deletions python/build_targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def build_targets(name):
visibility = ["//visibility:public"],
deps = [
"//src/google/protobuf",
"//src/google/protobuf/io",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@system_python//:python_headers",
Expand Down
85 changes: 85 additions & 0 deletions python/google/protobuf/proto_api.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "google/protobuf/proto_api.h"

#include <Python.h>

#include <memory>
#include <string>

#include "absl/log/absl_check.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/message.h"
namespace google {
namespace protobuf {
Expand Down Expand Up @@ -52,6 +56,87 @@ PythonMessageMutator PyProto_API::CreatePythonMessageMutator(
return PythonMessageMutator(owned_msg, msg, py_msg);
}

PythonConstMessagePointer::PythonConstMessagePointer(Message* owned_msg,
const Message* message,
PyObject* py_msg)
: owned_msg_(owned_msg), message_(message), py_msg_(py_msg) {
ABSL_DCHECK(py_msg != nullptr);
ABSL_DCHECK(message != nullptr);
Py_INCREF(py_msg_);
}

PythonConstMessagePointer::PythonConstMessagePointer(
PythonConstMessagePointer&& other)
: owned_msg_(other.owned_msg_ == nullptr ? nullptr
: other.owned_msg_.release()),
message_(other.message_),
py_msg_(other.py_msg_) {
other.message_ = nullptr;
other.py_msg_ = nullptr;
}

bool PythonConstMessagePointer::NotChanged() {
ABSL_DCHECK(!PyErr_Occurred());
if (owned_msg_ == nullptr) {
return false;
}

PyObject* py_serialized_pb(
PyObject_CallMethod(py_msg_, "SerializeToString", nullptr));
if (py_serialized_pb == nullptr) {
PyErr_Format(PyExc_ValueError, "Fail to serialize py_msg");
return false;
}
char* data;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(py_serialized_pb, &data, &len) < 0) {
Py_DECREF(py_serialized_pb);
PyErr_Format(PyExc_ValueError, "Fail to get bytes from serialized data");
return false;
}

// Even if serialize python message deterministic above, the
// serialize result may still diff between languages. So parse to
// another c++ message for compare.
std::unique_ptr<google::protobuf::Message> parsed_msg(owned_msg_->New());
parsed_msg->ParseFromArray(data, static_cast<int>(len));
std::string wire_other;
google::protobuf::io::StringOutputStream stream_other(&wire_other);
google::protobuf::io::CodedOutputStream output_other(&stream_other);
output_other.SetSerializationDeterministic(true);
parsed_msg->SerializeToCodedStream(&output_other);

std::string wire;
google::protobuf::io::StringOutputStream stream(&wire);
google::protobuf::io::CodedOutputStream output(&stream);
output.SetSerializationDeterministic(true);
owned_msg_->SerializeToCodedStream(&output);

if (wire == wire_other) {
Py_DECREF(py_serialized_pb);
return true;
}
PyErr_Format(PyExc_ValueError, "pymessage has been changed");
Py_DECREF(py_serialized_pb);
return false;
}

PythonConstMessagePointer::~PythonConstMessagePointer() {
if (py_msg_ == nullptr) {
ABSL_DCHECK(message_ == nullptr);
ABSL_DCHECK(owned_msg_ == nullptr);
return;
}
ABSL_DCHECK(owned_msg_ != nullptr);
ABSL_DCHECK(NotChanged());
Py_DECREF(py_msg_);
}

PythonConstMessagePointer PyProto_API::CreatePythonConstMessagePointer(
Message* owned_msg, const Message* msg, PyObject* py_msg) const {
return PythonConstMessagePointer(owned_msg, msg, py_msg);
}

} // namespace python
} // namespace protobuf
} // namespace google
44 changes: 41 additions & 3 deletions python/google/protobuf/proto_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// PyProtoAPICapsuleName(), 0));
// if (!py_proto_api) { ...handle ImportError... }
// Then use the methods of the returned class:
// py_proto_api->GetMessagePointer(...);
// py_proto_api->GetConstMessagePointer(...);

#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
Expand All @@ -31,11 +31,14 @@
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/message.h"

PyObject* pymessage_mutate_const(PyObject* self, PyObject* args);

namespace google {
namespace protobuf {
namespace python {

class PythonMessageMutator;
class PythonConstMessagePointer;

// Note on the implementation:
// This API is designed after
Expand All @@ -55,23 +58,36 @@ struct PyProto_API {
// Side-effect: The message will definitely be cleared. *When* the message
// gets cleared is undefined (C++ will clear it up-front, python/upb will
// clear it on destruction). Nothing should rely on the python message
// during the lifetime of this object
// during the lifetime of this object.
// User should not hold onto the returned PythonMessageMutator while
// calling back into Python
// calling back into Python.
// Warning: there is a risk of deadlock with Python/C++ if users use the
// returned message->GetDescriptor()->file->pool()
virtual absl::StatusOr<PythonMessageMutator> GetClearedMessageMutator(
PyObject* msg) const = 0;

// Returns a PythonConstMessagePointer. For UPB and Pure Python, it points
// to a new c++ message copied from python message. For cpp extension, it
// points the internal c++ message.
// User should not hold onto the returned PythonConstMessagePointer
// while calling back into Python.
virtual absl::StatusOr<PythonConstMessagePointer> GetConstMessagePointer(
PyObject* msg) const = 0;

// If the passed object is a Python Message, returns its internal pointer.
// Otherwise, returns NULL with an exception set.
// TODO: Remove deprecated GetMessagePointer().
[[deprecated(
"GetMessagePointer() only work with Cpp Extension, "
"please migrate to GetConstMessagePointer().")]]
virtual const Message* GetMessagePointer(PyObject* msg) const = 0;

// If the passed object is a Python Message, returns a mutable pointer.
// Otherwise, returns NULL with an exception set.
// This function will succeed only if there are no other Python objects
// pointing to the message, like submessages or repeated containers.
// With the current implementation, only empty messages are in this case.
// TODO: Remove deprecated GetMutableMessagePointer().
[[deprecated(
"GetMutableMessagePointer() only work with Cpp Extension, "
"please migrate to GetClearedMessageMutator().")]]
Expand Down Expand Up @@ -133,6 +149,8 @@ struct PyProto_API {
PythonMessageMutator CreatePythonMessageMutator(Message* owned_msg,
Message* msg,
PyObject* py_msg) const;
PythonConstMessagePointer CreatePythonConstMessagePointer(
Message* owned_msg, const Message* msg, PyObject* py_msg) const;
};

// User should not hold onto this object while calling back into Python
Expand Down Expand Up @@ -161,6 +179,26 @@ class PythonMessageMutator {
PyObject* py_msg_;
};

class PythonConstMessagePointer {
public:
PythonConstMessagePointer(PythonConstMessagePointer&& other);
~PythonConstMessagePointer();

const Message& get() { return *message_; }

private:
friend struct google::protobuf::python::PyProto_API;
PythonConstMessagePointer(Message* owned_msg, const Message* message,
PyObject* py_msg);

friend PyObject* ::pymessage_mutate_const(PyObject* self, PyObject* args);
// Check if the const message has been changed.
bool NotChanged();
std::unique_ptr<Message> owned_msg_;
const Message* message_;
PyObject* py_msg_;
};

inline const char* PyProtoAPICapsuleName() {
static const char kCapsuleName[] = "google.protobuf.pyext._message.proto_API";
return kCapsuleName;
Expand Down
115 changes: 82 additions & 33 deletions python/google/protobuf/pyext/message_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,54 +158,103 @@ google::protobuf::DynamicMessageFactory* GetFactory() {
return factory;
}

absl::StatusOr<google::protobuf::Message*> CreateNewMessage(PyObject* py_msg) {
PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR");
if (pyd == nullptr) {
return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'");
}

PyObject* fn = PyObject_GetAttrString(pyd, "full_name");
if (fn == nullptr) {
return absl::InvalidArgumentError(
"DESCRIPTOR has no attribute 'full_name'");
}

const char* descriptor_full_name = PyUnicode_AsUTF8(fn);
if (descriptor_full_name == nullptr) {
return absl::InternalError("Fail to convert descriptor full name");
}

PyObject* pyfile = PyObject_GetAttrString(pyd, "file");
Py_DECREF(pyd);
if (pyfile == nullptr) {
return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'");
}
auto gen_d = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
descriptor_full_name);
if (gen_d) {
Py_DECREF(pyfile);
Py_DECREF(fn);
return google::protobuf::MessageFactory::generated_factory()
->GetPrototype(gen_d)
->New();
}
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
Py_DECREF(pyfile);
RETURN_IF_ERROR(d.status());
Py_DECREF(fn);
return GetFactory()->GetPrototype(*d)->New();
}

bool CopyToOwnedMsg(google::protobuf::Message** copy, const google::protobuf::Message& message) {
*copy = message.New();
std::string wire;
message.SerializeToString(&wire);
(*copy)->ParseFromArray(wire.data(), wire.size());
return true;
}

// C++ API. Clients get at this via proto_api.h
struct ApiImplementation : google::protobuf::python::PyProto_API {
absl::StatusOr<google::protobuf::python::PythonMessageMutator> GetClearedMessageMutator(
PyObject* py_msg) const override {
if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) {
google::protobuf::Message* message =
google::protobuf::python::PyMessage_GetMutableMessagePointer(py_msg);
if (message == nullptr) {
return absl::InternalError(
"Fail to get message pointer. The message "
"may already had a reference.");
}
message->Clear();
return CreatePythonMessageMutator(nullptr, message, py_msg);
}
PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR");
if (pyd == nullptr) {
return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'");
}

PyObject* fn = PyObject_GetAttrString(pyd, "full_name");
if (fn == nullptr) {
return absl::InvalidArgumentError(
"DESCRIPTOR has no attribute 'full_name'");
}
auto msg = CreateNewMessage(py_msg);
RETURN_IF_ERROR(msg.status());
return CreatePythonMessageMutator(*msg, *msg, py_msg);
}

const char* descriptor_full_name = PyUnicode_AsUTF8(fn);
if (descriptor_full_name == nullptr) {
return absl::InternalError("Fail to convert descriptor full name");
absl::StatusOr<google::protobuf::python::PythonConstMessagePointer>
GetConstMessagePointer(PyObject* py_msg) const override {
if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) {
const google::protobuf::Message* message =
google::protobuf::python::PyMessage_GetMessagePointer(py_msg);
google::protobuf::Message* owned_msg = nullptr;
ABSL_DCHECK(CopyToOwnedMsg(&owned_msg, *message));
return CreatePythonConstMessagePointer(owned_msg, message, py_msg);
}

PyObject* pyfile = PyObject_GetAttrString(pyd, "file");
Py_DECREF(pyd);
if (pyfile == nullptr) {
return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'");
auto msg = CreateNewMessage(py_msg);
RETURN_IF_ERROR(msg.status());
PyObject* serialized_pb(
PyObject_CallMethod(py_msg, "SerializeToString", nullptr));
if (serialized_pb == nullptr) {
return absl::InternalError("Fail to serialize py_msg");
}
auto gen_d =
google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
descriptor_full_name);
if (gen_d) {
Py_DECREF(pyfile);
Py_DECREF(fn);
google::protobuf::Message* msg = google::protobuf::MessageFactory::generated_factory()
->GetPrototype(gen_d)
->New();
return CreatePythonMessageMutator(msg, msg, py_msg);
char* data;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(serialized_pb, &data, &len) < 0) {
Py_DECREF(serialized_pb);
return absl::InternalError(
"Fail to get bytes from py_msg serialized data");
}
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
Py_DECREF(pyfile);
RETURN_IF_ERROR(d.status());
Py_DECREF(fn);
google::protobuf::Message* msg = GetFactory()->GetPrototype(*d)->New();
return CreatePythonMessageMutator(msg, msg, py_msg);
if (!(*msg)->ParseFromArray(data, len)) {
Py_DECREF(serialized_pb);
return absl::InternalError(
"Couldn't parse py_message to google::protobuf::Message*!");
}
Py_DECREF(serialized_pb);
return CreatePythonConstMessagePointer(*msg, *msg, py_msg);
}

const google::protobuf::Message* GetMessagePointer(PyObject* msg) const override {
Expand Down

0 comments on commit b9e69e8

Please sign in to comment.