提交 53486aae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent test value computation from complaining about subtensor creation

The temporary values created during subtensor broadcast-value computations would cause the test value computation steps to complain/err. This change prevents that.
上级 9f56e418
......@@ -2183,6 +2183,19 @@ class TestInferShape(utt.InferShapeTester):
AdvancedSubtensor,
)
admat.tag.test_value = admat_val
aivec.tag.test_value = aivec_val
bivec.tag.test_value = bivec_val
# Make sure it doesn't complain about test values
with theano.change_flags(compute_test_value="raise"):
self._compile_and_check(
[admat, aivec],
[admat[1:3, aivec]],
[admat_val, aivec_val],
AdvancedSubtensor,
)
def test_AdvancedSubtensor_bool(self):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))
......
......@@ -2321,24 +2321,28 @@ class AdvancedSubtensor(Op):
# `Subtensor` calls, so we create a fake symbolic shape tuple and
# identify the broadcast dimensions from the shape result of this
# entire subtensor operation.
fake_shape = tuple(
theano.tensor.tensor(dtype="int64", broadcastable=()) if not bcast else 1
for bcast in x.broadcastable
)
with theano.change_flags(compute_test_value="off"):
fake_shape = tuple(
theano.tensor.tensor(dtype="int64", broadcastable=())
if not bcast
else 1
for bcast in x.broadcastable
)
bcast_index = tuple(
chain.from_iterable(
theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0 and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
bcast_index = tuple(
chain.from_iterable(
theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
)
)
)
bcast = [
getattr(i, "value", i) == 1
for i in indexed_result_shape(fake_shape, bcast_index)
]
bcast = [
getattr(i, "value", i) == 1
for i in indexed_result_shape(fake_shape, bcast_index)
]
return gof.Apply(
self,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论