提交 65e7e3e8 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add grad to squeeze.

上级 d1b5b42e
...@@ -225,7 +225,8 @@ class SqueezeOp(theano.Op): ...@@ -225,7 +225,8 @@ class SqueezeOp(theano.Op):
z[0] = squeezed z[0] = squeezed
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
return [None for i in inputs] out = outputs_gradients[0]
return [out.reshape(inputs[0].shape)]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -125,3 +125,10 @@ class TestSqueezeOp(utt.InferShapeTester): ...@@ -125,3 +125,10 @@ class TestSqueezeOp(utt.InferShapeTester):
a = np.random.random((4, 1, 2, 1)) a = np.random.random((4, 1, 2, 1))
assert np.allclose(np.squeeze(a), f(a)) assert np.allclose(np.squeeze(a), f(a))
def test_grad(self):
x = T.dtensor4('x')
a = np.random.random((1, 1, 3, 4))
gf = theano.function([x], T.grad(T.sum(squeeze(x, out_nd=1)), x))
utt.verify_grad(SqueezeOp(out_nd=2), [a])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论