verl icon indicating copy to clipboard operation
verl copied to clipboard

Fully Async Recipe OOM with only 8192 response length

Open PokeLu opened this issue 4 months ago • 6 comments

Hello, thanks for the developers for finally get the fully async policy working in Verl.

In my personal trial, I modified the geo3k_qwen25vl_7b_megatron_4_4 script to test gsm8k on Qwen3 4B under the commit 0bd03459. I used 2 * 8 * H100. With only a prompt length of 2048 and a response length of 8192, I still encountered the OOM issue while training even with all possible offload measures. Here is the full script:

HF_MODEL_PATH=${HF_MODEL_PATH:-"/mnt/data/llm/ht1/models/Qwen3-4B"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/hpc_scripts/runtime_env/dapo.yaml"}
export TENSORBOARD_DIR=${WORKING_DIR}/tensorboard/${project_name}/${exp_name}

train_path="['/mnt/data/llm/ht1/verl/data/processed_gsm8k/train.parquet']"
test_path="['/mnt/data/llm/ht1/verl/data/processed_gsm8k/test.parquet']"

CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"/mnt/data/llm/ht1/verl/data/processed_gsm8k/train.parquet"}
TEST_FILE=${TEST_FILE:-"/mnt/data/llm/ht1/verl/data/processed_gsm8k/test.parquet"}

rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
    export VLLM_USE_V1=1
    return_raw_chat="True"
fi

# Algorithm parameters
adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=True
kl_loss_coef=0.005

# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

# Training param args
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
ref_offload=True
actor_offload=True
gen_tp=1
train_tp=8
train_pp=1
recompute_granularity=selective
recompute_modules='["core_attn", "mlp"]'

# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-1}
NNODES_TRAIN=${NNODES_TRAIN:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=8
train_prompt_mini_bsz=64
total_rollout_steps=$(((512*100)))
test_freq=8
save_freq=-1
staleness_threshold=0.1
trigger_parameter_sync_step=8
require_batches=4
partial_rollout=True
total_epochs=200

submit_ray_job() {
    ### important
    # actor_rollout_ref.actor.optim.lr_decay_steps>0 is a must cos' assertion error from mcore
    ray job submit --working-dir ./ --runtime-env="${RUNTIME_ENV}" \
        --address=http://localhost:${DASHBOARD_PORT} -- \
    python -m recipe.fully_async_policy.fully_async_main \
        --config-path=config \
        --config-name='fully_async_ppo_megatron_trainer.yaml'\
        data.train_files="$train_path" \
        data.val_files="$test_path" \
        data.train_batch_size=${train_prompt_bsz} \
        data.max_prompt_length=${max_prompt_length} \
        data.max_response_length=${max_response_length} \
        data.filter_overlong_prompts=True \
        data.truncation='error' \
        data.gen_batch_size=${gen_prompt_bsz} \
        data.return_raw_chat=${return_raw_chat} \
        algorithm.adv_estimator=${adv_estimator} \
        algorithm.use_kl_in_reward=${use_kl_in_reward} \
        algorithm.kl_ctrl.kl_coef=${kl_coef} \
        actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
        actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
        actor_rollout_ref.actor.kl_loss_type=low_var_kl \
        actor_rollout_ref.model.path=$HF_MODEL_PATH \
        actor_rollout_ref.hybrid_engine=False \
        actor_rollout_ref.nccl_timeout=7200 \
        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
        actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
        actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
        actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
        actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
        actor_rollout_ref.actor.optim.lr=2e-6 \
        actor_rollout_ref.actor.optim.min_lr=0.0 \
        actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
        actor_rollout_ref.actor.optim.lr_decay_style=constant \
        actor_rollout_ref.actor.optim.lr_decay_steps=51200 \
        actor_rollout_ref.actor.optim.weight_decay=0.1 \
        actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity="${recompute_granularity}" \
        actor_rollout_ref.actor.megatron.override_transformer_config.recompute_modules="${recompute_modules}" \
        actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \
        actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \
        actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \
        +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
        +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \
        +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
        +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
        actor_rollout_ref.actor.megatron.use_mbridge=True \
        actor_rollout_ref.actor.entropy_coeff=0 \
        actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
        actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
        actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \
        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
        actor_rollout_ref.rollout.calculate_log_probs=True \
        actor_rollout_ref.rollout.name=$ENGINE \
        actor_rollout_ref.rollout.max_model_len=$((max_prompt_length + max_response_length)) \
        actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
        actor_rollout_ref.rollout.mode=${rollout_mode} \
        actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
        actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
        algorithm.use_kl_in_reward=False \
        trainer.critic_warmup=0 \
        trainer.logger='["console","tensorboard"]' \
        trainer.project_name="${project_name}" \
        trainer.experiment_name="${exp_name}" \
        trainer.test_freq="${test_freq}" \
        trainer.total_epochs="${total_epochs}" \
        trainer.val_before_train=False \
        trainer.save_freq=${save_freq} \
        trainer.default_local_dir="${CKPTS_DIR}" \
        trainer.resume_mode=auto \
        trainer.nnodes="${NNODES_TRAIN}" \
        trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
        rollout.nnodes="${NNODES_ROLLOUT}" \
        rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
        rollout.total_rollout_steps="${total_rollout_steps}" \
        rollout.total_epochs="${total_epochs}" \
        rollout.test_freq="${test_freq}" \
        async_training.staleness_threshold="${staleness_threshold}" \
        async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
        async_training.require_batches="${require_batches}" \
        async_training.partial_rollout="${partial_rollout}" \
        async_training.use_rollout_log_probs=True
}

Here is the error log:

(WorkerDict pid=132182) *** SIGSEGV received at time=1762101582 on cpu 35 ***
(WorkerDict pid=132182) PC: @     0x7efc286a3193  (unknown)  c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::malloc()
(WorkerDict pid=132182)     @     0x7efc5f842520       5728  (unknown)
(WorkerDict pid=132182)     @     0x7efc286a4b85        520  c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::malloc()
(WorkerDict pid=132182)     @     0x7ec867ffd2f0  (unknown)  (unknown)
(WorkerDict pid=132182) [2025-11-02 16:39:42,303 E 132182 138464] logging.cc:496: *** SIGSEGV received at time=1762101582 on cpu 35 ***
(WorkerDict pid=132182) [2025-11-02 16:39:42,303 E 132182 138464] logging.cc:496: PC: @     0x7efc286a3193  (unknown)  c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::malloc()
(WorkerDict pid=132182) [2025-11-02 16:39:42,306 E 132182 138464] logging.cc:496:     @     0x7efc5f842520       5728  (unknown)
(WorkerDict pid=132182) [2025-11-02 16:39:42,306 E 132182 138464] logging.cc:496:     @     0x7efc286a4b85        520  c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::malloc()
(WorkerDict pid=132182) [2025-11-02 16:39:42,309 E 132182 138464] logging.cc:496:     @     0x7ec867ffd2f0  (unknown)  (unknown)
(WorkerDict pid=132182) Fatal Python error: Segmentation fault
(WorkerDict pid=132182) 
(WorkerDict pid=132182) Stack (most recent call first):
(WorkerDict pid=132182)   <no Python frame>
(WorkerDict pid=132182) 
(WorkerDict pid=132182) Extension modules: msgpack._cmsgpack, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, ray._raylet, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, regex._regex, markupsafe._speedups, PIL._imaging, nvtx._lib.lib, nvtx._lib.profiler, cuda_utils, sklearn.__check_build._check_build, cython.cimports.libc.math, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._ansari_swilk_statistics, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, scipy.stats._unuran.unuran_wrapper, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, sklearn.utils._isfinite, sklearn.utils.sparsefuncs_fast, sklearn.utils.murmurhash, sklearn.utils._openmp_helpers, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.preprocessing._csr_polynomial_expansion, sklearn.preprocessing._target_encoder_fast, sklearn.metrics._dist_metrics, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.metrics._pairwise_distances_reduction._argkmin_classmode, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_distances_reduction._radius_neighbors_classmode, sklearn.metrics._pairwise_fast, _cffi_backend, PIL._imagingft, av._core, av.logging, av.bytesource, av.buffer, av.audio.format, av.error, av.dictionary, av.container.pyio, av.utils, av.option, av.descriptor, av.format, av.stream, av.container.streams, av.sidedata.motionvectors, av.sidedata.sidedata, av.opaque, av.packet, av.container.input, av.container.output, av.container.core, av.codec.context, av.video.format, av.video.reformatter, av.plane, av.video.plane, av.video.frame, av.video.stream, av.codec.hwaccel, av.codec.codec, av.frame, av.audio.layout, av.audio.plane, av.audio.frame, av.audio.stream, av.filter.pad, av.filter.link, av.filter.context, av.filter.graph, av.filter.filter, av.filter.loudnorm, av.audio.resampler, av.audio.codeccontext, av.audio.fifo, av.bitstream, av.video.codeccontext, pyarrow.lib, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, pyarrow._parquet, pyarrow._fs, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, multidict._multidict, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, xxhash._xxhash, pyarrow._acero, pyarrow._csv, pyarrow._json, pyarrow._substrait, pyarrow._dataset, pyarrow._dataset_orc, pyarrow._parquet_encryption, pyarrow._dataset_parquet_encryption, pyarrow._dataset_parquet, cupy_backends.cuda._softlink, cupy_backends.cuda.api._runtime_enum, cupy_backends.cuda.api.runtime, cupy._util, cupy.cuda.device, fastrlock.rlock, cupy.cuda.memory_hook, cupy_backends.cuda.stream, cupy.cuda.graph, cupy.cuda.stream, cupy_backends.cuda.api._driver_enum, cupy_backends.cuda.api.driver, cupy.cuda.memory, cupy._core.internal, cupy._core._carray, cupy.cuda.texture, cupy.cuda.function, cupy_backends.cuda.libs.nvrtc, cupy.cuda.pinned_memory, cupy.cuda.common, cupy.cuda.cub, cupy_backends.cuda.libs.nvtx, cupy.cuda.thrust, cupy._core._dtype, cupy._core._scalar, cupy._core._accelerator, cupy._core._memory_range, cupy_backends.cuda.libs.cutensor, cupy._core._fusion_thread_local, cupy._core._kernel, cupy._core._routines_manipulation, cupy._core._routines_binary, cupy._core._optimize_config, cupy._core._cub_reduction, cupy._core._reduction, cupy._core._routines_math, cupy._core._routines_indexing, cupy._core._routines_linalg, cupy._core._routines_logic, cupy._core._routines_sorting, cupy._core._routines_statistics, cupy._core.dlpack, cupy._core.flags, cupy._core.core, cupy._core._fusion_variable, cupy._core._fusion_trace, cupy._core._fusion_kernel, cupy._core.new_fusion, cupy._core.fusion, cupy._core.raw, cupy.fft._cache, cupy.fft._callback, cupy.random._bit_generator, cupyx.cutensor, cupy.lib._polynomial, cupy_backends.cuda.libs.nccl (total: 316)
(WorkerDict pid=132184) /usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:824: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.) [repeated 7x across cluster]
(WorkerDict pid=132184)   return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass [repeated 7x across cluster]
(WorkerDict pid=132181) dlca1ik21m92lu3n-master-0:132181:138578 [0] NCCL INFO [Service thread] Connection closed by localRank 1
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff1bc865d7417ba6d5182635e30a000000 Worker ID: fbd84e578c92ae1ed64fd6327fca115cfa3b59ca6b971f8b13e0dfa2 Node ID: 52cc930d342eef8646fca91c7d31615f3d57028563cdfacae73e74ed Worker IP address: 172.16.2.252 Worker port: 10312 Worker PID: 132182 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
Error executing job with overrides: ['data.train_files=[/mnt/data/llm/ht1/verl/data/processed_gsm8k/train.parquet]', 'data.val_files=[/mnt/data/llm/ht1/verl/data/processed_gsm8k/test.parquet]', 'data.train_batch_size=0', 'data.max_prompt_length=2048', 'data.max_response_length=8192', 'data.filter_overlong_prompts=True', 'data.truncation=error', 'data.gen_batch_size=1', 'data.return_raw_chat=True', 'algorithm.adv_estimator=grpo', 'algorithm.use_kl_in_reward=False', 'algorithm.kl_ctrl.kl_coef=0.0', 'actor_rollout_ref.actor.use_kl_loss=True', 'actor_rollout_ref.actor.kl_loss_coef=0.005', 'actor_rollout_ref.actor.kl_loss_type=low_var_kl', 'actor_rollout_ref.model.path=/mnt/data/llm/ht1/models/Qwen3-4B', 'actor_rollout_ref.hybrid_engine=False', 'actor_rollout_ref.nccl_timeout=7200', 'actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1', 'actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1', 'actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1', 'actor_rollout_ref.actor.use_dynamic_bsz=True', 'actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True', 'actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True', 'actor_rollout_ref.actor.ppo_max_token_len_per_gpu=10240', 'actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480', 'actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480', 'actor_rollout_ref.actor.ppo_mini_batch_size=64', 'actor_rollout_ref.actor.optim.lr=2e-6', 'actor_rollout_ref.actor.optim.min_lr=0.0', 'actor_rollout_ref.actor.optim.lr_warmup_steps=10', 'actor_rollout_ref.actor.optim.lr_decay_style=constant', 'actor_rollout_ref.actor.optim.lr_decay_steps=51200', 'actor_rollout_ref.actor.optim.weight_decay=0.1', 'actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=selective', 'actor_rollout_ref.actor.megatron.override_transformer_config.recompute_modules=["core_attn", "mlp"]', 'actor_rollout_ref.actor.megatron.param_offload=True', 'actor_rollout_ref.actor.megatron.optimizer_offload=True', 'actor_rollout_ref.actor.megatron.grad_offload=True', '+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True', '+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1', '+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True', '+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True', 'actor_rollout_ref.actor.megatron.use_mbridge=True', 'actor_rollout_ref.actor.entropy_coeff=0', 'actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1', 'actor_rollout_ref.ref.megatron.tensor_model_parallel_size=8', 'actor_rollout_ref.ref.megatron.param_offload=True', 'actor_rollout_ref.rollout.tensor_model_parallel_size=1', 'actor_rollout_ref.rollout.calculate_log_probs=True', 'actor_rollout_ref.rollout.name=vllm', 'actor_rollout_ref.rollout.max_model_len=10240', 'actor_rollout_ref.rollout.max_num_batched_tokens=10240', 'actor_rollout_ref.rollout.mode=async', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.9', 'actor_rollout_ref.rollout.n=8', 'algorithm.use_kl_in_reward=False', 'trainer.critic_warmup=0', 'trainer.logger=["console","tensorboard"]', 'trainer.project_name=Async_GRPO', 'trainer.experiment_name=Qwen3-4B-megatron-fully-async-8-8', 'trainer.test_freq=8', 'trainer.total_epochs=200', 'trainer.val_before_train=False', 'trainer.save_freq=-1', 'trainer.default_local_dir=/mnt/data/llm/ht1/verl-main-0bd0345/ckpts/Async_GRPO/Qwen3-4B-megatron-fully-async-8-8', 'trainer.resume_mode=auto', 'trainer.nnodes=1', 'trainer.n_gpus_per_node=8', 'rollout.nnodes=1', 'rollout.n_gpus_per_node=8', 'rollout.total_rollout_steps=51200', 'rollout.total_epochs=200', 'rollout.test_freq=8', 'async_training.staleness_threshold=0.1', 'async_training.trigger_parameter_sync_step=8', 'async_training.require_batches=4', 'async_training.partial_rollout=True', 'async_training.use_rollout_log_probs=True']
Traceback (most recent call last):
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_main.py", line 307, in main
    run_ppo(config, task_runner_class=FullyAsyncTaskRunner)
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/verl/trainer/main_ppo.py", line 96, in run_ppo
    ray.get(runner.run.remote(config))
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2849, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 937, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ActorDiedError): ray::FullyAsyncTaskRunner.run() (pid=130397, ip=172.16.2.252, actor_id=1e1247aee30963ca18f2d8b00a000000, repr=<fully_async_main.FullyAsyncTaskRunner object at 0x7f1dac2cb610>)
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_main.py", line 139, in run
    self._run_training_loop()
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_main.py", line 283, in _run_training_loop
    raise e
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_main.py", line 277, in _run_training_loop
    ray.get(future)
