提交 c49564d2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix optimizations that depend on python scalar being in idx_list.

上级 28be70e6
...@@ -557,7 +557,7 @@ def get_scalar_constant_value(v): ...@@ -557,7 +557,7 @@ def get_scalar_constant_value(v):
data = v.data data = v.data
return numpy_scalar(data) return numpy_scalar(data)
if v.owner: if getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast, if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
......
...@@ -1625,17 +1625,32 @@ def local_useless_subtensor(node): ...@@ -1625,17 +1625,32 @@ def local_useless_subtensor(node):
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
return return
shape_of = node.fgraph.shape_feature.shape_of shape_of = node.fgraph.shape_feature.shape_of
node_input_idx = 1 idx_vals = get_idx_list(node.inputs, node.op.idx_list)
for pos, idx in enumerate(node.op.idx_list): for pos, idx in enumerate(idx_vals):
if not isinstance(idx, slice): if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension # If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless # from the output, so the subtensor is not useless
return False return False
if idx.start not in [0, None]:
# Grab the values for start/stop/step
try:
start = get_scalar_constant_value(idx.start)
except NotScalarConstantError:
start = idx.start
try:
stop = get_scalar_constant_value(idx.stop)
except NotScalarConstantError:
stop = idx.stop
try:
step = get_scalar_constant_value(idx.step)
except NotScalarConstantError:
step = idx.step
if start not in [0, None]:
# If the start of the slice is different from 0, or is a # If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless # variable, then we assume the subtensor is not useless
return False return False
if idx.step not in [1, None]: if step not in [1, None]:
# If we are going backwards, or skipping elements, then this # If we are going backwards, or skipping elements, then this
# is not a useless subtensor # is not a useless subtensor
return False return False
...@@ -1648,11 +1663,11 @@ def local_useless_subtensor(node): ...@@ -1648,11 +1663,11 @@ def local_useless_subtensor(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if isinstance(idx.stop, (int, numpy.integer)): if isinstance(stop, (int, numpy.integer)):
if idx.stop < length_pos_data: if stop < length_pos_data:
return False return False
elif isinstance(idx.stop, theano.scalar.Scalar): elif isinstance(stop, gof.Variable):
length_pos_shape_i = node.inputs[node_input_idx] length_pos_shape_i = stop
# length_pos is a tensor variable, but length_pos_shape_i # length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent # is a scalar variable. We try to see if they represent
# the same underlying variable. # the same underlying variable.
...@@ -1675,14 +1690,11 @@ def local_useless_subtensor(node): ...@@ -1675,14 +1690,11 @@ def local_useless_subtensor(node):
assert str(length_pos.type.dtype) == "int64" assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in ["int8", "int16", assert str(length_pos_shape_i.type.dtype) in ["int8", "int16",
"int32", "int64"] "int32", "int64"]
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
# length_pos_shape_i cannot be None # length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos: if length_pos_shape_i != length_pos:
return False return False
elif idx.stop is None: elif stop is None:
pass pass
else: else:
return False return False
...@@ -1737,8 +1749,7 @@ def local_subtensor_lift(node): ...@@ -1737,8 +1749,7 @@ def local_subtensor_lift(node):
return [u.owner.op(*new_inputs)] return [u.owner.op(*new_inputs)]
if isinstance(u.owner.op, T.Rebroadcast): if isinstance(u.owner.op, T.Rebroadcast):
# make sure that Subtensor and Rebroadcast only have 1 input/output # make sure that Rebroadcast has only 1 input
assert len(node.inputs) == 1
assert len(u.owner.inputs) == 1 assert len(u.owner.inputs) == 1
# Subtensor might reduce dim., adapt broadcast pattern accordingly # Subtensor might reduce dim., adapt broadcast pattern accordingly
...@@ -1760,7 +1771,7 @@ def local_subtensor_lift(node): ...@@ -1760,7 +1771,7 @@ def local_subtensor_lift(node):
new_axis += [(j, u.broadcastable[i])] new_axis += [(j, u.broadcastable[i])]
j += 1 j += 1
subt_x = Subtensor(node.op.idx_list)(u.owner.inputs[0]) subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x) rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x] return [rbcast_subt_x]
......
...@@ -406,29 +406,28 @@ class Subtensor(Op): ...@@ -406,29 +406,28 @@ class Subtensor(Op):
# infer the broadcasting pattern # infer the broadcasting pattern
padded = (idx_list padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list))) + [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
idx_padded = get_idx_list((None,)+inputs, padded)
broadcastable = [] broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)): for i, (p, bc) in enumerate(izip(idx_padded, x.type.broadcastable)):
if isinstance(p, slice): if isinstance(p, slice):
if bc and p.start in [None, 0]: # figure out the value of start and stop (if they are constant)
# No need to check step when there is only try:
# one element. if p.start is None:
# We could call get_canonical_form_slice() to start = 0
# catch more broadcast case. I let this to else:
# later. start = get_scalar_constant_value(p.start)
if p.stop is None: if p.stop is None:
broadcastable.append(bc) stop = 1
continue else:
try:
if p.start is None:
start = 0
else:
start = get_scalar_constant_value(p.start)
stop = get_scalar_constant_value(p.stop) stop = get_scalar_constant_value(p.stop)
if stop > start: except theano.tensor.NotScalarConstantError:
broadcastable.append(True) start = None
continue stop = None
except theano.tensor.NotScalarConstantError:
pass if bc and start == 0 and stop > start:
broadcastable.append(True)
continue
broadcastable.append(False) broadcastable.append(False)
return gof.Apply(self, return gof.Apply(self,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论