Read speeds decrease 2x when reading with fewer processes
The issue
Given a specific checkpoint, load it in two different settings:
- Load it with 64 nodes, 512 GPUs, 512 processes (1 GPU / process).
- Load it with 64 nodes, 512 GPUs, 64 processes (8 GPUs / process).
What I observe:
- Using 512 processes, reading takes ~20 seconds.
- Using 64 processes, reading takes ~40 seconds (2x).
The checkpoint in question is also written with 512 processes (see below for repro). Except for the number of processes, nothing else changes (sharding etc. stays the same).
To reproduce.
Download this file and run it in a context with 64 nodes, 8 GPUs each. Make sure hostfile has the hostnames of the 64 nodes. (mpirun isn't essential here, it's just a way to spawn these processes.)
To create the checkpoint:
mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 512 -npernode 8 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'
To load the checkpoint with 512 processes:
mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 512 -npernode 8 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'
This takes ~20 sec for me.
To load the checkpoint with 64 processes:
mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 64 -npernode 1 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'
This takes ~40 sec for me.
The issue doesn't seem to be in Orbax because the same happens with a plain jax.experimental.serialization.async_deserialize.
I don't have access to a cluster; is there a local method to run this?
I'm trying something like:
sudo apt install libopenmpi-dev
python3 -m venv ts-venv
source ts-venv/bin/activate
python3 -m pip install jax orbax numpy tensorstore mpi4py
mpirun -np 1 python3 $(pwd)/ts_mpitest.py /tmp/data1/ $(hostname):3345
Edit: So I hacked your file and replaced some large values to run on my machine with num processes = 1 as well as changing you large arrays to much smaller ones. This won't be anything like what you've got:
for x in ts.experimental_collect_matching_metrics():
print(x)
The tensorstore spec looks something like:
$ mpirun -np 1 python3 $(pwd)/ts_mpitest.py /usr/local/google/tmp/data1/ $(hostname):3345
Loading existing checkpoint
Starting checkpoint load
{'driver': 'zarr', 'kvstore': {'driver': 'ocdbt', 'base': {'driver': 'file', 'path': '/usr/local/google/tmp/data1'}, 'path': '0', 'experimental_read_coalescing_threshold_bytes': 1000000, 'experimental_read_coalescing_merged_bytes': 500000000000, 'experimental_read_coalescing_interval': '1ms', 'cache_pool': 'cache_pool#ocdbt'}, 'recheck_cached_data': False, 'recheck_cached_metadata': False}
...
Loaded checkpoint from /usr/local/google/tmp/data1/ in 107.53 sec
{'name': '/tensorstore/cache/chunk_cache/reads', 'values': [{'value': 33}]}
{'name': '/tensorstore/cache/hit_count', 'values': [{'value': 44}]}
{'name': '/tensorstore/cache/kvs_cache_read', 'values': [{'category': 'changed', 'value': 46}]}
{'name': '/tensorstore/cache/miss_count', 'values': [{'value': 48}]}
{'name': '/tensorstore/futures/force_callbacks', 'values': [{'value': 303}]}
{'name': '/tensorstore/futures/live', 'values': [{'max_value': 162, 'value': 1}]}
{'name': '/tensorstore/futures/not_needed_callbacks', 'values': [{'value': 45}]}
{'name': '/tensorstore/futures/ready_callbacks', 'values': [{'value': 436}]}
{'name': '/tensorstore/internal/riegeli/noncontiguous_bytes', 'values': [{'value': 54767124480}]}
{'name': '/tensorstore/internal/thread/schedule_at/insert_histogram_ms', 'values': [{'0': 0, '1': 46, 'count': 46, 'mean': 0.0, 'sum_of_squared_deviation': 0.0}]}
{'name': '/tensorstore/internal/thread/schedule_at/next_event', 'values': [{'value': 'infinite-future'}]}
{'name': '/tensorstore/internal/thread/schedule_at/queued_ops', 'values': [{'max_value': 15, 'value': 0}]}
{'name': '/tensorstore/kvstore/file/batch_read', 'values': [{'value': 27}]}
{'name': '/tensorstore/kvstore/file/bytes_read', 'values': [{'value': 50616803970}]}
{'name': '/tensorstore/kvstore/file/open_read', 'values': [{'value': 27}]}
{'name': '/tensorstore/kvstore/file/read', 'values': [{'value': 27}]}
{'name': '/tensorstore/kvstore/file/read_latency_ms', 'values': [{'0': 0, '1': 2, '10': 0, '11': 0, '12': 0, '13': 0, '14': 3, '15': 6, '16': 7, '17': 6, '2': 0, '3': 0, '4': 0, '5': 2, '6': 0, '7': 1, '8': 0, '9': 0, 'count': 27, 'mean': 20684.2962962963, 'sum_of_squared_deviation': 9117420933.62963}]}
{'name': '/tensorstore/kvstore/ocdbt/read', 'values': [{'value': 45}]}
{'name': '/tensorstore/thread_pool/active', 'values': [{'max_value': 24, 'value': 13}]}
{'name': '/tensorstore/thread_pool/max_delay_ns', 'values': [{'max_value': 16552002634}]}
{'name': '/tensorstore/thread_pool/started', 'values': [{'value': 24}]}
{'name': '/tensorstore/thread_pool/steal_count', 'values': [{'value': 37.0}]}
{'name': '/tensorstore/thread_pool/task_providers', 'values': [{'max_value': 2, 'value': 0}]}
{'name': '/tensorstore/thread_pool/total_queue_time_ns', 'values': [{'value': 95758202782.0}]}
{'name': '/tensorstore/thread_pool/work_time_ns', 'values': [{'value': 1072415922369.0}]}
Hey Laramie - thanks for taking a look!
Unfortunately, I haven't managed to create a smaller repro yet. I'll run with experimental_collect_matching_metrics and get back to you soon.
More generally, do you know of any settings I might need to change to increase the per-process throughput? Or failing that, is there a (possibly hacky) way to have separate independent TensorStore clients within a single process? I suspect there's some kind of per-process limit (threadpool, TCP/IP connections, etc) that we hit here.
At the tensorstore layer this is using an ocdbt kvstore on top of a file kvstore. Tensorstore has some context settings for files which you could try: https://google.github.io/tensorstore/kvstore/file/index.html
Try setting "file_io_concurrency", which defaults to max(4, hardware_concurrency).
https://en.cppreference.com/w/cpp/thread/thread/hardware_concurrency
You could also add detailed logging to the file operations via TENSORSTORE_VERBOSE_LOGGING=file=2
How many hosts are in your hostfile? And what is the underlying filesystem?
There's 64 nodes (it says so in the issue description above). The file system is a distributed file system a la Lustre or VAST.
I already tried setting file_io_concurrency manually and it didn't seem to help.
I don't work on tensorstore directly, but one setting I found helps with loading performance sometimes is the ocdbt_target_data_file_size
def save(state, path, ocdbt_target_file_size: int = 2 * 1024 ** 3):
start = time.time()
ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).save(
path, ocp.args.PyTreeSave(
item=state, ocdbt_target_data_file_size=ocdbt_target_file_size))
log(f"Saved checkpoint to {path} in {time.time() - start:.2f} sec")
def load(path, shape_dtype):
start = time.time()
state = ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).restore(
path, ocp.args.PyTreeRestore(
shape_dtype, restore_args=ocp.checkpoint_utils.construct_restore_args(shape_dtype),
))
end = time.time()
log(f"Loaded checkpoint from {path} in {end - start:.2f} sec")
return state
2 GB is the default, but going smaller might help
I imagine that a lot of the performance will have to do with specific details about how the filesystem interaction happens. So this is basically running either a single process per node (n=64) or 8 processes per node (n=512).
If it's related to file_io_concurrency that implies going from something like 8x8 threads issuing io to something like 1x8 threads issuing io (if hardware_concurrency is, for example 8).
I would be interested to see the output of the tensorstore counters on for the various configs.
Edit: Looking at orbax it appears that file_io_concurrency has been set to an adequately large value.
https://github.com/google/orbax/blob/d27fcdd8e9227fcd3d631554f17fc90e4c04e150/checkpoint/orbax/checkpoint/type_handlers.py#L58
It would be nice to get a pprof of these; is that possible?
Ok, I figured out an inconsistency with our internal build which makes logging hard to use in python. Once I get it added then it will be easier to debug what's going on.
You should now be able to set this environment variable and look at the io timing across runs:
TENSORSTORE_VERBOSE_LOGGING=file=1,file_detail=2
I just submitted a tscli change that will help me to create better test harness/benchmark for this case.
If you could run this and let me know what the output is I'd appreciate it.
If the parameter names have meaning (in the path component) it's fine to redact them in a consistent way.
git clone https://github.com/google/tensorstore
cd tensorstore
./bazelisk.py build //tensorstore/tscli
alias tscli=$(pwd)/bazel-bin/tensorstore/tscli/tscli
for x in $(tscli search file:/// /usr/local/google/tmp/data1/); do
echo
tscli print_spec --spec "$x"
tscli print_stats --spec "$x"
echo
done
I have been running a variant of this with my updated multi_read_benchmark. We found some internal tensorstore chunk cache contention which may help here. It was alleviated in 5927385942c950048df3143651f72ec27ce7be31