提交 be719a61 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Use SpecifyShape in its own gradient

上级 9e9c34b9
...@@ -462,13 +462,9 @@ class SpecifyShape(COp): ...@@ -462,13 +462,9 @@ class SpecifyShape(COp):
def grad(self, inp, grads): def grad(self, inp, grads):
x, *shape = inp x, *shape = inp
(gz,) = grads (gz,) = grads
# Should I set an SpecifyShape on gz? I think so return [specify_shape(gz, shape)] + [
# But I don't do it now as we need to make an optimization aesara.gradient.DisconnectedType()() for _ in range(len(shape))
# to remove that op from the graph to don't block other optimization ]
# Should I do an optimizer that will remove the SpecifyShape?
# I think Yes
# return [specify_shape(gz, s)] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape))]
return [gz] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape))]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import aesara import aesara
from aesara import Mode, function from aesara import Mode, function, grad
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Variable from aesara.graph.basic import Variable
...@@ -510,6 +510,13 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -510,6 +510,13 @@ class TestSpecifyShape(utt.InferShapeTester):
assert specify_shape(x, (None, None, 3)) is not x assert specify_shape(x, (None, None, 3)) is not x
assert specify_shape(x, (1, 3, None)) is not x assert specify_shape(x, (1, 3, None)) is not x
def test_specify_shape_in_grad(self):
x = matrix()
y = specify_shape(x, (2, 3))
z = y + 1
z_grad = grad(z.sum(), wrt=x)
assert isinstance(z_grad.owner.op, SpecifyShape)
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_shape(self): def test_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论