提交 b601cb50 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Modify local_useless_subtensor to perform less constant folding under the hood.

上级 df1176a3
...@@ -544,7 +544,8 @@ get_scalar_constant_value_elemwises = ( ...@@ -544,7 +544,8 @@ get_scalar_constant_value_elemwises = (
scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum) scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum)
def get_scalar_constant_value(orig_v, elemwise=True): def get_scalar_constant_value(orig_v, elemwise=True,
only_process_constants=False):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, If v is the output of dimshuffles, fills, allocs, rebroadcasts,
...@@ -558,6 +559,11 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -558,6 +559,11 @@ def get_scalar_constant_value(orig_v, elemwise=True):
:param elemwise: If False, we won't try to go into elemwise. :param elemwise: If False, we won't try to go into elemwise.
So this call is faster. So this call is faster.
:param only_process_constants: If True, we only attempt to obtain
the value of `orig_v` if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
:note: There may be another function similar to this one in the :note: There may be another function similar to this one in the
code, but I'm not sure where it is. code, but I'm not sure where it is.
""" """
...@@ -581,7 +587,7 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -581,7 +587,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
data = v.data data = v.data
return numpy_scalar(data) return numpy_scalar(data)
if getattr(v, 'owner', None): if not only_process_constants and getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast, if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
......
...@@ -1932,7 +1932,8 @@ def local_useless_subtensor(node): ...@@ -1932,7 +1932,8 @@ def local_useless_subtensor(node):
shape_of = node.fgraph.shape_feature.shape_of shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True) cdata = node.op.get_constant_idx(node.inputs, allow_partial=True,
only_process_constants=True)
for pos, idx in enumerate(cdata): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension # If idx is not a slice, this means we remove this dimension
......
...@@ -383,7 +383,8 @@ class Subtensor(Op): ...@@ -383,7 +383,8 @@ class Subtensor(Op):
else: else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry) raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False): def get_constant_idx(self, inputs, allow_partial=False,
only_process_constants=False):
""" """
Return the idx_list with constant inputs replaced by their Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise python scalar equivalent. May raise
...@@ -405,6 +406,11 @@ class Subtensor(Op): ...@@ -405,6 +406,11 @@ class Subtensor(Op):
[v, slice(1, 3, None)] [v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs) >>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v NotScalarConstantError: v
:param only_process_constants: If True, we only attempt to obtain
the value of an index/slice if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
""" """
real_idx = get_idx_list(inputs, self.idx_list) real_idx = get_idx_list(inputs, self.idx_list)
...@@ -417,12 +423,14 @@ class Subtensor(Op): ...@@ -417,12 +423,14 @@ class Subtensor(Op):
conv(val.step)) conv(val.step))
else: else:
try: try:
return get_scalar_constant_value(val) return get_scalar_constant_value(val,
only_process_constants=only_process_constants)
except theano.tensor.NotScalarConstantError: except theano.tensor.NotScalarConstantError:
if allow_partial: if allow_partial:
return val return val
else: else:
raise raise
return map(conv, real_idx) return map(conv, real_idx)
def __init__(self, idx_list): def __init__(self, idx_list):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论