mxnet icon indicating copy to clipboard operation
mxnet copied to clipboard

Static graph allocation in cached op consumes much more memory without giving other benefits?

Open matteosal opened this issue 3 years ago • 1 comments

I bumped into this while investigating memory efficiency of v2.0 vs v1.6. I am building master branch (v2.0) from commit https://github.com/apache/incubator-mxnet/commit/fabcd145cd496628791f9f2ea813048360ac33ca and 1.x branch (v1.6) from commit https://github.com/apache/incubator-mxnet/commit/6eec9da55c5096079355d1f1a5fa58dcf35d6752

This script loads a BERT implementation (json files attached) and runs it multiple times while increasing the input sequence length:

import mxnet as mx
from mxnet import autograd
import json
import os, psutil
import time

symbol_path = '/home/matteo/bert_sym.json'
shapes_path = '/home/matteo/bert_shapes.json'

def get_memory_usage():
	return psutil.Process(os.getpid()).memory_info().rss / 1e+6

sym = mx.symbol.load(symbol_path)
with open(shapes_path) as shapes_file:
    shapes = json.load(shapes_file)

version = mx.__version__
print('version: ' + version)

if(version == '2.0.0'):
	inputs = [mx.nd.ones(shape) for shape in shapes.values()]
	static_alloc = False
	cached_op = mx.ndarray.CachedOp(sym, flags=[('static_alloc', static_alloc)])
	print('static_alloc: ' + str(static_alloc))
else:
	arrays = {name:mx.nd.ones(shape) for (name, shape) in shapes.items()}
	ex = sym.bind(mx.cpu(), arrays, grad_req='null')

batch_size = 32
seq_lengths = [5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 88]

print('initial memory: ' + str(get_memory_usage()))
print()
for len in seq_lengths:
	input = mx.nd.ones([batch_size, len, 2])
	start = time.time()
	if(version == '2.0.0'):
		inputs[0] = input
		output = cached_op(*inputs, default_ctx=mx.cpu())
	else:
		arrays['.Inputs.Input'] = input
		ex = sym.bind(mx.cpu(), arrays, grad_req='null', shared_exec=ex)
		ex.forward()
		output = ex.outputs[0]
	mx.ndarray.waitall()
	end = time.time()
	print('input length: ' + str(len) + ', output shape: ' + str(output.shape) + ', memory: ' + str(get_memory_usage()) + ', time: ' + str(end - start))

bert_model.zip

I have ran this script with both v2.0 and v1.6. For v2.0 you can change the line static_alloc = False to explore the effects of static allocation. This is what I get:

version: 1.6.0
initial memory: 513.232896

input length: 5, output shape: (160, 1, 768), memory: 627.634176, time: 0.17993521690368652
input length: 6, output shape: (192, 1, 768), memory: 634.53184, time: 0.16498708724975586
input length: 7, output shape: (224, 1, 768), memory: 636.997632, time: 0.1961820125579834
input length: 8, output shape: (256, 1, 768), memory: 639.373312, time: 0.3744235038757324
input length: 10, output shape: (320, 1, 768), memory: 646.094848, time: 0.6234118938446045
input length: 12, output shape: (384, 1, 768), memory: 647.86432, time: 0.7562270164489746
input length: 14, output shape: (448, 1, 768), memory: 648.101888, time: 0.817521333694458
input length: 16, output shape: (512, 1, 768), memory: 650.092544, time: 0.546708345413208
input length: 20, output shape: (640, 1, 768), memory: 653.774848, time: 0.8869819641113281
input length: 24, output shape: (768, 1, 768), memory: 663.30624, time: 1.1742339134216309
input length: 28, output shape: (896, 1, 768), memory: 675.590144, time: 1.5763235092163086
input length: 32, output shape: (1024, 1, 768), memory: 679.415808, time: 1.2759253978729248
input length: 40, output shape: (1280, 1, 768), memory: 685.662208, time: 1.3794775009155273
input length: 48, output shape: (1536, 1, 768), memory: 691.77344, time: 1.8214092254638672
input length: 56, output shape: (1792, 1, 768), memory: 700.432384, time: 2.0428481101989746
input length: 64, output shape: (2048, 1, 768), memory: 706.981888, time: 2.3951635360717773
input length: 88, output shape: (2816, 1, 768), memory: 742.8096, time: 3.1772890090942383
version: 2.0.0
static_alloc: False
initial memory: 450.53952

