提交 a1679df8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a C implementation for BroadcastTo

上级 4413fb47
...@@ -1600,9 +1600,11 @@ def broadcast_shape_iter( ...@@ -1600,9 +1600,11 @@ def broadcast_shape_iter(
return tuple(result_dims) return tuple(result_dims)
class BroadcastTo(Op): class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`.""" """An `Op` for `numpy.broadcast_to`."""
__props__ = ()
view_map = {0: [0]} view_map = {0: [0]}
def __call__(self, a, shape, **kwargs): def __call__(self, a, shape, **kwargs):
...@@ -1652,6 +1654,56 @@ class BroadcastTo(Op): ...@@ -1652,6 +1654,56 @@ class BroadcastTo(Op):
def infer_shape(self, fgraph, node, ins_shapes): def infer_shape(self, fgraph, node, ins_shapes):
return [node.inputs[1:]] return [node.inputs[1:]]
def c_code(self, node, name, inputs, outputs, sub):
(x, *shape) = inputs
(out,) = outputs
ndims = len(shape)
fail = sub["fail"]
# TODO: Could just use `PyArray_Return`, no?
dims_array = ", ".join(
[
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
for i, shape in enumerate(shape)
]
)
src = (
"""
npy_intp itershape[%(ndims)s] = {%(dims_array)s};
PyArrayObject *ops[1] = {%(x)s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim = %(ndims)s;
int* op_axes[1] = {NULL};
npy_intp buffersize = 0;
NpyIter *iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)s = NpyIter_GetIterView(iter, 0);
if(%(out)s == NULL){
NpyIter_Deallocate(iter);
%(fail)s;
}
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
%(fail)s;
}
"""
% locals()
)
return src
def c_code_cache_version(self):
return (1,)
broadcast_to_ = BroadcastTo() broadcast_to_ = BroadcastTo()
......
...@@ -1242,24 +1242,70 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1242,24 +1242,70 @@ class TestBroadcastTo(utt.InferShapeTester):
assert y.owner.inputs[1].owner is None assert y.owner.inputs[1].owner is None
assert y.owner.inputs[2].owner is None assert y.owner.inputs[2].owner is None
@config.change_flags(compute_test_value="raise") @pytest.mark.parametrize("linker", ["cvm", "py"])
def test_perform(self): def test_perform(self, linker):
a = scalar()
a.tag.test_value = 5
a = aesara.shared(5)
s_1 = iscalar("s_1") s_1 = iscalar("s_1")
s_1.tag.test_value = 4
shape = (s_1, 1) shape = (s_1, 1)
bcast_res = broadcast_to(a, shape) bcast_res = broadcast_to(a, shape)
assert bcast_res.broadcastable == (False, True) assert bcast_res.broadcastable == (False, True)
bcast_fn = aesara.function(
[s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
)
bcast_fn.vm.allow_gc = False
bcast_at = bcast_fn(4)
bcast_np = np.broadcast_to(5, (4, 1)) bcast_np = np.broadcast_to(5, (4, 1))
bcast_at = bcast_res.get_test_value()
assert np.array_equal(bcast_at, bcast_np) assert np.array_equal(bcast_at, bcast_np)
assert np.shares_memory(bcast_at, a.get_test_value())
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
bcast_in = bcast_fn.vm.storage_map[a]
bcast_out = bcast_fn.vm.storage_map[bcast_var]
if linker != "py":
assert np.shares_memory(bcast_out[0], bcast_in[0])
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_memory_leak(self):
import gc
import tracemalloc
from aesara.link.c.cvm import CVM
n = 100_000
x = aesara.shared(np.ones(n, dtype=np.float64))
y = broadcast_to(x, (5, n))
f = aesara.function([], y, mode=Mode(optimizer=None, linker="cvm"))
assert isinstance(f.vm, CVM)
assert len(f.maker.fgraph.apply_nodes) == 2
assert any(
isinstance(node.op, BroadcastTo) for node in f.maker.fgraph.apply_nodes
)
tracemalloc.start()
blocks_last = None
block_diffs = []
for i in range(1, 50):
x.set_value(np.ones(n))
_ = f()
_ = gc.collect()
blocks_i, _ = tracemalloc.get_traced_memory()
if blocks_last is not None:
blocks_diff = (blocks_i - blocks_last) // 10**3
block_diffs.append(blocks_diff)
blocks_last = blocks_i
tracemalloc.stop()
assert np.allclose(np.mean(block_diffs), 0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fn,input_dims", "fn,input_dims",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论