提交 adf83fa5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix bug in Numba Elemwise test that uses MyMultiOut

上级 65ae5df3
......@@ -64,20 +64,26 @@ class MyMultiOut(Op):
nin = 2
nout = 2
def make_node(self, a, b):
return Apply(self, [a, b], [a.type(), b.type()])
def impl(self, a, b):
@staticmethod
def impl(a, b):
res1 = 2 * a
res2 = 2 * b
return [res1, res2]
def make_node(self, a, b):
return Apply(self, [a, b], [a.type(), b.type()])
def perform(self, node, inputs, outputs):
res1, res2 = self.impl(inputs[0], inputs[1])
outputs[0][0] = res1
outputs[1][0] = res2
my_multi_out = Elemwise(MyMultiOut())
my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2
opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......@@ -317,7 +323,7 @@ def test_create_numba_signature(v, expected, force_scalar):
rng.randn(100).astype(config.floatX),
rng.randn(100).astype(config.floatX),
],
lambda x, y: Elemwise(MyMultiOut())(x, y),
lambda x, y: my_multi_out(x, y),
NotImplementedError,
),
],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论