提交 fcb54f8a authored 作者: Frederic's avatar Frederic

Make Subtensor.make_node() detect more broadcastable dimensions.

上级 5bf326a7
...@@ -391,19 +391,6 @@ class Subtensor(Op): ...@@ -391,19 +391,6 @@ class Subtensor(Op):
exception.subtensor_invalid = True exception.subtensor_invalid = True
raise exception raise exception
# infer the broadcasting pattern
padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = [bc and p.start is None and p.stop is None
# No need to check step when there is only
# one element and start and stop are None,
# the output will always have 1 element.
# We could call get_canonical_form_slice() to
# catch more broadcast case. I let this to
# later.
for p, bc in izip(padded, x.type.broadcastable)
if isinstance(p, slice)]
input_types = Subtensor.collapse(idx_list, input_types = Subtensor.collapse(idx_list,
lambda entry: isinstance(entry, gof.Type)) lambda entry: isinstance(entry, gof.Type))
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
...@@ -416,6 +403,31 @@ class Subtensor(Op): ...@@ -416,6 +403,31 @@ class Subtensor(Op):
"Wrong type for Subtensor template. Expected %s, got %s." "Wrong type for Subtensor template. Expected %s, got %s."
% (input.type, expected_type)) % (input.type, expected_type))
# infer the broadcasting pattern
padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice):
if bc and p.start in [None, 0]:
# No need to check step when there is only
# one element.
# We could call get_canonical_form_slice() to
# catch more broadcast case. I let this to
# later.
if p.stop is None:
broadcastable.append(bc)
continue
try:
start = theano.tensor.get_scalar_constant_value(p.start)
stop = theano.tensor.get_scalar_constant_value(p.stop)
if stop > start:
broadcastable.append(True)
continue
except theano.tensor.NotScalarConstantError:
pass
broadcastable.append(False)
return gof.Apply(self, return gof.Apply(self,
(x, ) + inputs, (x, ) + inputs,
[theano.tensor.tensor(dtype=x.type.dtype, [theano.tensor.tensor(dtype=x.type.dtype,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论