提交 53486aae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent test value computation from complaining about subtensor creation

The temporary values created during subtensor broadcast-value computations would cause the test value computation steps to complain/err. This change prevents that.
上级 9f56e418
...@@ -2183,6 +2183,19 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2183,6 +2183,19 @@ class TestInferShape(utt.InferShapeTester):
AdvancedSubtensor, AdvancedSubtensor,
) )
admat.tag.test_value = admat_val
aivec.tag.test_value = aivec_val
bivec.tag.test_value = bivec_val
# Make sure it doesn't complain about test values
with theano.change_flags(compute_test_value="raise"):
self._compile_and_check(
[admat, aivec],
[admat[1:3, aivec]],
[admat_val, aivec_val],
AdvancedSubtensor,
)
def test_AdvancedSubtensor_bool(self): def test_AdvancedSubtensor_bool(self):
n = dmatrix() n = dmatrix()
n_val = np.arange(6).reshape((2, 3)) n_val = np.arange(6).reshape((2, 3))
......
...@@ -2321,15 +2321,19 @@ class AdvancedSubtensor(Op): ...@@ -2321,15 +2321,19 @@ class AdvancedSubtensor(Op):
# `Subtensor` calls, so we create a fake symbolic shape tuple and # `Subtensor` calls, so we create a fake symbolic shape tuple and
# identify the broadcast dimensions from the shape result of this # identify the broadcast dimensions from the shape result of this
# entire subtensor operation. # entire subtensor operation.
with theano.change_flags(compute_test_value="off"):
fake_shape = tuple( fake_shape = tuple(
theano.tensor.tensor(dtype="int64", broadcastable=()) if not bcast else 1 theano.tensor.tensor(dtype="int64", broadcastable=())
if not bcast
else 1
for bcast in x.broadcastable for bcast in x.broadcastable
) )
bcast_index = tuple( bcast_index = tuple(
chain.from_iterable( chain.from_iterable(
theano.tensor.basic.nonzero(idx) theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0 and getattr(idx, "dtype", None) == "bool" if getattr(idx, "ndim", 0) > 0
and getattr(idx, "dtype", None) == "bool"
else (idx,) else (idx,)
for idx in index for idx in index
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论