提交 c042a9c4 authored 作者: Xavier Bouthillier's avatar Xavier Bouthillier

Merge pull request #3362 from Thrandis/gpu_reshape

GpuReshape opt.
...@@ -120,6 +120,9 @@ gpu_optimizer.register('local_remove_all_assert', ...@@ -120,6 +120,9 @@ gpu_optimizer.register('local_remove_all_assert',
theano.tensor.opt.local_remove_all_assert, theano.tensor.opt.local_remove_all_assert,
'unsafe') 'unsafe')
# Register local_reshape_chain
register_opt(name='local_gpu_reshape_chain')(
theano.tensor.opt.local_reshape_chain(GpuReshape))
# This is a partial list of CPU ops that can be in some circonstance # This is a partial list of CPU ops that can be in some circonstance
# moved to the GPU. This list is used by an optimization. # moved to the GPU. This list is used by an optimization.
...@@ -944,35 +947,6 @@ def local_gpu_reshape(node): ...@@ -944,35 +947,6 @@ def local_gpu_reshape(node):
return False return False
@local_optimizer([GpuReshape])
def local_gpu_reshape_chain(node):
"""
GuReshape(GpuReshape(shape1),shape2) -> GpuReshape(shape2)
"""
if not tensor.opt.opt.check_chain(node, GpuReshape, GpuReshape):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# It might happen that the desired output of this node has a broadcastable
# pattern that does not match that of 'rval'. This is when originally, we
# were able to figure out that one of the dimensions of the reshape is one,
# but some other transformation replaced the shape by one for which this
# cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# optimization.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
gpu_cut_copies.register('cut_local_gpu_reshape_chain',
local_gpu_reshape_chain,
'fast_run', 'gpu')
@register_opt() @register_opt()
@local_optimizer([gpu_from_host, tensor.Flatten]) @local_optimizer([gpu_from_host, tensor.Flatten])
def local_gpu_flatten(node): def local_gpu_flatten(node):
......
...@@ -32,6 +32,7 @@ from theano.scalar.basic_scipy import erfinv ...@@ -32,6 +32,7 @@ from theano.scalar.basic_scipy import erfinv
from theano.sandbox.blocksparse import sparse_block_dot from theano.sandbox.blocksparse import sparse_block_dot
from theano.sandbox.cuda.blocksparse import GpuSparseBlockGemv, GpuSparseBlockOuter from theano.sandbox.cuda.blocksparse import GpuSparseBlockGemv, GpuSparseBlockOuter
if theano.config.mode == 'FAST_COMPILE': if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu') mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpu') mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpu')
...@@ -819,14 +820,10 @@ class test_diag(theano.tensor.tests.test_nlinalg.test_diag): ...@@ -819,14 +820,10 @@ class test_diag(theano.tensor.tests.test_nlinalg.test_diag):
self).__init__(name) self).__init__(name)
def test_local_gpu_reshape(): class Test_GpuReshape(test_opt.Test_Reshape):
mode = mode_with_gpu def setUp(self):
a = tensor.fmatrix() self.mode = mode_with_gpu
b = basic_ops.GpuReshape(3)(a, [2, 3, 4]) self.op = basic_ops.GpuReshape
c = basic_ops.GpuReshape(1)(b, [24])
f = theano.function([a], c, mode=mode)
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, basic_ops.GpuReshape) for node in topo) == 1
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -3413,23 +3413,24 @@ def local_flatten_lift(node): ...@@ -3413,23 +3413,24 @@ def local_flatten_lift(node):
################## ##################
@gof.local_optimizer([T.Reshape]) def local_reshape_chain(op):
def local_reshape_chain(node): @gof.local_optimizer([op])
def f(node):
""" """
Reshape(Reshape(shape1),shape2) -> Reshape(shape2) Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
""" """
if not opt.check_chain(node, T.Reshape, T.Reshape): if not opt.check_chain(node, op, op):
return False return False
# TODO: this can permit a failing program to run by eliminating # TODO: this can permit a failing program to run by eliminating
# the lower reshape # the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# It might happen that the desired output of this node has a broadcastable # It might happen that the desired output of this node has a
# pattern that does not match that of 'rval'. This is when originally, we # broadcastable pattern that does not match that of 'rval'. This is
# were able to figure out that one of the dimensions of the reshape is one, # when originally, we were able to figure out that one of the
# but some other transformation replaced the shape by one for which this # dimensions of the reshape is one, but some other transformation
# cannot be guessed. # replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this # We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this # constant value... but in the meantime, better not apply this
# optimization. # optimization.
...@@ -3437,7 +3438,10 @@ def local_reshape_chain(node): ...@@ -3437,7 +3438,10 @@ def local_reshape_chain(node):
return [rval] return [rval]
else: else:
return False return False
register_canonicalize(local_reshape_chain)
return f
register_canonicalize(local_reshape_chain(T.Reshape),
name='local_reshape_chain')
@register_canonicalize @register_canonicalize
......
...@@ -5433,6 +5433,20 @@ def test_local_flatten_lift(): ...@@ -5433,6 +5433,20 @@ def test_local_flatten_lift():
assert isinstance(topo[1].op, tensor.Elemwise) assert isinstance(topo[1].op, tensor.Elemwise)
class Test_Reshape(unittest.TestCase):
def setUp(self):
self.mode = mode_opt
self.op = tensor.Reshape
def test_local_reshape(self):
a = tensor.fmatrix()
b = self.op(3)(a, [2, 3, 4])
c = self.op(1)(b, [24])
f = theano.function([a], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, self.op) for node in topo) == 1
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor.tensor4() x = tensor.tensor4()
out = T.exp(x).reshape([x.size]) out = T.exp(x).reshape([x.size])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论