提交 6acc1699 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use CAReduceDtype{add} instead of CAReduce{add}

A test was using CAReduce{add} instead of Sum, so the accumulation dtype could not be set. Also fixed CAReduceDtype.make_node to allow use of CAReduceDtype{Add} directly, instead of Sum.
上级 ad8a1755
...@@ -64,7 +64,9 @@ def test_careduce(): ...@@ -64,7 +64,9 @@ def test_careduce():
TODO: test with broadcast TODO: test with broadcast
""" """
for scalar_op in [theano.scalar.add, theano.scalar.maximum]: for scalar_op, careduce_op in [
(theano.scalar.add, tensor.elemwise.CAReduceDtype),
(theano.scalar.maximum, tensor.CAReduce)]:
for shape, pattern in [((1,1),(1,)), for shape, pattern in [((1,1),(1,)),
((1,0),(1,)), ((1,0),(1,)),
((0,1),(1,)), ((0,1),(1,)),
...@@ -119,7 +121,7 @@ def test_careduce(): ...@@ -119,7 +121,7 @@ def test_careduce():
]: ]:
op = tensor.CAReduce(scalar_op, axis=pattern) op = careduce_op(scalar_op, axis=pattern)
pat = tensor_pattern_to_gpu_pattern(shape, pattern) pat = tensor_pattern_to_gpu_pattern(shape, pattern)
#GpuCAReduce{maximum} support only those patterns #GpuCAReduce{maximum} support only those patterns
if scalar_op is theano.scalar.maximum and pat not in [ if scalar_op is theano.scalar.maximum and pat not in [
...@@ -187,7 +189,7 @@ def test_careduce(): ...@@ -187,7 +189,7 @@ def test_careduce():
((5,4),[0,1]),((5,4),[0]), ((5,4),[0,1]),((5,4),[0]),
((5,4,3),[0]),((5,4,3),[0,1]),((5,4,3),[2]),((5,4,3),[0,1,2]), ((5,4,3),[0]),((5,4,3),[0,1]),((5,4,3),[2]),((5,4,3),[0,1,2]),
((5,4,3,2),[0,1,2,3]), ((5,4,3,2),[0,2,3])]: ((5,4,3,2),[0,1,2,3]), ((5,4,3,2),[0,2,3])]:
op = tensor.CAReduce(scalar_op, axis=pattern) op = careduce_op(scalar_op, axis=pattern)
pat = tensor_pattern_to_gpu_pattern(shape, pattern) pat = tensor_pattern_to_gpu_pattern(shape, pattern)
#GpuCAReduce{maximum} support only those patterns #GpuCAReduce{maximum} support only those patterns
if scalar_op is theano.scalar.maximum and pat not in [ if scalar_op is theano.scalar.maximum and pat not in [
...@@ -219,7 +221,7 @@ def test_careduce(): ...@@ -219,7 +221,7 @@ def test_careduce():
((5,4,3),[0]),((5,4,3),[0,1]), ((5,4,3),[0]),((5,4,3),[0,1]),
((5,4,3),[2]),((5,4,3),[0,1,2]), ((5,4,3),[2]),((5,4,3),[0,1,2]),
((5,4,3,2),[0,1,2,3]), ((5,4,3,2),[0,2,3])]: ((5,4,3,2),[0,1,2,3]), ((5,4,3,2),[0,2,3])]:
op = tensor.CAReduce(scalar_op, axis=pattern) op = careduce_op(scalar_op, axis=pattern)
pat = tensor_pattern_to_gpu_pattern(shape, pattern) pat = tensor_pattern_to_gpu_pattern(shape, pattern)
#GpuCAReduce{maximum} support only those patterns #GpuCAReduce{maximum} support only those patterns
if scalar_op is theano.scalar.maximum and pat not in [ if scalar_op is theano.scalar.maximum and pat not in [
......
...@@ -1242,6 +1242,7 @@ class CAReduce(Op): ...@@ -1242,6 +1242,7 @@ class CAReduce(Op):
# We can't call self.__class__() as there is class that # We can't call self.__class__() as there is class that
# inherit from CAReduce that don't have the same signature # inherit from CAReduce that don't have the same signature
op = copy(self) op = copy(self)
op.set_ufunc(op.scalar_op)
op.axis = axis op.axis = axis
else: else:
op = self op = self
...@@ -1733,8 +1734,10 @@ class CAReduceDtype(CAReduce): ...@@ -1733,8 +1734,10 @@ class CAReduceDtype(CAReduce):
# Don't build another instance # Don't build another instance
op = self op = self
else: else:
op = self.__class__(axis=self.axis, op = copy(self)
dtype=dtype, acc_dtype=acc_dtype) op.set_ufunc(self.scalar_op)
op.dtype = dtype
op.acc_dtype = acc_dtype
assert op.acc_dtype is not None assert op.acc_dtype is not None
return CAReduce.make_node(op, input) return CAReduce.make_node(op, input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论