onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Support direct usage of ORT format model flatbuffer for initializers

Open skottmckay opened this issue 3 years ago • 6 comments

Description: An ORT format model contains initializer data that we currently copy into a TensorProto during model load, and copy again into an OrtValue<Tensor> during session state finalization. We can do some optimizations to try and keep peak memory usage from these steps to roughly 2x the original size of the initializers, but that is still inefficient in a mobile scenario.

There is no way to populate the raw_data field of a TensorProto using an existing buffer. The OrtValue<Tensor> however does support the Tensor being constructed from an existing buffer with optional ownership transfer.

There is the capability for a TensorProto to point to external data. Typically the external data is stored in a separate file to the model, and the TensorProto contains the filename, offset and size of the data. We can leverage this mechanism to point to external data that is already resident in memory (from the ORT format model flatbuffer) by using a special tag for the filename and storing the memory address in the 'offset' field.

The existing code to create an OrtValue<Tensor> from a TensorProto containing external data supports the copy-free approach of passing along a pointer with optional transfer of ownership to the OrtValue, as we normally mmap the file containing the external data and use the address of that buffer.

This PR contains the small set of changes necessary to implement this approach to gather feedback. The usage is limited to an ORT format model where the caller provides a buffer containing the pre-loaded bytes for the model and they set a flag specifying not to copy the bytes (signifying that memory usage is important to them). An additional flag is provided to allow specifying that we may also use the buffer directly for initializers, as that creates a new requirement that the buffer remain valid for the entire duration of the InferenceSession (vs. currently where it is only required to be valid until InferenceSession initialization completes).

Motivation and Context We have production mobile scenarios that require a reduction in peak memory usage.

Test output from potential production model is below. ORT format model being tested is 13.6MB.

Peak Working Set Size in bytes

Stage Original New Notes
Pre-load of model into buffer 7,737,344 7,663,616 Baseline with overhead from onnxruntime_test_all size
Model loaded. <BR>Pre-InferenceSession::Load 23,093,248 23,040,000 Roughly equal as expected
Post InferenceSession.Load. <BR>Pre InferenceSession::Initialize 42,946,560 29,179,904 13MB+ reduction
PostInferenceSession.Initialize 61,435,904 34,676,736 25MB+ reduction

NOTE: Pre-packing was disabled for this testing. If we are using the user-provided buffer directly for the initializers, the pre-packing causes an additional copy of the initializer data when creating the pre-packed OrtValue<Tensor>, and we can't free the original initializer data as that is within the user-provided buffer. If that buffer was mutable we could potentially do in-place pre-packing (pre-pack to temporary buffer, replace original data) to avoid that copy. This is a separate problem to solve if pre-packing is also required in the production scenario.

skottmckay avatar Aug 04 '22 10:08 skottmckay

                            std::vector<uint8_t>& bytes_data_holder) {

Should this be uint8_t[]?


Refers to: onnxruntime/core/session/inference_session.cc:993 in e30ebad. [](commit_id = e30ebad04100d56907ebe644bc0568e96a069021, deletion_comment = False)

yuslepukhin avatar Aug 04 '22 23:08 yuslepukhin

LGTM

yuslepukhin avatar Aug 04 '22 23:08 yuslepukhin

                            std::vector<uint8_t>& bytes_data_holder) {

The std::vector is a member of InferenceSession and is providing the storage.


In reply to: 1205866045


Refers to: onnxruntime/core/session/inference_session.cc:993 in e30ebad. [](commit_id = e30ebad04100d56907ebe644bc0568e96a069021, deletion_comment = False)

skottmckay avatar Aug 05 '22 04:08 skottmckay

Could we ask customers to use the AddExternalInitializers API instead?

pranavsharma avatar Aug 08 '22 20:08 pranavsharma

Could we ask customers to use the AddExternalInitializers API instead?

Not sure that helps. Looks like the API takes an OrtValue that we convert to a TensorProto (and I assume back to an OrtValue during session state finalization). Due to that would you still have 2x the memory usage for each initializer?

https://github.com/microsoft/onnxruntime/blob/8a86b346a5fef042ab8a91a50eb05feab482e122/onnxruntime/core/graph/graph.cc#L2866-L2873

skottmckay avatar Aug 08 '22 21:08 skottmckay

Could we ask customers to use the AddExternalInitializers API instead?

Not sure that helps. Looks like the API takes an OrtValue that we convert to a TensorProto (and I assume back to an OrtValue during session state finalization). Due to that would you still have 2x the memory usage for each initializer?

https://github.com/microsoft/onnxruntime/blob/8a86b346a5fef042ab8a91a50eb05feab482e122/onnxruntime/core/graph/graph.cc#L2866-L2873

Yeah, the approach in the PR looks fine to me. Just some minor comments.

pranavsharma avatar Aug 09 '22 01:08 pranavsharma