提交 1c31c496 authored 作者: Frederic Bastien's avatar Frederic Bastien

Remove Python fct call and building useless object to remove slow down discovered by Ramana

上级 8a41c263
......@@ -2298,12 +2298,15 @@ pprint.assign(fill, printing.FunctionPrinter('fill'))
@constructor
def ones_like(model, dtype=None):
def ones_like(model, dtype=None, opt=False):
"""equivalent of numpy.ones_like
Parameters
----------
model : tensor
dtype : data-type, optional
opt : If True, we will return a constant instead of a graph when possible.
Useful for Theano optimization, not for user building a graph as this
have the consequence that model isn't always in the graph.
Returns
-------
......@@ -2312,17 +2315,22 @@ def ones_like(model, dtype=None):
"""
if dtype is None:
dtype = model.type.dtype
ret = fill(model, constant(1.0, dtype=dtype))
return ret
ret = constant(1.0, dtype=dtype)
if opt and ret.type == model.type:
return ret
return fill(model, ret)
@constructor
def zeros_like(model, dtype=None):
def zeros_like(model, dtype=None, opt=False):
"""equivalent of numpy.zeros_like
Parameters
----------
model : tensor
dtype : data-type, optional
opt : If True, we will return a constant instead of a graph when possible.
Useful for Theano optimization, not for user building a graph as this
have the consequence that model isn't always in the graph.
Returns
-------
......@@ -2332,7 +2340,10 @@ def zeros_like(model, dtype=None):
if dtype is None:
dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype))
ret = constant(0.0, dtype=dtype)
if opt and ret.type == model.type:
return ret
return fill(model, ret)
def zeros(shape, dtype=None):
......
......@@ -2021,24 +2021,14 @@ def local_useless_elemwise(node):
"""
if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx):
# it is the same var in the graph. That will always be true
ret = T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return [ret]
def ones_like(node, in_idx):
# it is the same var in the graph. That will always be true
ret = T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return [ret]
# We call zeros_like and one_like with opt=True to generate a
# cleaner graph.
dtype = node.outputs[0].dtype
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true
ret = ones_like(node, 0)
ret = T.ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
......@@ -2046,7 +2036,7 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false
ret = zeros_like(node, 0)
ret = T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
......@@ -2070,7 +2060,8 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return zeros_like(node, 1)
return T.zeros_like(node.inputs[1], dtype=dtype,
opt=True)
else:
return [node.inputs[1]]
......@@ -2078,7 +2069,8 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return zeros_like(node, 0)
return T.zeros_like(node.inputs[0], dtype=dtype,
opt=True)
else:
return [node.inputs[0]]
......@@ -2091,7 +2083,7 @@ def local_useless_elemwise(node):
if const_val == 0:
return [node.inputs[1]]
else:
return ones_like(node, 1)
return T.ones_like(node.inputs[1], dtype=dtype, opt=True)
if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
......@@ -2099,12 +2091,12 @@ def local_useless_elemwise(node):
if const_val == 0:
return [node.inputs[0]]
else:
return ones_like(node, 0)
return T.ones_like(node.inputs[0], dtype=dtype, opt=True)
elif (isinstance(node.op.scalar_op, scalar.XOR) and
len(node.inputs) == 2):
if node.inputs[0] is node.inputs[1]:
return zeros_like(node, 0)
return T.zeros_like(node.inputs[0], dtype=dtype, opt=True)
@register_specialize
......@@ -5023,24 +5015,18 @@ def local_useless_elemwise_comparison(node):
if node.op.scalar_op.nin != 2:
return
def zeros_like(model, dtype):
ret = T.zeros_like(model, dtype=node.outputs[0].dtype)
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return ret
def ones_like(model, dtype):
ret = T.ones_like(model, dtype=node.outputs[0].dtype)
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return ret
# We call zeros_like and one_like with opt=True to generate a
# cleaner graph.
dtype = node.outputs[0].dtype
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
node.inputs[0] is node.inputs[1]:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
node.inputs[0] is node.inputs[1]:
return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[{minimum,maximum}](X, X) -> X
if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \
node.inputs[0] is node.inputs[1]:
......@@ -5051,13 +5037,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \
......@@ -5075,13 +5061,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i):
return [zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
return [T.zeros_like(node.inputs[1], dtype=dtype, opt=True)]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \
......@@ -5092,7 +5078,7 @@ def local_useless_elemwise_comparison(node):
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \
......@@ -5101,7 +5087,7 @@ def local_useless_elemwise_comparison(node):
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.ones_like(node.inputs[0], dtype=dtype, opt=True)]
# Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N)
......@@ -5134,8 +5120,8 @@ def local_useless_elemwise_comparison(node):
cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True)
if cst < 0:
return [zeros_like(node.inputs[0],
dtype=node.outputs[0].dtype)]
return [T.zeros_like(node.inputs[0],
dtype=dtype, opt=True)]
except NotScalarConstantError:
pass
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论