ColabDesign icon indicating copy to clipboard operation
ColabDesign copied to clipboard

OOM on AF DB proteins?

Open Abhishaike opened this issue 2 years ago • 2 comments

Creating a binder for this protein: https://alphafold.ebi.ac.uk/entry/Q8W3K0 and I'm getting this error both on T4's and A100's:

This error makes sense, but I'm confused as to how a protein that requires 100GB could've been folded by Alphafold in the first place? Shouldn't any protein that Alphafold can intake also be used by Colabdesign? Or does Colabdesign take more memory?

Stage 1: running (logits → soft)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-3-09dc0c470929>](https://localhost:8080/#) in <module>
     36 if optimizer == "pssm_semigreedy":
---> 37   model.design_pssm_semigreedy(120, 32, **flags)
     38   pssm = softmax(model._tmp["seq_logits"],1)

25 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 164515000848 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation: 1006.63MiB
              constant allocation:    97.4KiB
        maybe_live_out allocation:  660.41MiB
     preallocated temp allocation:  153.22GiB
                 total allocation:  154.84GiB
Peak buffers:
	Buffer 1:
		Size: 22.78GiB
		XLA Label: copy
		Shape: f32[288,4,4,1152,1152]
		==========================

	Buffer 2:
		Size: 22.78GiB
		XLA Label: copy
		Shape: f32[288,4,4,1152,1152]
		==========================

	Buffer 3:
		Size: 22.78GiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_attention_starting_node/broadcast_in_dim[shape=(288, 4, 4, 1152, 1152) broadcast_dimensions=()]" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/stateful.py" source_line=640
		XLA Label: broadcast
		Shape: f32[288,4,4,1152,1152]
		==========================

	Buffer 4:
		Size: 648.00MiB
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 5:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/gating_linear/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 6:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/output_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 7:
		Size: 648.00MiB
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 8:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/mul" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/layer_norm.py" source_line=205
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 9:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/mul" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/layer_norm.py" source_line=205
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 10:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/right_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 11:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/right_gate/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 12:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/left_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 13:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/left_gate/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 14:
		Size: 648.00MiB
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 15:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/structure_module/broadcast_in_dim[shape=(1152, 1152, 128) broadcast_dimensions=()]" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/stateful.py" source_line=640
		XLA Label: broadcast
		Shape: f32[1152,1152,128]
		==========================

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Abhishaike avatar Mar 08 '23 18:03 Abhishaike

  • This is with PSSM, the memory issue doesnt occur with semigreedy. Why is PSSM so memory intensive?

Abhishaike avatar Mar 08 '23 19:03 Abhishaike

Gradient compute takes about 2X more memory. Semigreedy is just trying random mutations and accepts those that improve loss, which does not require gradient compute.

On Wed, Mar 8, 2023, 2:17 PM Abhishaike Mahajan @.***> wrote:

  • This is with PSSM, the memory issue doesnt occur with semigreedy. Why is PSSM so memory intensive?

— Reply to this email directly, view it on GitHub https://github.com/sokrypton/ColabDesign/issues/116#issuecomment-1460725802, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA76LAS2YE27MII42ITHYVDW3DLOJANCNFSM6AAAAAAVUD6OPI . You are receiving this because you are subscribed to this thread.Message ID: @.***>

sokrypton avatar Mar 08 '23 19:03 sokrypton