提交 6acc1699 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use CAReduceDtype{add} instead of CAReduce{add}

A test was using CAReduce{add} instead of Sum, so the accumulation dtype could not be set. Also fixed CAReduceDtype.make_node to allow use of CAReduceDtype{Add} directly, instead of Sum.
上级 ad8a1755
......@@ -64,7 +64,9 @@ def test_careduce():
TODO: test with broadcast
"""
for scalar_op in [theano.scalar.add, theano.scalar.maximum]:
for scalar_op, careduce_op in [
(theano.scalar.add, tensor.elemwise.CAReduceDtype),
(theano.scalar.maximum, tensor.CAReduce)]:
for shape, pattern in [((1,1),(1,)),
((1,0),(1,)),
((0,1),(1,)),
......@@ -119,7 +121,7 @@ def test_careduce():
]:
op = tensor.CAReduce(scalar_op, axis=pattern)
op = careduce_op(scalar_op, axis=pattern)
pat = tensor_pattern_to_gpu_pattern(shape, pattern)
#GpuCAReduce{maximum} support only those patterns
if scalar_op is theano.scalar.maximum and pat not in [
......@@ -187,7 +189,7 @@ def test_careduce():
((5,4),[0,1]),((5,4),[0]),
((5,4,3),[0]),((5,4,3),[0,1]),((5,4,3),[2]),((5,4,3),[0,1,2]),
((5,4,3,2),[0,1,2,3]), ((5,4,3,2),[0,2,3])]:
op = tensor.CAReduce(scalar_op, axis=pattern)
op = careduce_op(scalar_op, axis=pattern)
pat = tensor_pattern_to_gpu_pattern(shape, pattern)
#GpuCAReduce{maximum} support only those patterns
if scalar_op is theano.scalar.maximum and pat not in [
......@@ -219,7 +221,7 @@ def test_careduce():
((5,4,3),[0]),((5,4,3),[0,1]),
((5,4,3),[2]),((5,4,3),[0,1,2]),
((5,4,3,2),[0,1,2,3]), ((5,4,3,2),[0,2,3])]:
op = tensor.CAReduce(scalar_op, axis=pattern)
op = careduce_op(scalar_op, axis=pattern)
pat = tensor_pattern_to_gpu_pattern(shape, pattern)
#GpuCAReduce{maximum} support only those patterns
if scalar_op is theano.scalar.maximum and pat not in [
......
......@@ -1242,6 +1242,7 @@ class CAReduce(Op):
# We can't call self.__class__() as there is class that
# inherit from CAReduce that don't have the same signature
op = copy(self)
op.set_ufunc(op.scalar_op)
op.axis = axis
else:
op = self
......@@ -1733,8 +1734,10 @@ class CAReduceDtype(CAReduce):
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis,
dtype=dtype, acc_dtype=acc_dtype)
op = copy(self)
op.set_ufunc(self.scalar_op)
op.dtype = dtype
op.acc_dtype = acc_dtype
assert op.acc_dtype is not None
return CAReduce.make_node(op, input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论