提交 c042a9c4 authored 作者: Xavier Bouthillier's avatar Xavier Bouthillier

Merge pull request #3362 from Thrandis/gpu_reshape

GpuReshape opt.
......@@ -120,6 +120,9 @@ gpu_optimizer.register('local_remove_all_assert',
theano.tensor.opt.local_remove_all_assert,
'unsafe')
# Register local_reshape_chain
register_opt(name='local_gpu_reshape_chain')(
theano.tensor.opt.local_reshape_chain(GpuReshape))
# This is a partial list of CPU ops that can be in some circonstance
# moved to the GPU. This list is used by an optimization.
......@@ -944,35 +947,6 @@ def local_gpu_reshape(node):
return False
@local_optimizer([GpuReshape])
def local_gpu_reshape_chain(node):
"""
GuReshape(GpuReshape(shape1),shape2) -> GpuReshape(shape2)
"""
if not tensor.opt.opt.check_chain(node, GpuReshape, GpuReshape):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# It might happen that the desired output of this node has a broadcastable
# pattern that does not match that of 'rval'. This is when originally, we
# were able to figure out that one of the dimensions of the reshape is one,
# but some other transformation replaced the shape by one for which this
# cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# optimization.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
gpu_cut_copies.register('cut_local_gpu_reshape_chain',
local_gpu_reshape_chain,
'fast_run', 'gpu')
@register_opt()
@local_optimizer([gpu_from_host, tensor.Flatten])
def local_gpu_flatten(node):
......
......@@ -32,6 +32,7 @@ from theano.scalar.basic_scipy import erfinv
from theano.sandbox.blocksparse import sparse_block_dot
from theano.sandbox.cuda.blocksparse import GpuSparseBlockGemv, GpuSparseBlockOuter
if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpu')
......@@ -819,14 +820,10 @@ class test_diag(theano.tensor.tests.test_nlinalg.test_diag):
self).__init__(name)
def test_local_gpu_reshape():
mode = mode_with_gpu
a = tensor.fmatrix()
b = basic_ops.GpuReshape(3)(a, [2, 3, 4])
c = basic_ops.GpuReshape(1)(b, [24])
f = theano.function([a], c, mode=mode)
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, basic_ops.GpuReshape) for node in topo) == 1
class Test_GpuReshape(test_opt.Test_Reshape):
def setUp(self):
self.mode = mode_with_gpu
self.op = basic_ops.GpuReshape
if __name__ == '__main__':
......
......@@ -3413,31 +3413,35 @@ def local_flatten_lift(node):
##################
@gof.local_optimizer([T.Reshape])
def local_reshape_chain(node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
def local_reshape_chain(op):
@gof.local_optimizer([op])
def f(node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if not opt.check_chain(node, T.Reshape, T.Reshape):
return False
"""
if not opt.check_chain(node, op, op):
return False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# It might happen that the desired output of this node has a broadcastable
# pattern that does not match that of 'rval'. This is when originally, we
# were able to figure out that one of the dimensions of the reshape is one,
# but some other transformation replaced the shape by one for which this
# cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# optimization.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
register_canonicalize(local_reshape_chain)
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# optimization.
if rval.broadcastable == node.outputs[0].broadcastable:
return [rval]
else:
return False
return f
register_canonicalize(local_reshape_chain(T.Reshape),
name='local_reshape_chain')
@register_canonicalize
......
......@@ -5433,6 +5433,20 @@ def test_local_flatten_lift():
assert isinstance(topo[1].op, tensor.Elemwise)
class Test_Reshape(unittest.TestCase):
def setUp(self):
self.mode = mode_opt
self.op = tensor.Reshape
def test_local_reshape(self):
a = tensor.fmatrix()
b = self.op(3)(a, [2, 3, 4])
c = self.op(1)(b, [24])
f = theano.function([a], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert sum(isinstance(node.op, self.op) for node in topo) == 1
def test_local_reshape_lift():
x = tensor.tensor4()
out = T.exp(x).reshape([x.size])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论