提交 53486aae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent test value computation from complaining about subtensor creation

The temporary values created during subtensor broadcast-value computations would cause the test value computation steps to complain/err. This change prevents that.
上级 9f56e418
...@@ -2183,6 +2183,19 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2183,6 +2183,19 @@ class TestInferShape(utt.InferShapeTester):
AdvancedSubtensor, AdvancedSubtensor,
) )
admat.tag.test_value = admat_val
aivec.tag.test_value = aivec_val
bivec.tag.test_value = bivec_val
# Make sure it doesn't complain about test values
with theano.change_flags(compute_test_value="raise"):
self._compile_and_check(
[admat, aivec],
[admat[1:3, aivec]],
[admat_val, aivec_val],
AdvancedSubtensor,
)
def test_AdvancedSubtensor_bool(self): def test_AdvancedSubtensor_bool(self):
n = dmatrix() n = dmatrix()
n_val = np.arange(6).reshape((2, 3)) n_val = np.arange(6).reshape((2, 3))
......
...@@ -2321,24 +2321,28 @@ class AdvancedSubtensor(Op): ...@@ -2321,24 +2321,28 @@ class AdvancedSubtensor(Op):
# `Subtensor` calls, so we create a fake symbolic shape tuple and # `Subtensor` calls, so we create a fake symbolic shape tuple and
# identify the broadcast dimensions from the shape result of this # identify the broadcast dimensions from the shape result of this
# entire subtensor operation. # entire subtensor operation.
fake_shape = tuple( with theano.change_flags(compute_test_value="off"):
theano.tensor.tensor(dtype="int64", broadcastable=()) if not bcast else 1 fake_shape = tuple(
for bcast in x.broadcastable theano.tensor.tensor(dtype="int64", broadcastable=())
) if not bcast
else 1
for bcast in x.broadcastable
)
bcast_index = tuple( bcast_index = tuple(
chain.from_iterable( chain.from_iterable(
theano.tensor.basic.nonzero(idx) theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0 and getattr(idx, "dtype", None) == "bool" if getattr(idx, "ndim", 0) > 0
else (idx,) and getattr(idx, "dtype", None) == "bool"
for idx in index else (idx,)
for idx in index
)
) )
)
bcast = [ bcast = [
getattr(i, "value", i) == 1 getattr(i, "value", i) == 1
for i in indexed_result_shape(fake_shape, bcast_index) for i in indexed_result_shape(fake_shape, bcast_index)
] ]
return gof.Apply( return gof.Apply(
self, self,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论