提交 ff542349 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Simplify checking for constant zeros in optimization, test it.

More cases are detected now.
上级 8b7aeb6e
......@@ -262,21 +262,19 @@ def local_0_dot_x(node):
x = node.inputs[0]
y = node.inputs[1]
replace = False
if x.owner and isinstance(x.owner.op, T.Alloc):
try:
val = get_constant_value(x.owner.inputs[0])
if numpy.all(val == 0):
replace = True
except TypeError:
pass
try:
if get_constant_value(x) == 0:
replace = True
except TypeError:
pass
if y.owner and isinstance(y.owner.op, T.Alloc):
try:
val = get_constant_value(y.owner.inputs[0])
if numpy.all(val == 0):
replace = True
except TypeError:
pass
try:
if get_constant_value(y) == 0:
replace = True
except TypeError:
pass
# TODO: Integrate that into get_constant_value somehow
if isinstance(x, T.TensorConstant) and (x.tag.unique_value == 0):
replace = True
if isinstance(y, T.TensorConstant) and (y.tag.unique_value == 0):
......@@ -1630,13 +1628,12 @@ def local_incsubtensor_of_allocs(node):
x = node.inputs[0]
y = node.inputs[1]
replace = False
if y.owner and isinstance(y.owner.op, T.Alloc):
try:
val = get_constant_value(y.owner.inputs[0])
if numpy.all(val == 0):
replace = True
except TypeError:
pass
try:
if get_constant_value(y) == 0:
replace = True
except TypeError:
pass
# TODO: Integrate that into get_constant_value
if isinstance(y, T.TensorConstant) and (y.tag.unique_value == 0):
replace = True
......@@ -1655,25 +1652,20 @@ def local_setsubtensor_of_allocs(node):
y = node.inputs[1]
replace_x = None
replace_y = None
if x.owner and isinstance(x.owner.op, T.Alloc):
try:
val = get_constant_value(x.owner.inputs[0])
assert val.size == 1
replace_x = val
except (TypeError, AssertionError):
replace_x = x.owner.inputs[0]
try:
replace_x = get_constant_value(x)
except TypeError:
pass
if isinstance(x, T.TensorConstant) and (x.tag.unique_value is not
None):
replace_x = x.tag.unique_value
if y.owner and isinstance(y.owner.op, T.Alloc):
try:
val = get_constant_value(y.owner.inputs[0])
assert val.size == 1
replace_y = val
except (TypeError, AssertionError):
replace_y = y.owner.inputs[0]
try:
replace_y = get_constant_value(y)
except TypeError:
pass
if isinstance(y, T.TensorConstant) and (y.tag.unique_value is not
None):
......
......@@ -1602,6 +1602,15 @@ class Test_alloc_zero(unittest.TestCase):
assert numpy.all( [ not isinstance(x.op, tensor.IncSubtensor) for x in
f.maker.env.toposort() ])
def test_setsubtensor_allocs1t(self):
y = tensor.matrix()
x0 = tensor.constant(numpy.asarray(numpy.zeros_like((4,4)), dtype=config.floatX))
y0 = tensor.zeros_like(y)
z = tensor.set_subtensor(x0[:4], y0.T)
f = theano.function([y], z)
assert numpy.all( [ not isinstance(x.op, tensor.IncSubtensor) for x in
f.maker.env.toposort() ])
def test_setsubtensor_allocs2(self):
x = tensor.matrix()
y0 = tensor.constant(numpy.asarray(numpy.zeros_like((4,4)), dtype=config.floatX))
......@@ -1620,6 +1629,15 @@ class Test_alloc_zero(unittest.TestCase):
assert numpy.all( [ not isinstance(x.op, tensor.IncSubtensor) for x in
f.maker.env.toposort() ])
def test_incsubtensor_allocs0t(self):
x = tensor.matrix()
y = tensor.matrix()
y0 = tensor.zeros_like(y)
z = tensor.inc_subtensor(x[:4], y0.T)
f = theano.function([x,y], z)
assert numpy.all( [ not isinstance(x.op, tensor.IncSubtensor) for x in
f.maker.env.toposort() ])
def test_incsubtensor_allocs1(self):
x = tensor.matrix()
y0 = tensor.constant(numpy.asarray(numpy.zeros_like((4,4)), dtype=config.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论