提交 b601cb50 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Modify local_useless_subtensor to perform less constant folding under the hood.

上级 df1176a3
......@@ -544,7 +544,8 @@ get_scalar_constant_value_elemwises = (
scal.IntDiv, scal.TrueDiv, scal.Minimum, scal.Maximum)
def get_scalar_constant_value(orig_v, elemwise=True):
def get_scalar_constant_value(orig_v, elemwise=True,
only_process_constants=False):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts,
......@@ -558,6 +559,11 @@ def get_scalar_constant_value(orig_v, elemwise=True):
:param elemwise: If False, we won't try to go into elemwise.
So this call is faster.
:param only_process_constants: If True, we only attempt to obtain
the value of `orig_v` if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
......@@ -581,7 +587,7 @@ def get_scalar_constant_value(orig_v, elemwise=True):
data = v.data
return numpy_scalar(data)
if getattr(v, 'owner', None):
if not only_process_constants and getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard,
compile.DeepCopyOp)):
......@@ -4048,7 +4054,7 @@ def flatten(x, outdim=1):
class Tile(Op):
"""
DEPRECATED: use tile() instead.
Construct an array by repeating the input x according to reps pattern.
Tiles its input according to reps. The length of reps is the number of
......
......@@ -1932,7 +1932,8 @@ def local_useless_subtensor(node):
shape_of = node.fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor):
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True,
only_process_constants=True)
for pos, idx in enumerate(cdata):
if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension
......
......@@ -383,7 +383,8 @@ class Subtensor(Op):
else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False):
def get_constant_idx(self, inputs, allow_partial=False,
only_process_constants=False):
"""
Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise
......@@ -405,6 +406,11 @@ class Subtensor(Op):
[v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v
:param only_process_constants: If True, we only attempt to obtain
the value of an index/slice if it's directly constant and don't
try to dig through dimshuffles, fills, allocs, and other to figure
out its value.
"""
real_idx = get_idx_list(inputs, self.idx_list)
......@@ -417,12 +423,14 @@ class Subtensor(Op):
conv(val.step))
else:
try:
return get_scalar_constant_value(val)
return get_scalar_constant_value(val,
only_process_constants=only_process_constants)
except theano.tensor.NotScalarConstantError:
if allow_partial:
return val
else:
raise
return map(conv, real_idx)
def __init__(self, idx_list):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论