input length: 5, output shape: (160, 1, 768), memory: 646.582272, time: 0.11594986915588379
input length: 6, output shape: (192, 1, 768), memory: 655.331328, time: 0.11730408668518066
input length: 7, output shape: (224, 1, 768), memory: 660.733952, time: 0.12042427062988281
input length: 8, output shape: (256, 1, 768), memory: 666.7264, time: 0.13665461540222168
input length: 10, output shape: (320, 1, 768), memory: 673.849344, time: 0.16487836837768555
input length: 12, output shape: (384, 1, 768), memory: 688.164864, time: 0.20292162895202637
input length: 14, output shape: (448, 1, 768), memory: 694.419456, time: 0.22304844856262207
input length: 16, output shape: (512, 1, 768), memory: 702.029824, time: 0.26688313484191895
input length: 20, output shape: (640, 1, 768), memory: 722.939904, time: 0.31215453147888184
input length: 24, output shape: (768, 1, 768), memory: 746.889216, time: 0.43781352043151855
input length: 28, output shape: (896, 1, 768), memory: 768.385024, time: 0.7122831344604492
input length: 32, output shape: (1024, 1, 768), memory: 772.091904, time: 0.8170092105865479
input length: 40, output shape: (1280, 1, 768), memory: 806.465536, time: 0.9898905754089355
input length: 48, output shape: (1536, 1, 768), memory: 853.385216, time: 1.1750028133392334
input length: 56, output shape: (1792, 1, 768), memory: 901.287936, time: 1.3260750770568848
input length: 64, output shape: (2048, 1, 768), memory: 953.700352, time: 1.5255098342895508
input length: 88, output shape: (2816, 1, 768), memory: 1040.379904, time: 2.304577350616455
version: 2.0.0
static_alloc: True
initial memory: 460.10368

input length: 5, output shape: (160, 1, 768), memory: 707.33824, time: 0.13922739028930664
input length: 6, output shape: (192, 1, 768), memory: 774.643712, time: 0.10790896415710449
input length: 7, output shape: (224, 1, 768), memory: 819.228672, time: 0.12243270874023438
input length: 8, output shape: (256, 1, 768), memory: 831.901696, time: 0.13582420349121094
input length: 10, output shape: (320, 1, 768), memory: 853.38112, time: 0.17856955528259277
input length: 12, output shape: (384, 1, 768), memory: 967.712768, time: 0.20433759689331055
input length: 14, output shape: (448, 1, 768), memory: 987.136, time: 0.22087836265563965
input length: 16, output shape: (512, 1, 768), memory: 1007.665152, time: 0.24690675735473633
input length: 20, output shape: (640, 1, 768), memory: 1199.075328, time: 0.3204777240753174
input length: 24, output shape: (768, 1, 768), memory: 1425.625088, time: 0.3830389976501465
input length: 28, output shape: (896, 1, 768), memory: 1689.489408, time: 0.4304239749908447
input length: 32, output shape: (1024, 1, 768), memory: 1726.783488, time: 0.515824556350708
input length: 40, output shape: (1280, 1, 768), memory: 2105.565184, time: 0.9770240783691406
input length: 48, output shape: (1536, 1, 768), memory: 2558.291968, time: 1.218252420425415
input length: 56, output shape: (1792, 1, 768), memory: 3089.73568, time: 1.4080862998962402
input length: 64, output shape: (2048, 1, 768), memory: 3693.993984, time: 1.6184279918670654
input length: 88, output shape: (2816, 1, 768), memory: 4524.1344, time: 2.394423246383667

Besides being glad that v2.0 has faster timings, there are 2 things I don't understand here:

  1. Static allocation uses roughly 4 times more memory while not being any faster. Is this expected for such a model?
  2. In terms of memory usage, static allocation in CachedOp is supposed to be closer to the scenario of old executors in v1.6 because those are also supposed to statically allocate their computational graphs. But yet v1.6 uses way less memory than v2.0 + static allocation, and also less memory than 2.0 without static allocation. What's happening here?

matteosal avatar Jun 07 '22 13:06 matteosal

@barry-jin @szha Can you help with this?

bgawrych avatar Jun 15 '22 06:06 bgawrych