ray.exceptions.RayTaskError(ActorDiedError): ray::FullyAsyncTrainer.fit() (pid=131100, ip=172.16.2.252, actor_id=cacf5baae4c84d8505e41d5a0a000000, repr=<recipe.fully_async_policy.fully_async_trainer.FullyAsyncTrainer object at 0x7fbe22327bb0>)
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_trainer.py", line 261, in fit
    batch, reward_extra_infos_dict = self._process_batch_common(
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/ray_trainer.py", line 460, in _process_batch_common
    actor_output = self.actor_rollout_wg.update_actor(batch)
  File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/verl/single_controller/ray/base.py", line 48, in __call__
    output = ray.get(output)
ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
        class_name: create_colocated_worker_cls.<locals>.WorkerDict
        actor_id: 1bc865d7417ba6d5182635e30a000000
        pid: 132182
        name: ra0GdtWorkerDict_0:1
        namespace: 31d1e1ed-4eeb-4a83-b235-4847d56a68c7
        ip: 172.16.2.252
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
(FullyAsyncTaskRunner pid=130397) /tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_main.py:293: RuntimeWarning: coroutine 'MessageQueueClient.clear_queue' was never awaited
(FullyAsyncTaskRunner pid=130397)   self.components["message_queue_client"].clear_queue()
(FullyAsyncTaskRunner pid=130397) RuntimeWarning: Enable tracemalloc to get the object allocation traceback
(FullyAsyncTaskRunner pid=130397) [ASYNC MAIN] Component failed with error: ray::FullyAsyncTrainer.fit() (pid=131100, ip=172.16.2.252, actor_id=cacf5baae4c84d8505e41d5a0a000000, repr=<recipe.fully_async_policy.fully_async_trainer.FullyAsyncTrainer object at 0x7fbe22327bb0>)
(FullyAsyncTaskRunner pid=130397)   File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_trainer.py", line 261, in fit
(FullyAsyncTaskRunner pid=130397)     batch, reward_extra_infos_dict = self._process_batch_common(
(FullyAsyncTaskRunner pid=130397)   File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/ray_trainer.py", line 460, in _process_batch_common
(FullyAsyncTaskRunner pid=130397)     actor_output = self.actor_rollout_wg.update_actor(batch)
(FullyAsyncTaskRunner pid=130397)   File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/verl/single_controller/ray/base.py", line 48, in __call__
(FullyAsyncTaskRunner pid=130397)     output = ray.get(output)
(FullyAsyncTaskRunner pid=130397) ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
(FullyAsyncTaskRunner pid=130397)       class_name: create_colocated_worker_cls.<locals>.WorkerDict
(FullyAsyncTaskRunner pid=130397)       actor_id: 1bc865d7417ba6d5182635e30a000000
(FullyAsyncTaskRunner pid=130397)       pid: 132182
(FullyAsyncTaskRunner pid=130397)       name: ra0GdtWorkerDict_0:1
(FullyAsyncTaskRunner pid=130397)       namespace: 31d1e1ed-4eeb-4a83-b235-4847d56a68c7
(FullyAsyncTaskRunner pid=130397)       ip: 172.16.2.252
(FullyAsyncTaskRunner pid=130397) The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
(FullyAsyncTaskRunner pid=130397) [ASYNC MAIN] Training failed: ray::FullyAsyncTrainer.fit() (pid=131100, ip=172.16.2.252, actor_id=cacf5baae4c84d8505e41d5a0a000000, repr=<recipe.fully_async_policy.fully_async_trainer.FullyAsyncTrainer object at 0x7fbe22327bb0>)
(FullyAsyncTaskRunner pid=130397)   File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/fully_async_trainer.py", line 261, in fit
(FullyAsyncTaskRunner pid=130397)     batch, reward_extra_infos_dict = self._process_batch_common(
(FullyAsyncTaskRunner pid=130397)   File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/recipe/fully_async_policy/ray_trainer.py", line 460, in _process_batch_common
(FullyAsyncTaskRunner pid=130397)     actor_output = self.actor_rollout_wg.update_actor(batch)
(FullyAsyncTaskRunner pid=130397)   File "/tmp/ray/session_2025-11-02_13-27-21_397096_26/runtime_resources/working_dir_files/_ray_pkg_c7d1668832fb5b15/verl/single_controller/ray/base.py", line 48, in __call__
(FullyAsyncTaskRunner pid=130397)     output = ray.get(output)
(FullyAsyncTaskRunner pid=130397) ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
(FullyAsyncTaskRunner pid=130397)       class_name: create_colocated_worker_cls.<locals>.WorkerDict
(FullyAsyncTaskRunner pid=130397)       actor_id: 1bc865d7417ba6d5182635e30a000000
(FullyAsyncTaskRunner pid=130397)       pid: 132182
(FullyAsyncTaskRunner pid=130397)       name: ra0GdtWorkerDict_0:1
(FullyAsyncTaskRunner pid=130397)       namespace: 31d1e1ed-4eeb-4a83-b235-4847d56a68c7
(FullyAsyncTaskRunner pid=130397)       ip: 172.16.2.252
(FullyAsyncTaskRunner pid=130397) The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
(FullyAsyncTaskRunner pid=130397) [ASYNC MAIN] Training completed or interrupted
(WorkerDict pid=132181) dlca1ik21m92lu3n-master-0:132181:134952 [0] NCCL INFO [Service thread] Connection closed by localRank 1 [repeated 9x across cluster]

---------------------------------------
Job 'raysubmit_Zw8TnRQiZGRP7Njt' failed
---------------------------------------

Could someone help please? I think the parameters should not cause OOM issue under this setting. @ISEEKYAN

PokeLu avatar Nov 02 '25 16:11 PokeLu

H100 80G ?

ArronHZG avatar Nov 03 '25 11:11 ArronHZG

H100 80G ?

Yes, but sorry I found that I forgot to set tp and pp for the actor in my script. I will test again with the corrected setting later and if it works I'll close the issue.

PokeLu avatar Nov 03 '25 12:11 PokeLu

H100 80G ?

I got it work with tp=4, pp=1 for the actor and ref model. One more question before I closed the issue. I checked the script for collocate sync mode using verl.trainer.ppo_main, which works for tp=2,pp=2 without any offloading and relatively high batch size:

actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 8))

In the async setting, I used tp=4, pp=1 with offloading and relatively low batch size:

actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))

