minpy
minpy copied to clipboard
Zero gradient for concatenate
Looks like the computation graph breaks on the concatenation operation. MWE:
import minpy.numpy as np
from minpy.core import grad
def foo_nocat(x):
return 3*x
def foo_cat(x):
catx = np.concatenate([x, x], axis=1)
return np.dot(catx, np.array([[1], [2]]))
test_x = np.array([[3]])
print grad(foo_nocat)(test_x) # correct_output
print grad(foo_cat)(test_x) # should be the same
@ZihengJiang Could you have a look? Also put this in unittest.
@ZihengJiang Any follow-up on this?