提交 b601cb50 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Modify local_useless_subtensor to perform less constant folding under the hood.

上级 df1176a3
...@@ -544,7 +544,8 @@ get_scalar_constant_value_elemwises = ( ...@@ -544,7 +544,8 @@ get_scalar_constant_value_elemwises = (
scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum) scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum)
def get_scalar_constant_value(orig_v, elemwise=True): def get_scalar_constant_value(orig_v, elemwise=True,
only_process_constants=False):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, If v is the output of dimshuffles, fills, allocs, rebroadcasts,
...@@ -558,6 +559,11 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -558,6 +559,11 @@ def get_scalar_constant_value(orig_v, elemwise=True):
:param elemwise: If False, we won't try to go into elemwise. :param elemwise: If False, we won't try to go into elemwise.
So this call is faster. So this call is faster.
:param only_process_constants: If True, we only attempt to obtain
the value of `orig_v` if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
:note: There may be another function similar to this one in the :note: There may be another function similar to this one in the
code, but I'm not sure where it is. code, but I'm not sure where it is.
""" """
...@@ -581,7 +587,7 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -581,7 +587,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
data = v.data data = v.data
return numpy_scalar(data) return numpy_scalar(data)
if getattr(v, 'owner', None): if not only_process_constants and getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast, if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
...@@ -4048,7 +4054,7 @@ def flatten(x, outdim=1): ...@@ -4048,7 +4054,7 @@ def flatten(x, outdim=1):
class Tile(Op): class Tile(Op):
""" """
DEPRECATED: use tile() instead. DEPRECATED: use tile() instead.
Construct an array by repeating the input x according to reps pattern. Construct an array by repeating the input x according to reps pattern.
Tiles its input according to reps. The length of reps is the number of Tiles its input according to reps. The length of reps is the number of
......
...@@ -1932,7 +1932,8 @@ def local_useless_subtensor(node): ...@@ -1932,7 +1932,8 @@ def local_useless_subtensor(node):
shape_of = node.fgraph.shape_feature.shape_of shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True) cdata = node.op.get_constant_idx(node.inputs, allow_partial=True,
only_process_constants=True)
for pos, idx in enumerate(cdata): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension # If idx is not a slice, this means we remove this dimension
......
...@@ -383,7 +383,8 @@ class Subtensor(Op): ...@@ -383,7 +383,8 @@ class Subtensor(Op):
else: else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry) raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False): def get_constant_idx(self, inputs, allow_partial=False,
only_process_constants=False):
""" """
Return the idx_list with constant inputs replaced by their Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise python scalar equivalent. May raise
...@@ -405,6 +406,11 @@ class Subtensor(Op): ...@@ -405,6 +406,11 @@ class Subtensor(Op):
[v, slice(1, 3, None)] [v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs) >>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v NotScalarConstantError: v
:param only_process_constants: If True, we only attempt to obtain
the value of an index/slice if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
""" """
real_idx = get_idx_list(inputs, self.idx_list) real_idx = get_idx_list(inputs, self.idx_list)
...@@ -417,12 +423,14 @@ class Subtensor(Op): ...@@ -417,12 +423,14 @@ class Subtensor(Op):
conv(val.step)) conv(val.step))
else: else:
try: try:
return get_scalar_constant_value(val) return get_scalar_constant_value(val,
only_process_constants=only_process_constants)
except theano.tensor.NotScalarConstantError: except theano.tensor.NotScalarConstantError:
if allow_partial: if allow_partial:
return val return val
else: else:
raise raise
return map(conv, real_idx) return map(conv, real_idx)
def __init__(self, idx_list): def __init__(self, idx_list):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论