提交 c49564d2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix optimizations that depend on python scalar being in idx_list.

上级 28be70e6
......@@ -557,7 +557,7 @@ def get_scalar_constant_value(v):
data = v.data
return numpy_scalar(data)
if v.owner:
if getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard,
compile.DeepCopyOp)):
......
......@@ -1625,17 +1625,32 @@ def local_useless_subtensor(node):
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
node_input_idx = 1
for pos, idx in enumerate(node.op.idx_list):
idx_vals = get_idx_list(node.inputs, node.op.idx_list)
for pos, idx in enumerate(idx_vals):
if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless
return False
if idx.start not in [0, None]:
# Grab the values for start/stop/step
try:
start = get_scalar_constant_value(idx.start)
except NotScalarConstantError:
start = idx.start
try:
stop = get_scalar_constant_value(idx.stop)
except NotScalarConstantError:
stop = idx.stop
try:
step = get_scalar_constant_value(idx.step)
except NotScalarConstantError:
step = idx.step
if start not in [0, None]:
# If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless
return False
if idx.step not in [1, None]:
if step not in [1, None]:
# If we are going backwards, or skipping elements, then this
# is not a useless subtensor
return False
......@@ -1648,11 +1663,11 @@ def local_useless_subtensor(node):
except NotScalarConstantError:
pass
if isinstance(idx.stop, (int, numpy.integer)):
if idx.stop < length_pos_data:
if isinstance(stop, (int, numpy.integer)):
if stop < length_pos_data:
return False
elif isinstance(idx.stop, theano.scalar.Scalar):
length_pos_shape_i = node.inputs[node_input_idx]
elif isinstance(stop, gof.Variable):
length_pos_shape_i = stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
......@@ -1675,14 +1690,11 @@ def local_useless_subtensor(node):
assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in ["int8", "int16",
"int32", "int64"]
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
# length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos:
return False
elif idx.stop is None:
elif stop is None:
pass
else:
return False
......@@ -1737,8 +1749,7 @@ def local_subtensor_lift(node):
return [u.owner.op(*new_inputs)]
if isinstance(u.owner.op, T.Rebroadcast):
# make sure that Subtensor and Rebroadcast only have 1 input/output
assert len(node.inputs) == 1
# make sure that Rebroadcast has only 1 input
assert len(u.owner.inputs) == 1
# Subtensor might reduce dim., adapt broadcast pattern accordingly
......@@ -1760,7 +1771,7 @@ def local_subtensor_lift(node):
new_axis += [(j, u.broadcastable[i])]
j += 1
subt_x = Subtensor(node.op.idx_list)(u.owner.inputs[0])
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x]
......
......@@ -406,29 +406,28 @@ class Subtensor(Op):
# infer the broadcasting pattern
padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
idx_padded = get_idx_list((None,)+inputs, padded)
broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
for i, (p, bc) in enumerate(izip(idx_padded, x.type.broadcastable)):
if isinstance(p, slice):
if bc and p.start in [None, 0]:
# No need to check step when there is only
# one element.
# We could call get_canonical_form_slice() to
# catch more broadcast case. I let this to
# later.
# figure out the value of start and stop (if they are constant)
try:
if p.start is None:
start = 0
else:
start = get_scalar_constant_value(p.start)
if p.stop is None:
broadcastable.append(bc)
continue
try:
if p.start is None:
start = 0
else:
start = get_scalar_constant_value(p.start)
stop = 1
else:
stop = get_scalar_constant_value(p.stop)
if stop > start:
broadcastable.append(True)
continue
except theano.tensor.NotScalarConstantError:
pass
except theano.tensor.NotScalarConstantError:
start = None
stop = None
if bc and start == 0 and stop > start:
broadcastable.append(True)
continue
broadcastable.append(False)
return gof.Apply(self,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论