提交 f570839f authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4891 from nouiz/bugfix_gh_4865

fix slowdown introduced by gh-4865
...@@ -2298,12 +2298,15 @@ pprint.assign(fill, printing.FunctionPrinter('fill')) ...@@ -2298,12 +2298,15 @@ pprint.assign(fill, printing.FunctionPrinter('fill'))
@constructor @constructor
def ones_like(model, dtype=None): def ones_like(model, dtype=None, opt=False):
"""equivalent of numpy.ones_like """equivalent of numpy.ones_like
Parameters Parameters
---------- ----------
model : tensor model : tensor
dtype : data-type, optional dtype : data-type, optional
opt : If True, we will return a constant instead of a graph when possible.
Useful for Theano optimization, not for user building a graph as this
have the consequence that model isn't always in the graph.
Returns Returns
------- -------
...@@ -2312,17 +2315,22 @@ def ones_like(model, dtype=None): ...@@ -2312,17 +2315,22 @@ def ones_like(model, dtype=None):
""" """
if dtype is None: if dtype is None:
dtype = model.type.dtype dtype = model.type.dtype
ret = fill(model, constant(1.0, dtype=dtype)) ret = constant(1.0, dtype=dtype)
return ret if opt and ret.type == model.type:
return ret
return fill(model, ret)
@constructor @constructor
def zeros_like(model, dtype=None): def zeros_like(model, dtype=None, opt=False):
"""equivalent of numpy.zeros_like """equivalent of numpy.zeros_like
Parameters Parameters
---------- ----------
model : tensor model : tensor
dtype : data-type, optional dtype : data-type, optional
opt : If True, we will return a constant instead of a graph when possible.
Useful for Theano optimization, not for user building a graph as this
have the consequence that model isn't always in the graph.
Returns Returns
------- -------
...@@ -2332,7 +2340,10 @@ def zeros_like(model, dtype=None): ...@@ -2332,7 +2340,10 @@ def zeros_like(model, dtype=None):
if dtype is None: if dtype is None:
dtype = model.type.dtype dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype)) ret = constant(0.0, dtype=dtype)
if opt and ret.type == model.type:
return ret
return fill(model, ret)
def zeros(shape, dtype=None): def zeros(shape, dtype=None):
......
...@@ -2021,36 +2021,26 @@ def local_useless_elemwise(node): ...@@ -2021,36 +2021,26 @@ def local_useless_elemwise(node):
""" """
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx): # We call zeros_like and one_like with opt=True to generate a
# it is the same var in the graph. That will always be true # cleaner graph.
ret = T.fill(node.inputs[in_idx], dtype = node.outputs[0].dtype
T.constant(0.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return [ret]
def ones_like(node, in_idx):
# it is the same var in the graph. That will always be true
ret = T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return [ret]
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2: if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
ret = ones_like(node, 0) ret = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return ret return [ret]
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
ret = zeros_like(node, 0) ret = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return ret return [ret]
elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1:
# No need to copy over any stack trace # No need to copy over any stack trace
...@@ -2070,7 +2060,8 @@ def local_useless_elemwise(node): ...@@ -2070,7 +2060,8 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0], only_process_constants=True) const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return zeros_like(node, 1) return [T.zeros_like(node.inputs[1], dtype=dtype,
opt=True)]
else: else:
return [node.inputs[1]] return [node.inputs[1]]
...@@ -2078,7 +2069,8 @@ def local_useless_elemwise(node): ...@@ -2078,7 +2069,8 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return zeros_like(node, 0) return [T.zeros_like(node.inputs[0], dtype=dtype,
opt=True)]
else: else:
return [node.inputs[0]] return [node.inputs[0]]
...@@ -2091,7 +2083,8 @@ def local_useless_elemwise(node): ...@@ -2091,7 +2083,8 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1]]
else: else:
return ones_like(node, 1) return [T.ones_like(node.inputs[1], dtype=dtype,
opt=True)]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
...@@ -2099,12 +2092,13 @@ def local_useless_elemwise(node): ...@@ -2099,12 +2092,13 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0]]
else: else:
return ones_like(node, 0) return [T.ones_like(node.inputs[0], dtype=dtype,
opt=True)]
elif (isinstance(node.op.scalar_op, scalar.XOR) and elif (isinstance(node.op.scalar_op, scalar.XOR) and
len(node.inputs) == 2): len(node.inputs) == 2):
if node.inputs[0] is node.inputs[1]: if node.inputs[0] is node.inputs[1]:
return zeros_like(node, 0) return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
@register_specialize @register_specialize
...@@ -5023,24 +5017,18 @@ def local_useless_elemwise_comparison(node): ...@@ -5023,24 +5017,18 @@ def local_useless_elemwise_comparison(node):
if node.op.scalar_op.nin != 2: if node.op.scalar_op.nin != 2:
return return
def zeros_like(model, dtype): # We call zeros_like and one_like with opt=True to generate a
ret = T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype) # cleaner graph.
ret = pre_greedy_local_optimizer([local_useless_fill], ret) dtype = node.outputs[0].dtype
return ret
def ones_like(model, dtype):
ret = T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return ret
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \ if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \ if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[{minimum,maximum}](X, X) -> X # Elemwise[{minimum,maximum}](X, X) -> X
if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \ if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
...@@ -5051,13 +5039,13 @@ def local_useless_elemwise_comparison(node): ...@@ -5051,13 +5039,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -5075,13 +5063,15 @@ def local_useless_elemwise_comparison(node): ...@@ -5075,13 +5063,15 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# It don't detect case when the 0 is all zeros with ndim > 0.
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[1], dtype=dtype, opt=True)]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
...@@ -5092,7 +5082,7 @@ def local_useless_elemwise_comparison(node): ...@@ -5092,7 +5082,7 @@ def local_useless_elemwise_comparison(node):
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -5101,7 +5091,7 @@ def local_useless_elemwise_comparison(node): ...@@ -5101,7 +5091,7 @@ def local_useless_elemwise_comparison(node):
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[EQ](Subtensor(Shape(x)), -N) # Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N) # Elemwise[EQ](somegraph that only depend of shape, -N)
...@@ -5134,8 +5124,8 @@ def local_useless_elemwise_comparison(node): ...@@ -5134,8 +5124,8 @@ def local_useless_elemwise_comparison(node):
cst = get_scalar_constant_value(node.inputs[1], cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True) only_process_constants=True)
if cst < 0: if cst < 0:
return [zeros_like(node.inputs[0], return [T.zeros_like(node.inputs[0],
dtype=node.outputs[0].dtype)] dtype=dtype, opt=True)]
except NotScalarConstantError: except NotScalarConstantError:
pass pass
return return
......
...@@ -3432,6 +3432,9 @@ def test_local_fill_useless(): ...@@ -3432,6 +3432,9 @@ def test_local_fill_useless():
class Test_local_useless_elemwise_comparison(unittest.TestCase): class Test_local_useless_elemwise_comparison(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_local_useless_elemwise_comparison(self): def test_local_useless_elemwise_comparison(self):
# TODO: test each case individually. # TODO: test each case individually.
# The following case is what made me discover those cases. # The following case is what made me discover those cases.
...@@ -3469,6 +3472,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3469,6 +3472,8 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
mode = theano.compile.get_default_mode().excluding('fusion') mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode) f = theano.function([X, Y], Z, mode=mode)
f(self.rng.rand(2, 3).astype(config.floatX),
self.rng.rand(2).astype(config.floatX))
# theano.printing.debugprint(f, print_type=True) # theano.printing.debugprint(f, print_type=True)
# here is the output for the debug print: # here is the output for the debug print:
""" """
...@@ -3571,9 +3576,15 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3571,9 +3576,15 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.minimum(x.shape[0], 0), mode=mode) f = theano.function([x], T.minimum(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
assert f(x_val) == 0
f = theano.function([x], T.minimum(0, x.shape[0]), mode=mode) f = theano.function([x], T.minimum(0, x.shape[0]), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
assert f(x_val) == 0
f = theano.function([x], T.minimum([0, 0], x.shape[0]), mode=mode)
# This case isn't optimized.
# self.assert_eqs_const(f, 0)
utt.assert_allclose(f(x_val), [0, 0])
def test_shape_add_inequality(self): def test_shape_add_inequality(self):
x = T.vector('x', dtype=config.floatX) x = T.vector('x', dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论