keras_multi_gpu
keras_multi_gpu copied to clipboard
multi-gpu not works
Hi, I try to use this method to parallel my model. but it may do not work. For example, I use 3 gpus. but actually it seem use one gpu. The code is
def multi_gpu_wrapper(single_model, num_gpu):
inputs = single_model.inputs
towers = []
concate_layer = tf.keras.layers.Concatenate(axis=0)
for gpu_id in range(num_gpu):
print 'cur gpu is ', gpu_id
with tf.device('/gpu:' + str(gpu_id)):
splited_layer = tf.keras.layers.Lambda(lambda x: slice_batch(x, num_gpu, gpu_id))
cur_inputs = []
for input in inputs:
cur_inputs.append(
splited_layer(input)
)
towers.append(single_model(cur_inputs))
print towers[-1]
outputs = []
num_output = len(towers[-1])
with tf.device('/cpu:0'):
for i in range(num_output):
tmp_outputs = []
for j in range(num_gpu):
tmp_outputs.append(towers[j][i])
outputs.append(concate_layer(tmp_outputs))
multi_gpu_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
return multi_gpu_model
The output of nvidia-smi is:

Do you know how to fix it? Thank you!