提交 ba126742 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert boolean indices to integers when determining broadcast pattern

上级 e96b285d
......@@ -2236,6 +2236,9 @@ class TestInferShape(utt.InferShapeTester):
check_topo=False,
)
abs_res = n[~tensor.isinf(n)]
assert abs_res.broadcastable == (False,)
@change_flags(compute_test_value="raise")
def test_basic_shape():
......
......@@ -2360,7 +2360,7 @@ class BaseAdvancedSubtensor(Op):
__props__ = ()
def make_node(self, x, *index):
def make_node(self, x, *index, is_boolean=False):
x = theano.tensor.as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
......@@ -2372,9 +2372,23 @@ class BaseAdvancedSubtensor(Op):
theano.tensor.tensor(dtype="int64", broadcastable=()) if not bcast else 1
for bcast in x.broadcastable
)
bcast_index = index
if is_boolean:
bcast_index = tuple(
chain.from_iterable(
theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
else (idx,)
for idx in bcast_index
)
)
bcast = [
getattr(i, "value", i) == 1 for i in indexed_result_shape(fake_shape, index)
getattr(i, "value", i) == 1
for i in indexed_result_shape(fake_shape, bcast_index)
]
return gof.Apply(
self,
(x,) + index,
......@@ -2465,6 +2479,9 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
"""
def make_node(self, x, *index):
return super().make_node(x, *index, is_boolean=True)
def grad(self, inputs, grads):
(gz,) = grads
x = inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论