提交 d94cdf4f authored 作者: abergeron's avatar abergeron

Merge pull request #4094 from abergeron/lift_ifelse

Add lifter for IfElse.
...@@ -10,6 +10,7 @@ from theano.compile.ops import shape_i ...@@ -10,6 +10,7 @@ from theano.compile.ops import shape_i
from theano.gof import (local_optimizer, EquilibriumDB, from theano.gof import (local_optimizer, EquilibriumDB,
SequenceDB, Optimizer, toolbox) SequenceDB, Optimizer, toolbox)
from theano.gof.optdb import LocalGroupDB from theano.gof.optdb import LocalGroupDB
from theano.ifelse import IfElse
from theano.scalar.basic import Scalar, Pow, Cast from theano.scalar.basic import Scalar, Pow, Cast
from theano.scan_module import scan_utils, scan_op, scan_opt from theano.scan_module import scan_utils, scan_op, scan_opt
...@@ -539,6 +540,16 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -539,6 +540,16 @@ def local_gpu_pdbbreakpoint_op(node):
return False return False
@register_opt('fast_compile')
@op_lifter([IfElse])
def local_gpua_lazy_ifelse(node, context_name):
if node.op.gpu:
return
c = node.inputs[0]
inps = [as_gpuarray_variable(v, context_name) for v in node.inputs[1:]]
return IfElse(node.op.n_outs, gpu=True)(c, *inps, return_list=True)
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([tensor.Join]) @op_lifter([tensor.Join])
def local_gpua_join(node, context_name): def local_gpua_join(node, context_name):
......
...@@ -3,7 +3,7 @@ import numpy ...@@ -3,7 +3,7 @@ import numpy
import theano import theano
from theano import tensor from theano import tensor
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt, test_ifelse
from theano.tensor.tests import test_basic from theano.tensor.tests import test_basic
import theano.sandbox.gpuarray import theano.sandbox.gpuarray
...@@ -206,6 +206,18 @@ class TestSpecifyShape(test_basic.TestSpecifyShape): ...@@ -206,6 +206,18 @@ class TestSpecifyShape(test_basic.TestSpecifyShape):
input_type = GpuArrayType input_type = GpuArrayType
class test_gpu_ifelse(test_ifelse.test_ifelse):
mode = mode_with_gpu
@staticmethod
def cast_output(v):
return basic_ops.as_gpuarray_variable(v, test_ctx_name)
shared = staticmethod(gpuarray_shared_constructor)
def get_ifelse(self, n):
return theano.ifelse.IfElse(n, gpu=True, as_view=True)
def test_print_op(): def test_print_op():
""" Test that print ops don't block gpu optimization""" """ Test that print ops don't block gpu optimization"""
b = tensor.fmatrix() b = tensor.fmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论