提交 6f85d17f authored 作者: Saizheng Zhang's avatar Saizheng Zhang

#2801, minor fix

上级 675970e0
...@@ -1895,11 +1895,14 @@ def local_track_shape_i(node): ...@@ -1895,11 +1895,14 @@ def local_track_shape_i(node):
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
# Subtensor(SetSubtensor(x, y, idx), idx) -> y
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
def local_subtensor_inc_subtensor(node): def local_subtensor_inc_subtensor(node):
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
"""
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
x = node.inputs[0] x = node.inputs[0]
if not x.owner or not isinstance(x.owner.op, IncSubtensor): if not x.owner or not isinstance(x.owner.op, IncSubtensor):
......
...@@ -1956,7 +1956,7 @@ def test_subtensor_inc_subtensor(): ...@@ -1956,7 +1956,7 @@ def test_subtensor_inc_subtensor():
v = tensor.tensor3('v') v = tensor.tensor3('v')
y = tensor.set_subtensor(x[i1, :i2, i3:, ::i4], v) y = tensor.set_subtensor(x[i1, :i2, i3:, ::i4], v)
z = y[i1, :i2, i3:, ::i4] z = y[i1, :i2, i3:, ::i4]
mode = theano.compile.mode.get_mode('FAST_COMPILE').including('local_subtensor_inc_subtensor') mode = theano.compile.mode.get_default_mode().including('local_subtensor_inc_subtensor')
f = theano.function([x, i1, i2, i3, i4, v], z, mode=mode) f = theano.function([x, i1, i2, i3, i4, v], z, mode=mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) == 1 assert len(prog) == 1
...@@ -1964,7 +1964,7 @@ def test_subtensor_inc_subtensor(): ...@@ -1964,7 +1964,7 @@ def test_subtensor_inc_subtensor():
# case not use this optimization # case not use this optimization
z = y[i1, :i3, i2:, ::i4] z = y[i1, :i3, i2:, ::i4]
mode = theano.compile.mode.get_mode('FAST_COMPILE').including('local_subtensor_inc_subtensor') mode = theano.compile.mode.get_default_mode().including('local_subtensor_inc_subtensor')
f = theano.function([x, i1, i2, i3, i4, v], z, mode=mode) f = theano.function([x, i1, i2, i3, i4, v], z, mode=mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert len(prog) != 1 assert len(prog) != 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论