提交 9ec45737 authored 作者: Frederic's avatar Frederic

local_mul_specialize now don't introcude neg node when this will add more node then needed.

上级 ea3dcffc
......@@ -3600,7 +3600,15 @@ def local_pow_specialize_device(node):
@gof.local_optimizer([T.mul])
def local_mul_specialize(node):
"""Remove special-case constants from mul arguments
"""Remove special-case constants from mul arguments and useless neg in inputs.
mul(-1, x) -> neg(x)
mul(1, x, y) -> mul(x, y)
mul(0, ...) -> alloc(0, shapes...)
This is not done if we would add more nodes in the graph, like with:
mul(-1, x, y) -/-> neg(mul(x, y))
"""
# here, we are past the point of canonicalization, so we don't
# want to put in un-necessary fills.
......@@ -3610,19 +3618,23 @@ def local_mul_specialize(node):
#the idea here is that we have pow(x, y)
neg = False
new_inputs = []
nb_neg_node = 0
nb_cst = 0
for input in node.inputs:
# remove any neg arguments
while input.owner and input.owner.op == T.neg:
neg ^= True
input = input.owner.inputs[0]
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0):
continue
elif N.all(y == -1.0):
if y == 1.0:
nb_cst += 1
elif y == -1.0:
nb_cst += 1
neg ^= True # toggles
elif N.all(y == 0.0):
elif y == 0.0:
# if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], node.fgraph)]
else:
......@@ -3636,10 +3648,17 @@ def local_mul_specialize(node):
else:
rval = new_inputs[0]
else:
if neg:
rval = -T.mul(*new_inputs)
else:
rval = T.mul(*new_inputs)
# The next case would cause a replace by an equivalent case.
if (neg and
nb_neg_node == 0 and
nb_cst == 1):
return
elif neg:
# Don't add an extra neg node as we can't
# fully replace this mul by a neg.
m1 = numpy.asarray(-1, dtype=node.outputs[0].dtype)
new_inputs = [m1] + new_inputs
rval = T.mul(*new_inputs)
return [broadcast_like(rval, node.outputs[0], node.fgraph)]
else:
......
......@@ -2838,7 +2838,7 @@ def test_local_mul_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()]
print nodes
theano.printing.debugprint(f)
assert nodes == [T.mul, inplace.neg_inplace]
assert nodes == [T.mul]
f = function([v, m], v * 0 * (-m), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
......@@ -2852,6 +2852,12 @@ def test_local_mul_specialize():
theano.printing.debugprint(f)
assert nodes == [T.mul]
f = function([v, m], v * (-1) * m, mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
print nodes
theano.printing.debugprint(f)
assert nodes == [T.mul]
def speed_local_pow_specialize_range():
val = numpy.random.rand(1e7)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论