提交 75405be7 authored 作者: Frederic's avatar Frederic

crash fix to the new tests that will be in the next commit.

Make subtensor.make_node infer more broadcastable. This remove crash from local_subtensor_merge that was loosing broadcastable dimensions.
上级 2d73c170
...@@ -593,6 +593,8 @@ def get_scalar_constant_value(orig_v, elemwise=True): ...@@ -593,6 +593,8 @@ def get_scalar_constant_value(orig_v, elemwise=True):
# mess with the stabilization optimization and be too slow. # mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice() # We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly. # to allow it to determine the broadcast pattern correctly.
elif isinstance(v.owner.op, ScalarFromTensor):
return get_scalar_constant_value(v.owner.inputs[0])
elif isinstance(v.owner.op, scal.ScalarOp): elif isinstance(v.owner.op, scal.ScalarOp):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
......
...@@ -16,7 +16,7 @@ from theano.gof.python25 import maxsize ...@@ -16,7 +16,7 @@ from theano.gof.python25 import maxsize
from theano.printing import pprint from theano.printing import pprint
from theano import scalar as scal from theano import scalar as scal
from theano.tensor.basic import (addbroadcast, clip, get_scalar_constant_value, from theano.tensor.basic import (addbroadcast, clip, get_scalar_constant_value,
ARange, TensorType) ARange, TensorType, NotScalarConstantError)
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tensor.type_other import NoneConst, SliceType, make_slice from theano.tensor.type_other import NoneConst, SliceType, make_slice
from theano import config from theano import config
...@@ -470,15 +470,22 @@ class Subtensor(Op): ...@@ -470,15 +470,22 @@ class Subtensor(Op):
broadcastable = [] broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)): for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice): if isinstance(p, slice):
if bc and p.start in [None, 0]: if bc:
start = p.start start = p.start
if start is None: try:
start = 0 start = get_scalar_constant_value(start)
if (p.stop is None or except NotScalarConstantError:
(isinstance(p.stop, (int, numpy.integer)) and pass
p.stop > start)): if start in [None, 0]:
broadcastable.append(True) start = p.start
continue if start is None:
start = 0
if (p.stop is None or
(isinstance(p.stop, (int, numpy.integer,
numpy.ndarray)) and
p.stop > start)):
broadcastable.append(True)
continue
broadcastable.append(False) broadcastable.append(False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论