提交 4249a01c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a BroadcastTo Op for NumPy's broadcast_to function

上级 89f18b35
......@@ -3,10 +3,11 @@ import pytest
import theano
from tests import unittest_tools as utt
from theano import config, function
from theano import change_flags, config, function
from theano import tensor as tt
from theano.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CpuContiguous,
CumOp,
DiffOp,
......@@ -20,6 +21,7 @@ from theano.tensor.extra_ops import (
bartlett,
bincount,
broadcast_shape,
broadcast_to,
compress,
cpu_contiguous,
cumprod,
......@@ -1305,3 +1307,70 @@ def test_broadcast_shape():
y_tt = tt.ones(y_shapes)
b_tt = broadcast_shape(x_tt, y_tt)
assert isinstance(b_tt[-1].owner.op, tt.opt.Assert)
class TestBroadcastTo(utt.InferShapeTester):
rng = np.random.RandomState(43)
def setup_method(self):
super().setup_method()
self.op_class = BroadcastTo
self.op = broadcast_to
@change_flags(compute_test_value="raise")
def test_perform(self):
a = tt.scalar()
a.tag.test_value = 5
s_1 = tt.iscalar("s_1")
s_1.tag.test_value = 4
shape = (s_1, 1)
bcast_res = broadcast_to(a, shape)
assert bcast_res.broadcastable == (False, True)
bcast_np = np.broadcast_to(5, (4, 1))
bcast_tt = bcast_res.get_test_value()
assert np.array_equal(bcast_tt, bcast_np)
assert np.shares_memory(bcast_tt, a.get_test_value())
@pytest.mark.parametrize(
"fn,input_dims",
[
[lambda x: broadcast_to(x, (1,)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)],
[lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)],
],
)
def test_gradient(self, fn, input_dims):
utt.verify_grad(
fn,
[np.random.rand(*input_dims).astype(config.floatX)],
n_tests=1,
rng=self.rng,
)
def test_infer_shape(self):
a = tt.tensor(config.floatX, [False, True, False])
shape = list(a.shape)
out = self.op(a, shape)
self._compile_and_check(
[a] + shape,
[out],
[np.random.rand(2, 1, 3).astype(config.floatX), 2, 1, 3],
self.op_class,
)
a = tt.tensor(config.floatX, [False, True, False])
shape = [tt.iscalar() for i in range(4)]
self._compile_and_check(
[a] + shape,
[self.op(a, shape)],
[np.random.rand(2, 1, 3).astype(config.floatX), 6, 2, 5, 3],
self.op_class,
)
......@@ -1553,3 +1553,56 @@ def broadcast_shape_iter(arrays, **kwargs):
result_dims.append(one)
return tuple(result_dims)
class BroadcastTo(Op):
view_map = {0: [0]}
def __call__(self, a, shape, **kwargs):
return super().__call__(a, *shape, **kwargs)
def make_node(self, a, *shape):
a = basic.as_tensor_variable(a)
shape = basic.as_tensor_variable(shape, ndim=1)
shape, bcast = basic.alloc_validate_shape(shape)
out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)()
return theano.Apply(self, [a] + shape, [out])
def perform(self, node, inputs, output_storage):
a, *shape = inputs
z = output_storage[0]
z[0] = np.broadcast_to(a, shape)
def grad(self, inputs, outputs_gradients):
a, *shape = inputs
(dout,) = outputs_gradients
# Determine the dimensions that were added by broadcasting
new_dims = list(range(dout.ndim - a.ndim))
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
# Determine the dimensions that were broadcast
_, shape_bcast = basic.alloc_validate_shape(shape)
bcast_sums = [
i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
if a_b and not s_b
]
if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
]
def infer_shape(self, node, ins_shapes):
return [node.inputs[1:]]
broadcast_to = BroadcastTo()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论