The training node reaches a very high gpu memory utilization of 0.993, on the edge of OOM. I was expecting the async training to use less memory than the collocate sync mode. Both modes use 2 nodes for training in total with the same model_len. Is there anything I did wrong or is it a desired dehavior?

PokeLu avatar Nov 03 '25 14:11 PokeLu

Colocating with 2x gpus with 2x sequence length doesn't means you can train with 1x gpu and 1x seq length. The memory for training is a little complicated, see https://developer.nvidia.com/zh-cn/blog/explore-using-the-megatron-core-training-framework-to-improve-gpu-memory-efficiency-in-large-model-training/ for megatron memory overview.

ISEEKYAN avatar Nov 06 '25 11:11 ISEEKYAN

Colocating with 2x gpus with 2x sequence length doesn't means you can train with 1x gpu and 1x seq length. The memory for training is a little complicated, see https://developer.nvidia.com/zh-cn/blog/explore-using-the-megatron-core-training-framework-to-improve-gpu-memory-efficiency-in-large-model-training/ for megatron memory overview.

Thanks for the reply. Your megatron memory estimator helped a lot during our CPT/SFT phase. Do you have any suggestion on how to estimate the memory usage for GRPO-based algorithms based on that of SFT/PT? For instance, I used the estimator to calculate the required VRAM for SFT, how do I estimate the required VRAM for GRPO with the same training configuration?

PokeLu avatar Nov 07 '25 17:11 PokeLu

Colocating with 2x gpus with 2x sequence length doesn't means you can train with 1x gpu and 1x seq length. The memory for training is a little complicated, see https://developer.nvidia.com/zh-cn/blog/explore-using-the-megatron-core-training-framework-to-improve-gpu-memory-efficiency-in-large-model-training/ for megatron memory overview.

Thanks for the reply. Your megatron memory estimator helped a lot during our CPT/SFT phase. Do you have any suggestion on how to estimate the memory usage for GRPO-based algorithms based on that of SFT/PT? For instance, I used the estimator to calculate the required VRAM for SFT, how do I estimate the required VRAM for GRPO with the same training configuration?

I guess just change the sequence length to match your RL max_tokens will be fine?

ISEEKYAN avatar Nov 12 '25 10:11 ISEEKYAN