提交 ff9f0ea8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

test for weird strides, small fixes

上级 9088d234
......@@ -120,6 +120,16 @@ class _test_Broadcast(unittest.TestCase):
f(xv, yv)
assert (xv == yv).all()
def test_weird_strides(self):
x = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'x'))
y = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'y'))
e = Broadcast(Add, (x, y)).out
f = gof.CLinker(env([x, y], [e])).make_function(inplace = False)
xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
zv = xv + yv
assert (f(xv, yv) == zv).all()
class _test_CAReduce(unittest.TestCase):
......
......@@ -397,9 +397,6 @@ class CAReduce(Op):
if dimensions_to_reduce is None:
dimensions_to_reduce = range(len(inputs[0].broadcastable))
self.nin = 1
self.nout = 1
self.inputs = inputs
self.outputs = [Tensor(dtype = inputs[0].dtype,
broadcastable = [x for i, x in enumerate(inputs[0].broadcastable) if i not in dimensions_to_reduce])]
......
......@@ -153,7 +153,7 @@ class _Op(BaseTensorOp):
return self.c_impl(self.inputs, self.outputs) % sub
def c_impl(self, inputs, outputs):
raise AbstractFunctionError()
raise AbstractFunctionError("No c_impl for %s" % self.__class__.__name__)
class _Unary:
nin = 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论