提交 b0e70bd0 authored 作者: abergeron's avatar abergeron

Merge pull request #1811 from nouiz/opt

Opt Reduce(join())
...@@ -44,7 +44,7 @@ if pygpu: ...@@ -44,7 +44,7 @@ if pygpu:
init_dev(config.device) init_dev(config.device)
import theano.compile import theano.compile
theano.compile.shared_constructor(gpuarray_shared_constructor) theano.compile.shared_constructor(gpuarray_shared_constructor)
optdb.add_tags('gpuarray_opt', 'fast_run', 'inplace') optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile', 'inplace')
elif config.gpuarray.init_device != '': elif config.gpuarray.init_device != '':
init_dev(config.gpuarray.init_device) init_dev(config.gpuarray.init_device)
except Exception: except Exception:
......
...@@ -1650,7 +1650,7 @@ def min(x, axis=None, keepdims=False): ...@@ -1650,7 +1650,7 @@ def min(x, axis=None, keepdims=False):
the result as dimensions with size one. With this option, the result the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor. will broadcast correctly against the original tensor.
""" """
x = as_tensor_variable(x)
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return -max(-x, axis=axis, keepdims=keepdims) return -max(-x, axis=axis, keepdims=keepdims)
...@@ -1671,7 +1671,7 @@ def argmin(x, axis=None, keepdims=False): ...@@ -1671,7 +1671,7 @@ def argmin(x, axis=None, keepdims=False):
the result as dimensions with size one. With this option, the result the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor. will broadcast correctly against the original tensor.
""" """
x = as_tensor_variable(x)
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return argmax(-x, axis=axis, keepdims=keepdims) return argmax(-x, axis=axis, keepdims=keepdims)
......
...@@ -3308,6 +3308,41 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any, ...@@ -3308,6 +3308,41 @@ ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
T.elemwise.Sum, T.elemwise.Prod, T.elemwise.Sum, T.elemwise.Prod,
T.elemwise.ProdWithoutZeros] T.elemwise.ProdWithoutZeros]
@register_canonicalize
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@gof.local_optimizer(ALL_REDUCE)
def local_reduce_join(node):
"""Max(Join(a,b), axis=0) -> Maximum(a,b) """
if (isinstance(node.op, T.CAReduce) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)):
join = node.inputs[0].owner
if T.extract_constant(join.inputs[0]) != 0:
return
if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)):
#Support only 2 inputs for now
if len(join.inputs) != 3:
return
elif not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul)):
return
new_inp = []
for inp in join.inputs[1:]:
inp = inp.owner
if not inp:
return
if (not isinstance(inp.op, DimShuffle) or
inp.op.new_order != ('x',) + tuple(range(inp.inputs[0].ndim))):
return
new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp)
if ret.dtype == node.outputs[0].dtype:
return [ret]
#else the reduction do something about the dtype.
@register_canonicalize @register_canonicalize
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
def local_cut_useless_reduce(node): def local_cut_useless_reduce(node):
......
...@@ -3776,8 +3776,10 @@ class T_local_sum(unittest.TestCase): ...@@ -3776,8 +3776,10 @@ class T_local_sum(unittest.TestCase):
class T_local_reduce(unittest.TestCase): class T_local_reduce(unittest.TestCase):
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize', self.mode = theano.compile.get_default_mode().including(
'specialize') 'canonicalize',
'specialize',
'uncanonicalize', 'local_max_and_argmax')
def test_local_reduce_broadcast_all_0(self): def test_local_reduce_broadcast_all_0(self):
for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod, for fct in [tensor.sum, tensor.all, tensor.any, tensor.prod,
...@@ -3827,6 +3829,28 @@ class T_local_reduce(unittest.TestCase): ...@@ -3827,6 +3829,28 @@ class T_local_reduce(unittest.TestCase):
isinstance(node.op, T.CAReduce) isinstance(node.op, T.CAReduce)
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
def test_local_reduce_join(self):
vx = matrix()
vy = matrix()
vz = matrix()
x = numpy.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = numpy.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = numpy.asarray([[5, 0], [1, 2]], dtype=config.floatX)
for out, res in [
(T.max((vx, vy), 0), numpy.max((x, y), 0)),
(T.min((vx, vy), 0), numpy.min((x, y), 0)),
(T.sum((vx, vy, vz), 0), numpy.sum((x, y, z), 0)),
(T.prod((vx, vy, vz), 0), numpy.prod((x, y, z), 0)),
(T.prod((vx, vy.T, vz), 0), numpy.prod((x, y.T, z), 0)),
]:
f = theano.function([vx, vy, vz], out,
on_unused_input='ignore', mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, T.Elemwise), out
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论