提交 1b07a532 authored 作者: Faruk Ahmed's avatar Faruk Ahmed

inc to set for zeros

上级 bd1a12ed
...@@ -3411,6 +3411,24 @@ def local_incsubtensor_of_zeros(node): ...@@ -3411,6 +3411,24 @@ def local_incsubtensor_of_zeros(node):
return return
@register_canonicalize
@register_specialize
@gof.local_optimizer([IncSubtensor])
def incsubtensor_of_zeros_to_setsubtensor(node):
"""
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
"""
if (isinstance(node.op, (IncSubtensor)) and not node.op.set_instead_of_inc):
x = node.inputs[0]
if isinstance(x, T.Constant) and not numpy.any(x.data):
return [IncSubtensor(node.op.idx_list,
node.op.inplace,
set_instead_of_inc=True,
destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased,
)(*node.inputs)]
@register_canonicalize('local_setsubtensor_of_allocs') @register_canonicalize('local_setsubtensor_of_allocs')
@register_stabilize('local_setsubtensor_of_allocs') @register_stabilize('local_setsubtensor_of_allocs')
@gof.local_optimizer([IncSubtensor]) @gof.local_optimizer([IncSubtensor])
......
...@@ -3001,6 +3001,37 @@ class Test_alloc_zero(unittest.TestCase): ...@@ -3001,6 +3001,37 @@ class Test_alloc_zero(unittest.TestCase):
assert np.all([not isinstance(n.op, tensor.IncSubtensor) assert np.all([not isinstance(n.op, tensor.IncSubtensor)
for n in f.maker.fgraph.toposort()]) for n in f.maker.fgraph.toposort()])
def test_incsubtensor_x_zeros(self):
x = tensor.constant(np.asarray(np.zeros((4, 4)),
dtype=config.floatX))
y = tensor.matrix()
z = tensor.inc_subtensor(x[:4], y)
f = theano.function([y], z)
inc_nodes = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, tensor.IncSubtensor)]
assert(len(inc_nodes) == 1)
node_is_set_instead_of_inc = inc_nodes[0].op.set_instead_of_inc
mode = theano.config.mode
assert((mode != "FAST_COMPILE" and node_is_set_instead_of_inc) or
(mode == "FAST_COMPILE" and not node_is_set_instead_of_inc))
test_X = np.random.random((4, 4)).astype(config.floatX)
utt.assert_allclose(f(test_X), test_X)
# also check the flag doesn't get set if first input is not zeros:
not_all_zeros = np.zeros((4, 4))
not_all_zeros[1, 0] = 0.001
x = tensor.constant(np.asarray(not_all_zeros, dtype=config.floatX))
y = tensor.matrix()
z = tensor.inc_subtensor(x[:4], y)
f = theano.function([y], z)
inc_nodes = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, tensor.IncSubtensor)]
assert(len(inc_nodes) == 1)
assert(inc_nodes[0].op.set_instead_of_inc is False)
test_X = np.random.random((4, 4)).astype(config.floatX)
utt.assert_allclose(f(test_X), test_X + not_all_zeros)
def test_advancedincsubtensor1_allocs0(self): def test_advancedincsubtensor1_allocs0(self):
x = tensor.matrix() x = tensor.matrix()
y = tensor.matrix() y = tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论