提交 b4e6bda3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2532 from carriepl/speed_up_local_useless_subtensor

Speed up local useless subtensor
......@@ -544,7 +544,8 @@ get_scalar_constant_value_elemwises = (
scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum)
def get_scalar_constant_value(orig_v, elemwise=True):
def get_scalar_constant_value(orig_v, elemwise=True,
only_process_constants=False):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts,
......@@ -558,6 +559,11 @@ def get_scalar_constant_value(orig_v, elemwise=True):
:param elemwise: If False, we won't try to go into elemwise.
So this call is faster.
:param only_process_constants: If True, we only attempt to obtain
the value of `orig_v` if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
......@@ -581,7 +587,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
data = v.data
return numpy_scalar(data)
if getattr(v, 'owner', None):
if not only_process_constants and getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard,
compile.DeepCopyOp)):
......@@ -4048,7 +4054,7 @@ def flatten(x, outdim=1):
class Tile(Op):
"""
DEPRECATED: use tile() instead.
Construct an array by repeating the input x according to reps pattern.
Tiles its input according to reps. The length of reps is the number of
......
......@@ -1935,7 +1935,8 @@ def local_useless_subtensor(node):
shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor):
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True,
only_process_constants=True)
for pos, idx in enumerate(cdata):
if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension
......@@ -1950,15 +1951,17 @@ def local_useless_subtensor(node):
# is not a useless subtensor
return False
length_pos_data = maxsize
for pos, idx in enumerate(cdata):
length_pos = shape_of[node.inputs[0]][pos]
try:
length_pos_data = get_scalar_constant_value(length_pos)
except NotScalarConstantError:
pass
if isinstance(idx.stop, (int, numpy.integer)):
length_pos_data = maxsize
try:
length_pos_data = get_scalar_constant_value(length_pos)
except NotScalarConstantError:
pass
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, gof.Variable):
......@@ -1999,10 +2002,10 @@ def local_useless_subtensor(node):
length = get_scalar_constant_value(shape_of[node.inputs[0]][0])
except NotScalarConstantError:
return False
# get index (which must be a vector by definition)
idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if isinstance(idx, T.Constant):
......@@ -2017,7 +2020,7 @@ def local_useless_subtensor(node):
idx.owner.inputs)
except NotScalarConstantError:
return False
if start != 0:
return False
if stop != length:
......
......@@ -383,7 +383,8 @@ class Subtensor(Op):
else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False):
def get_constant_idx(self, inputs, allow_partial=False,
only_process_constants=False):
"""
Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise
......@@ -405,6 +406,11 @@ class Subtensor(Op):
[v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v
:param only_process_constants: If True, we only attempt to obtain
the value of an index/slice if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
"""
real_idx = get_idx_list(inputs, self.idx_list)
......@@ -417,12 +423,14 @@ class Subtensor(Op):
conv(val.step))
else:
try:
return get_scalar_constant_value(val)
return get_scalar_constant_value(val,
only_process_constants=only_process_constants)
except theano.tensor.NotScalarConstantError:
if allow_partial:
return val
else:
raise
return map(conv, real_idx)
def __init__(self, idx_list):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论