提交 d219054e authored 作者: abergeron's avatar abergeron

Merge pull request #3380 from nouiz/mixed2

Mixed2
.. _libdoc_blocksparse: .. _libdoc_blocksparse:
=================================================================== ===========================================================================
:mod:`sandbox.blocksparse` -- Block sparse dot operations (gemv and outer) :mod:`sandbox.blocksparse` -- Block sparse dot operations (gemv and outer)
=================================================================== ===========================================================================
.. module:: sandbox.blocksparse .. module:: sandbox.blocksparse
:platform: Unix, Windows :platform: Unix, Windows
......
...@@ -24,6 +24,13 @@ There are at least three possible ways of doing so: ...@@ -24,6 +24,13 @@ There are at least three possible ways of doing so:
``LD_LIBRARY_PATH``, ``LIBRARY_PATH`` and ``CPATH`` to the directory ``LD_LIBRARY_PATH``, ``LIBRARY_PATH`` and ``CPATH`` to the directory
extracted from the download. If needed, separate multiple directories extracted from the download. If needed, separate multiple directories
with ``:`` as in the ``PATH`` environment variable. with ``:`` as in the ``PATH`` environment variable.
example::
export LD_LIBRARY_PATH=/home/user/path_to_CUDNN_folder/lib64:$LD_LIBRARY_PATH
export CPATH=/home/user/path_to_CUDNN_folder/include:$CPATH
export LIBRARY_PATH=/home/user/path_to_CUDNN_folder/lib64:$LD_LIBRARY_PATH
- And as a third way, also on Linux, you can copy the ``*.h`` files - And as a third way, also on Linux, you can copy the ``*.h`` files
to ``/usr/include`` and the ``*.so*`` files to ``/lib64``. to ``/usr/include`` and the ``*.so*`` files to ``/lib64``.
......
...@@ -19,6 +19,7 @@ Blas Op ...@@ -19,6 +19,7 @@ Blas Op
.. automodule:: theano.sandbox.cuda.blas .. automodule:: theano.sandbox.cuda.blas
:members: :members:
.. autofunction:: theano.sandbox.cuda.blas.batched_dot
Nnet Op Nnet Op
======= =======
......
...@@ -78,8 +78,8 @@ class OpFromGraph(gof.Op): ...@@ -78,8 +78,8 @@ class OpFromGraph(gof.Op):
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError( raise TypeError(
'inputs and outputs must be Variable instances', i) 'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs: if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates are not allowed in kwargs') raise TypeError('updates and givens are not allowed in kwargs')
# To support correctly shared variables the inner fct should # To support correctly shared variables the inner fct should
# not see them. Otherwise their is problem with the gradient. # not see them. Otherwise their is problem with the gradient.
......
...@@ -302,8 +302,15 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -302,8 +302,15 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
"HINT: Use the Theano flag 'exception_verbosity=high'" "HINT: Use the Theano flag 'exception_verbosity=high'"
" for a debugprint and storage map footprint of this apply node.") " for a debugprint and storage map footprint of this apply node.")
exc_value = exc_type(str(exc_value) + detailed_err_msg + try:
'\n' + '\n'.join(hints)) exc_value = exc_type(str(exc_value) + detailed_err_msg +
'\n' + '\n'.join(hints))
except TypeError:
print("WARNING: %s error does not allow us to add extra error message" %
str(exc_type))
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
pass
reraise(exc_type, exc_value, exc_trace) reraise(exc_type, exc_value, exc_trace)
......
...@@ -395,7 +395,13 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -395,7 +395,13 @@ def ifelse(condition, then_branch, else_branch, name=None):
@gof.local_optimizer([IfElse]) @gof.local_optimizer([IfElse])
def cond_make_inplace(node): def cond_make_inplace(node):
op = node.op op = node.op
if isinstance(op, IfElse) and not op.as_view: if (isinstance(op, IfElse) and
not op.as_view and
# For big graph, do not make inplace scalar to speed up
# optimization.
(len(node.fgraph.apply_nodes) < 500 or
not all([getattr(o.type, 'ndim', -1) == 0
for o in node.outputs]))):
return IfElse(n_outs=op.n_outs, return IfElse(n_outs=op.n_outs,
as_view=True, as_view=True,
gpu=op.gpu, gpu=op.gpu,
......
...@@ -14,8 +14,8 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, ...@@ -14,8 +14,8 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
gpu_contiguous) gpu_contiguous)
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable
class BatchedDotOp(GpuOp):
class BatchedDotOp(GpuOp):
__props__ = () __props__ = ()
def make_node(self, inp1, inp2): def make_node(self, inp1, inp2):
...@@ -213,6 +213,10 @@ class BatchedDotOp(GpuOp): ...@@ -213,6 +213,10 @@ class BatchedDotOp(GpuOp):
return (1,) return (1,)
batched_dot = BatchedDotOp() batched_dot = BatchedDotOp()
"""
Call cublasSgemmBatched. Take 2 3d tensor as input.
"""
class GpuDot22(GpuOp): class GpuDot22(GpuOp):
""" """
......
...@@ -81,20 +81,28 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) { ...@@ -81,20 +81,28 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
" from one version, but we link with" " from one version, but we link with"
" a different version %s" % str(v)) " a different version %s" % str(v))
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if version() == -1: if v == -1:
dnn_available.avail = False dnn_available.avail = False
dnn_available.msg = ( dnn_available.msg = (
"CuDNN v1 detected. This version is no longer " "CuDNN v1 detected. This version is no longer "
"supported by Theano. Update your CuDNN installation " "supported by Theano. Update your CuDNN installation "
"to a more recent version") "to a more recent version")
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if version() == (20, 20): if v == (20, 20):
dnn_available.avail = False dnn_available.avail = False
dnn_available.msg = ( dnn_available.msg = (
"You have installed a release candidate of CuDNN v2." "You have installed a release candidate of CuDNN v2."
" This isn't supported anymore." " This isn't supported anymore."
" Update to CuDNN v2 final version.") " Update to CuDNN v2 final version.")
raise RuntimeError(dnn_available.msg) raise RuntimeError(dnn_available.msg)
if v[0] >= 3000 and v[0] < 3007:
# 3007 is the final release of cudnn v3
dnn_available.avail = False
dnn_available.msg = (
"You have installed a release candidate of CuDNN v3."
" This isn't supported anymore."
" Update to CuDNN v3 final version.")
raise RuntimeError(dnn_available.msg)
return dnn_available.avail return dnn_available.avail
...@@ -2380,8 +2388,7 @@ if True: ...@@ -2380,8 +2388,7 @@ if True:
isinstance(node.inputs[0].owner.op, HostFromGpu)) or isinstance(node.inputs[0].owner.op, HostFromGpu)) or
(node.inputs[1].owner and (node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, HostFromGpu)))): isinstance(node.inputs[1].owner.op, HostFromGpu)))):
if not dnn_available() or version() != (2000, 2000): if not dnn_available():
# Softmax grad is broken in v3 rc1 for this case
return return
ins = [] ins = []
for n in node.inputs: for n in node.inputs:
......
...@@ -66,7 +66,7 @@ class NaiveAlgo(object): ...@@ -66,7 +66,7 @@ class NaiveAlgo(object):
def cache_version(self): def cache_version(self):
ver = self.scalar_op.c_code_cache_version() ver = self.scalar_op.c_code_cache_version()
if ver: if ver:
return (19, self.verbose, self.sync, ver) return (20, self.verbose, self.sync, ver)
else: else:
return ver return ver
...@@ -86,7 +86,9 @@ class NaiveAlgo(object): ...@@ -86,7 +86,9 @@ class NaiveAlgo(object):
def c_src_kernel(self, node, nodename, nd): def c_src_kernel(self, node, nodename, nd):
sio = StringIO() sio = StringIO()
# print 'C_SRC_KERNEL', sio.getvalue() # print 'C_SRC_KERNEL', sio.getvalue()
print("// %s" % str(node.op), file=sio)
print("// node.op.destroy_map=%s" % str(
getattr(node.op, 'destroy_map', None)), file=sio)
for ipos, i in enumerate(node.inputs): for ipos, i in enumerate(node.inputs):
print("// Input ", ipos, str(i.type), file=sio) print("// Input ", ipos, str(i.type), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
...@@ -202,6 +204,9 @@ class NaiveAlgo(object): ...@@ -202,6 +204,9 @@ class NaiveAlgo(object):
if nd in (4,): if nd in (4,):
# print some leading comments to make the code easier to read # print some leading comments to make the code easier to read
print("// %s" % str(node.op), file=sio)
print("// node.op.destroy_map=%s" % str(
getattr(node.op, 'destroy_map', None)), file=sio)
for ipos, i in enumerate(node.inputs): for ipos, i in enumerate(node.inputs):
print("// Input ", ipos, str(i.type), file=sio) print("// Input ", ipos, str(i.type), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
...@@ -307,6 +312,9 @@ class NaiveAlgo(object): ...@@ -307,6 +312,9 @@ class NaiveAlgo(object):
return sio.getvalue() return sio.getvalue()
# print some leading comments to make the code easier to read # print some leading comments to make the code easier to read
print("// %s" % str(node.op), file=sio)
print("// node.op.destroy_map=%s" % str(
getattr(node.op, 'destroy_map', None)), file=sio)
for ipos, i in enumerate(node.inputs): for ipos, i in enumerate(node.inputs):
print("// Input ", ipos, str(i.type), file=sio) print("// Input ", ipos, str(i.type), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
...@@ -456,6 +464,9 @@ class NaiveAlgo(object): ...@@ -456,6 +464,9 @@ class NaiveAlgo(object):
sio = StringIO() sio = StringIO()
# print 'C_SRC_KERNEL', sio.getvalue() # print 'C_SRC_KERNEL', sio.getvalue()
print("// %s" % str(node.op), file=sio)
print("// node.op.destroy_map=%s" % str(
getattr(node.op, 'destroy_map', None)), file=sio)
for ipos, i in enumerate(node.inputs): for ipos, i in enumerate(node.inputs):
print("// Input ", ipos, str(i.type), file=sio) print("// Input ", ipos, str(i.type), file=sio)
for ipos, i in enumerate(node.outputs): for ipos, i in enumerate(node.outputs):
......
...@@ -795,6 +795,11 @@ def local_gpu_careduce(node): ...@@ -795,6 +795,11 @@ def local_gpu_careduce(node):
replace = False replace = False
if x.owner and isinstance(x.owner.op, HostFromGpu): if x.owner and isinstance(x.owner.op, HostFromGpu):
replace = True replace = True
# If this is a useless reduce, remove it as
# local_cut_useless_reduce. This is needed as the code
# below do not support when x.ndim == 0.
if x.type == node.outputs[0].type:
return [x]
elif (all([c != "output" and isinstance(c.op, GpuFromHost) elif (all([c != "output" and isinstance(c.op, GpuFromHost)
for c, i in node.outputs[0].clients]) for c, i in node.outputs[0].clients])
and x.owner and x.owner.op.__class__ in and x.owner and x.owner.op.__class__ in
......
...@@ -296,6 +296,12 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -296,6 +296,12 @@ def inplace_elemwise_optimizer_op(OP):
# gpuarray GpuElemwise inherit from Elemwise # gpuarray GpuElemwise inherit from Elemwise
if not type(op) == OP: if not type(op) == OP:
continue continue
# If big graph and the outputs are scalar, do not make it
# inplace.
if (check_each_change != 1 and
all([getattr(o.type, 'ndim', -1) == 0
for o in node.outputs])):
continue
baseline = op.inplace_pattern baseline = op.inplace_pattern
protected_inputs = [ protected_inputs = [
...@@ -4188,28 +4194,29 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -4188,28 +4194,29 @@ def local_sum_prod_mul_by_scalar(node):
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod): if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
node_inps, = node.inputs node_inps, = node.inputs
if node_inps.owner and node_inps.owner.op == T.mul: if node_inps.owner and node_inps.owner.op == T.mul:
terms = node_inps.owner.inputs terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)] numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if len(scalars) == 0: if len(scalars) == 0:
# Nothing to optimize here # Nothing to optimize here
return return
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
# Perform the op only on the non-scalar inputs, if applicable # Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0: if len(non_scalars) == 0:
new_op_input_nb_elements = 1 new_op_input_nb_elements = 1
new_op_output = 1 new_op_output = 1
elif len(non_scalars) == 1: elif len(non_scalars) == 1:
new_op_input_nb_elements = T.prod(non_scalars[0].shape) new_op_input_nb_elements = non_scalars[0].size
new_op_output = node.op(non_scalars[0]) new_op_output = node.op(non_scalars[0])
else: else:
new_op_input = T.mul(*non_scalars) new_op_input = T.mul(*non_scalars)
new_op_input_nb_elements = T.prod(new_op_input.shape) new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input) new_op_output = node.op(new_op_input)
# If node.op is a T.elemwise.Prod, then the scalars need to be # If node.op is a T.elemwise.Prod, then the scalars need to be
...@@ -4226,7 +4233,10 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -4226,7 +4233,10 @@ def local_sum_prod_mul_by_scalar(node):
if new_op_input_nb_elements != 1: if new_op_input_nb_elements != 1:
mul_inputs.append(new_op_output) mul_inputs.append(new_op_output)
return [T.mul(*mul_inputs)] if len(mul_inputs) == 1:
return mul_inputs
else:
return [T.mul(*mul_inputs)]
if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg: if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(node_inps.owner.inputs[0]))] return [T.neg(node.op(node_inps.owner.inputs[0]))]
...@@ -4453,25 +4463,25 @@ def local_sum_prod_div_dimshuffle(node): ...@@ -4453,25 +4463,25 @@ def local_sum_prod_div_dimshuffle(node):
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
op_on_compatible_dims = T.sum( op_on_compatible_dims = T.sum(
numerator, axis=compatible_dims) numerator, axis=compatible_dims)
div_op = T.true_div( rval = T.true_div(
op_on_compatible_dims, op_on_compatible_dims,
optimized_dimshuffle) optimized_dimshuffle)
op_on_incompatible_dims = T.sum( if len(reordered_incompatible_dims) > 0:
div_op, rval = T.sum(rval,
axis=reordered_incompatible_dims) axis=reordered_incompatible_dims)
elif isinstance(node.op, T.elemwise.Prod): elif isinstance(node.op, T.elemwise.Prod):
op_on_compatible_dims = T.prod( op_on_compatible_dims = T.prod(
numerator, axis=compatible_dims) numerator, axis=compatible_dims)
dtype = numerator.dtype dtype = numerator.dtype
div_op = T.true_div( rval = T.true_div(
op_on_compatible_dims, op_on_compatible_dims,
(optimized_dimshuffle ** (optimized_dimshuffle **
T.prod([numerator.shape[ax].astype(dtype) T.prod([numerator.shape[ax].astype(dtype)
for ax in compatible_dims]))) for ax in compatible_dims])))
op_on_incompatible_dims = T.prod( if len(reordered_incompatible_dims) > 0:
div_op, rval = T.prod(rval,
axis=reordered_incompatible_dims) axis=reordered_incompatible_dims)
return [op_on_incompatible_dims] return [rval]
@register_canonicalize @register_canonicalize
......
...@@ -4810,7 +4810,7 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -4810,7 +4810,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 2 # Case 2
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod, test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod,
(s1_val * v_val).prod(), 2) (s1_val * v_val).prod(), 1)
# Case 3 # Case 3
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val], test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
...@@ -4823,7 +4823,7 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -4823,7 +4823,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 5 # Case 5
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val], test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.elemwise.Prod, (s1_val * s2_val * v_val).prod(), T.elemwise.Prod, (s1_val * s2_val * v_val).prod(),
2) 1)
# Case 6 # Case 6
test_reduction_opt([vect, mat, scalar1, scalar2], test_reduction_opt([vect, mat, scalar1, scalar2],
......
...@@ -280,7 +280,8 @@ class _tensor_py_operators: ...@@ -280,7 +280,8 @@ class _tensor_py_operators:
shape = property(lambda self: theano.tensor.basic.shape(self)) shape = property(lambda self: theano.tensor.basic.shape(self))
size = property(lambda self: theano.tensor.basic.prod(self.shape)) size = property(lambda self: self.shape[0] if self.ndim == 1 else
theano.tensor.basic.prod(self.shape))
# We can't implement __len__ to provide a better error message. # We can't implement __len__ to provide a better error message.
def any(self, axis=None, keepdims=False): def any(self, axis=None, keepdims=False):
......
...@@ -30,7 +30,6 @@ whitelist_flake8 = [ ...@@ -30,7 +30,6 @@ whitelist_flake8 = [
"tests/test_gradient.py", "tests/test_gradient.py",
"tests/test_config.py", "tests/test_config.py",
"tests/diverse_tests.py", "tests/diverse_tests.py",
"tests/test_ifelse.py",
"tests/test_rop.py", "tests/test_rop.py",
"tests/test_2nd_order_grads.py", "tests/test_2nd_order_grads.py",
"tests/run_tests_in_batch.py", "tests/run_tests_in_batch.py",
......
...@@ -3,20 +3,22 @@ ...@@ -3,20 +3,22 @@
""" """
from __future__ import print_function from __future__ import print_function
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import unittest import unittest
import numpy import numpy
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from six.moves import reduce
import theano import theano
from theano import tensor from theano import tensor
import theano.ifelse import theano.ifelse
from theano.ifelse import IfElse, ifelse from theano.ifelse import IfElse, ifelse
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
...@@ -51,6 +53,32 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -51,6 +53,32 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
assert numpy.allclose(vx, f(1, vx, vy)) assert numpy.allclose(vx, f(1, vx, vy))
assert numpy.allclose(vy, f(0, vx, vy)) assert numpy.allclose(vy, f(0, vx, vy))
def test_not_lazy_if_inplace(self):
# Tests that if the outputs are scalars and the graph is big,
# we disable the inplace opt to speed up optimization
x = tensor.vector('x', dtype=self.dtype)
y = tensor.vector('y', dtype=self.dtype)
c = tensor.iscalar('c')
mode = theano.compile.get_mode(self.mode).excluding(
# Disable many opt to keep the graph big enough to disable
# the opt.
'fusion', 'local_add_canonizer',
'inplace', 'constant_folding', 'constant_folding')
y2 = reduce(lambda x, y: x + y, [y] + list(range(200)))
f = theano.function([c, x, y], ifelse(c, x, y2), mode=mode)
# For not inplace ifelse
self.assertFunctionContains1(f, IfElse(1))
rng = numpy.random.RandomState(utt.fetch_seed())
xlen = rng.randint(200)
ylen = rng.randint(200)
vx = numpy.asarray(rng.uniform(size=(xlen,)), self.dtype)
vy = numpy.asarray(rng.uniform(size=(ylen,)), self.dtype)
assert numpy.allclose(vx, f(1, vx, vy))
assert numpy.allclose(vy + sum(range(200)), f(0, vx, vy))
def test_mixed_dtype(self): def test_mixed_dtype(self):
x1 = tensor.vector('x1', dtype='int32') x1 = tensor.vector('x1', dtype='int32')
x2 = tensor.vector('x2', dtype=self.dtype) x2 = tensor.vector('x2', dtype=self.dtype)
...@@ -65,9 +93,9 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -65,9 +93,9 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
xlen = rng.randint(200) xlen = rng.randint(200)
ylen = rng.randint(200) ylen = rng.randint(200)
vx1 = numpy.asarray(rng.uniform(size=(xlen,))*3, 'int32') vx1 = numpy.asarray(rng.uniform(size=(xlen,)) * 3, 'int32')
vx2 = numpy.asarray(rng.uniform(size=(xlen,)), self.dtype) vx2 = numpy.asarray(rng.uniform(size=(xlen,)), self.dtype)
vy1 = numpy.asarray(rng.uniform(size=(ylen,))*3, 'int32') vy1 = numpy.asarray(rng.uniform(size=(ylen,)) * 3, 'int32')
vy2 = numpy.asarray(rng.uniform(size=(ylen,)), self.dtype) vy2 = numpy.asarray(rng.uniform(size=(ylen,)), self.dtype)
o1, o2 = f(1, vx1, vx2, vy1, vy2) o1, o2 = f(1, vx1, vx2, vy1, vy2)
...@@ -288,8 +316,8 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -288,8 +316,8 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
z2 = ifelse(c, x + 2, y + 2) z2 = ifelse(c, x + 2, y + 2)
z = z1 + z2 z = z1 + z2
f = theano.function([c, x, y], z) f = theano.function([c, x, y], z)
assert len([x for x in f.maker.fgraph.toposort() assert len([n for n in f.maker.fgraph.toposort()
if isinstance(x.op, IfElse)]) == 1 if isinstance(n.op, IfElse)]) == 1
def test_remove_useless_inputs1(self): def test_remove_useless_inputs1(self):
raise SkipTest("Optimization temporarily disabled") raise SkipTest("Optimization temporarily disabled")
...@@ -299,8 +327,8 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -299,8 +327,8 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
z = ifelse(c, (x, x), (y, y)) z = ifelse(c, (x, x), (y, y))
f = theano.function([c, x, y], z) f = theano.function([c, x, y], z)
ifnode = [x for x in f.maker.fgraph.toposort() ifnode = [n for n in f.maker.fgraph.toposort()
if isinstance(x.op, IfElse)][0] if isinstance(n.op, IfElse)][0]
assert len(ifnode.inputs) == 3 assert len(ifnode.inputs) == 3
def test_remove_useless_inputs2(self): def test_remove_useless_inputs2(self):
...@@ -418,12 +446,12 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -418,12 +446,12 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
c = tensor.iscalar('c') c = tensor.iscalar('c')
out = ifelse(c, out = ifelse(c,
ifelse(c, x1, x2) + ifelse(c, y1, y2) + w1, ifelse(c, x1, x2) + ifelse(c, y1, y2) + w1,
ifelse(c, x1, x2) + ifelse(c, y1, y2) + w2) ifelse(c, x1, x2) + ifelse(c, y1, y2) + w2)
f = theano.function([x1, x2, y1, y2, w1, w2, c], out, f = theano.function([x1, x2, y1, y2, w1, w2, c], out,
allow_input_downcast=True) allow_input_downcast=True)
assert len([x for x in f.maker.fgraph.toposort() assert len([x for x in f.maker.fgraph.toposort()
if isinstance(x.op, IfElse)]) == 1 if isinstance(x.op, IfElse)]) == 1
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
vx1 = rng.uniform() vx1 = rng.uniform()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论