提交 0a7ad736 authored 作者: nouiz's avatar nouiz

Merge pull request #387 from lamblin/alloc_of_broadcasted

Use C code for Alloc of broadcasted scalar
...@@ -2542,7 +2542,7 @@ class Alloc(gof.Op): ...@@ -2542,7 +2542,7 @@ class Alloc(gof.Op):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
# TODO: use the elemwise code generator here # TODO: use the elemwise code generator here
if node.inputs[0].ndim == 0: if python_all(node.inputs[0].broadcastable):
# filling with a scalar is a common use of alloc # filling with a scalar is a common use of alloc
# that we can implement relatively easily # that we can implement relatively easily
vv = inp[0] vv = inp[0]
......
...@@ -1208,6 +1208,8 @@ AllocTester = makeBroadcastTester( ...@@ -1208,6 +1208,8 @@ AllocTester = makeBroadcastTester(
op = alloc, op = alloc,
expected = (lambda x, *shp: numpy.zeros(shp, dtype=x.dtype) + x), expected = (lambda x, *shp: numpy.zeros(shp, dtype=x.dtype) + x),
good = dict( good = dict(
correct01 = (rand(), numpy.int32(7)),
correct01_bcast = (rand(1), numpy.int32(7)),
correct02 = (rand(), numpy.int32(4), numpy.int32(7)), correct02 = (rand(), numpy.int32(4), numpy.int32(7)),
correct12 = (rand(7), numpy.int32(4), numpy.int32(7)), correct12 = (rand(7), numpy.int32(4), numpy.int32(7)),
correct13 = (rand(7), numpy.int32(2), numpy.int32(4), numpy.int32(7)), correct13 = (rand(7), numpy.int32(2), numpy.int32(4), numpy.int32(7)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论