提交 be719a61 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Use SpecifyShape in its own gradient

上级 9e9c34b9
......@@ -462,13 +462,9 @@ class SpecifyShape(COp):
def grad(self, inp, grads):
x, *shape = inp
(gz,) = grads
# Should I set an SpecifyShape on gz? I think so
# But I don't do it now as we need to make an optimization
# to remove that op from the graph to don't block other optimization
# Should I do an optimizer that will remove the SpecifyShape?
# I think Yes
# return [specify_shape(gz, s)] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape))]
return [gz] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape))]
return [specify_shape(gz, shape)] + [
aesara.gradient.DisconnectedType()() for _ in range(len(shape))
]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......
......@@ -2,7 +2,7 @@ import numpy as np
import pytest
import aesara
from aesara import Mode, function
from aesara import Mode, function, grad
from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config
from aesara.graph.basic import Variable
......@@ -510,6 +510,13 @@ class TestSpecifyShape(utt.InferShapeTester):
assert specify_shape(x, (None, None, 3)) is not x
assert specify_shape(x, (1, 3, None)) is not x
def test_specify_shape_in_grad(self):
x = matrix()
y = specify_shape(x, (2, 3))
z = y + 1
z_grad = grad(z.sum(), wrt=x)
assert isinstance(z_grad.owner.op, SpecifyShape)
class TestRopLop(RopLopChecker):
def test_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论