提交 ac5bbc1b authored 作者: Frederic Bastien's avatar Frederic Bastien

make SpecifyShape.grad work and test it.

上级 6b7971a5
......@@ -1568,8 +1568,13 @@ class SpecifyShape(Op):
assert len(new_shape)==len(xshape)
return [new_shape]
def grad(self, (x,), (gz,)):
return [gz]
def grad(self, (x, s), (gz,)):
# Should I set an SpecifyShape on gz? I think so
# But I don't do it now as we need to make an optimization
# to remove that op from the graph to don't block other optimization
# Should I do an optimizer that will remove the SpecifyShape? I think Yes
return [gz, None]
return [specify_shape(gz,s), None]
specify_shape = SpecifyShape()
......
......@@ -418,6 +418,12 @@ def makeSharedTester(shared_constructor_,
if theano.config.mode!='FAST_COMPILE':
assert len(topo_cst)==0
# Test that we can take the grad.
shape_grad = tensor.grad(x1_specify_shape.sum(), x1_shared)
shape_constant_fct_grad = theano.function([], shape_grad)
theano.printing.debugprint(shape_constant_fct_grad)
shape_constant_fct_grad()
#Test that we can replace with values of the different shape
# but that will raise an error in some case, but not all
specify_shape_fct()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论