Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d219054e
提交
d219054e
authored
9月 16, 2015
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3380 from nouiz/mixed2
Mixed2
上级
88eac16c
2e4f475a
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
133 行增加
和
47 行删除
+133
-47
blocksparse.txt
doc/library/sandbox/blocksparse.txt
+2
-2
dnn.txt
doc/library/sandbox/cuda/dnn.txt
+7
-0
op.txt
doc/library/sandbox/cuda/op.txt
+1
-0
builders.py
theano/compile/builders.py
+2
-2
nanguardmode.py
theano/compile/nanguardmode.py
+0
-0
link.py
theano/gof/link.py
+9
-2
ifelse.py
theano/ifelse.py
+7
-1
blas.py
theano/sandbox/cuda/blas.py
+5
-1
dnn.py
theano/sandbox/cuda/dnn.py
+11
-4
elemwise.py
theano/sandbox/cuda/elemwise.py
+13
-2
opt.py
theano/sandbox/cuda/opt.py
+5
-0
opt.py
theano/tensor/opt.py
+24
-14
test_opt.py
theano/tensor/tests/test_opt.py
+2
-2
var.py
theano/tensor/var.py
+2
-1
test_flake8.py
theano/tests/test_flake8.py
+0
-1
test_ifelse.py
theano/tests/test_ifelse.py
+43
-15
没有找到文件。
doc/library/sandbox/blocksparse.txt
浏览文件 @
d219054e
.. _libdoc_blocksparse:
===================================================================
===================================================================
========
:mod:`sandbox.blocksparse` -- Block sparse dot operations (gemv and outer)
===================================================================
===================================================================
========
.. module:: sandbox.blocksparse
:platform: Unix, Windows
...
...
doc/library/sandbox/cuda/dnn.txt
浏览文件 @
d219054e
...
...
@@ -24,6 +24,13 @@ There are at least three possible ways of doing so:
``LD_LIBRARY_PATH``, ``LIBRARY_PATH`` and ``CPATH`` to the directory
extracted from the download. If needed, separate multiple directories
with ``:`` as in the ``PATH`` environment variable.
example::
export LD_LIBRARY_PATH=/home/user/path_to_CUDNN_folder/lib64:$LD_LIBRARY_PATH
export CPATH=/home/user/path_to_CUDNN_folder/include:$CPATH
export LIBRARY_PATH=/home/user/path_to_CUDNN_folder/lib64:$LD_LIBRARY_PATH
- And as a third way, also on Linux, you can copy the ``*.h`` files
to ``/usr/include`` and the ``*.so*`` files to ``/lib64``.
...
...
doc/library/sandbox/cuda/op.txt
浏览文件 @
d219054e
...
...
@@ -19,6 +19,7 @@ Blas Op
.. automodule:: theano.sandbox.cuda.blas
:members:
.. autofunction:: theano.sandbox.cuda.blas.batched_dot
Nnet Op
=======
...
...
theano/compile/builders.py
浏览文件 @
d219054e
...
...
@@ -78,8 +78,8 @@ class OpFromGraph(gof.Op):
if
not
isinstance
(
i
,
gof
.
Variable
):
raise
TypeError
(
'inputs and outputs must be Variable instances'
,
i
)
if
'updates'
in
kwargs
:
raise
TypeError
(
'updates are not allowed in kwargs'
)
if
'updates'
in
kwargs
or
'givens'
in
kwargs
:
raise
TypeError
(
'updates a
nd givens a
re not allowed in kwargs'
)
# To support correctly shared variables the inner fct should
# not see them. Otherwise their is problem with the gradient.
...
...
theano/compile/nanguardmode.py
浏览文件 @
d219054e
差异被折叠。
点击展开。
theano/gof/link.py
浏览文件 @
d219054e
...
...
@@ -302,8 +302,15 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
"HINT: Use the Theano flag 'exception_verbosity=high'"
" for a debugprint and storage map footprint of this apply node."
)
exc_value
=
exc_type
(
str
(
exc_value
)
+
detailed_err_msg
+
'
\n
'
+
'
\n
'
.
join
(
hints
))
try
:
exc_value
=
exc_type
(
str
(
exc_value
)
+
detailed_err_msg
+
'
\n
'
+
'
\n
'
.
join
(
hints
))
except
TypeError
:
print
(
"WARNING:
%
s error does not allow us to add extra error message"
%
str
(
exc_type
))
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
pass
reraise
(
exc_type
,
exc_value
,
exc_trace
)
...
...
theano/ifelse.py
浏览文件 @
d219054e
...
...
@@ -395,7 +395,13 @@ def ifelse(condition, then_branch, else_branch, name=None):
@gof.local_optimizer
([
IfElse
])
def
cond_make_inplace
(
node
):
op
=
node
.
op
if
isinstance
(
op
,
IfElse
)
and
not
op
.
as_view
:
if
(
isinstance
(
op
,
IfElse
)
and
not
op
.
as_view
and
# For big graph, do not make inplace scalar to speed up
# optimization.
(
len
(
node
.
fgraph
.
apply_nodes
)
<
500
or
not
all
([
getattr
(
o
.
type
,
'ndim'
,
-
1
)
==
0
for
o
in
node
.
outputs
]))):
return
IfElse
(
n_outs
=
op
.
n_outs
,
as_view
=
True
,
gpu
=
op
.
gpu
,
...
...
theano/sandbox/cuda/blas.py
浏览文件 @
d219054e
...
...
@@ -14,8 +14,8 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
gpu_contiguous
)
from
theano.tensor
import
as_tensor_variable
class
BatchedDotOp
(
GpuOp
):
class
BatchedDotOp
(
GpuOp
):
__props__
=
()
def
make_node
(
self
,
inp1
,
inp2
):
...
...
@@ -213,6 +213,10 @@ class BatchedDotOp(GpuOp):
return
(
1
,)
batched_dot
=
BatchedDotOp
()
"""
Call cublasSgemmBatched. Take 2 3d tensor as input.
"""
class
GpuDot22
(
GpuOp
):
"""
...
...
theano/sandbox/cuda/dnn.py
浏览文件 @
d219054e
...
...
@@ -81,20 +81,28 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
" from one version, but we link with"
" a different version
%
s"
%
str
(
v
))
raise
RuntimeError
(
dnn_available
.
msg
)
if
v
ersion
()
==
-
1
:
if
v
==
-
1
:
dnn_available
.
avail
=
False
dnn_available
.
msg
=
(
"CuDNN v1 detected. This version is no longer "
"supported by Theano. Update your CuDNN installation "
"to a more recent version"
)
raise
RuntimeError
(
dnn_available
.
msg
)
if
v
ersion
()
==
(
20
,
20
):
if
v
==
(
20
,
20
):
dnn_available
.
avail
=
False
dnn_available
.
msg
=
(
"You have installed a release candidate of CuDNN v2."
" This isn't supported anymore."
" Update to CuDNN v2 final version."
)
raise
RuntimeError
(
dnn_available
.
msg
)
if
v
[
0
]
>=
3000
and
v
[
0
]
<
3007
:
# 3007 is the final release of cudnn v3
dnn_available
.
avail
=
False
dnn_available
.
msg
=
(
"You have installed a release candidate of CuDNN v3."
" This isn't supported anymore."
" Update to CuDNN v3 final version."
)
raise
RuntimeError
(
dnn_available
.
msg
)
return
dnn_available
.
avail
...
...
@@ -2380,8 +2388,7 @@ if True:
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
HostFromGpu
))
or
(
node
.
inputs
[
1
]
.
owner
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
HostFromGpu
)))):
if
not
dnn_available
()
or
version
()
!=
(
2000
,
2000
):
# Softmax grad is broken in v3 rc1 for this case
if
not
dnn_available
():
return
ins
=
[]
for
n
in
node
.
inputs
:
...
...
theano/sandbox/cuda/elemwise.py
浏览文件 @
d219054e
...
...
@@ -66,7 +66,7 @@ class NaiveAlgo(object):
def
cache_version
(
self
):
ver
=
self
.
scalar_op
.
c_code_cache_version
()
if
ver
:
return
(
19
,
self
.
verbose
,
self
.
sync
,
ver
)
return
(
20
,
self
.
verbose
,
self
.
sync
,
ver
)
else
:
return
ver
...
...
@@ -86,7 +86,9 @@ class NaiveAlgo(object):
def
c_src_kernel
(
self
,
node
,
nodename
,
nd
):
sio
=
StringIO
()
# print 'C_SRC_KERNEL', sio.getvalue()
print
(
"//
%
s"
%
str
(
node
.
op
),
file
=
sio
)
print
(
"// node.op.destroy_map=
%
s"
%
str
(
getattr
(
node
.
op
,
'destroy_map'
,
None
)),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
(
"// Input "
,
ipos
,
str
(
i
.
type
),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
...
...
@@ -202,6 +204,9 @@ class NaiveAlgo(object):
if
nd
in
(
4
,):
# print some leading comments to make the code easier to read
print
(
"//
%
s"
%
str
(
node
.
op
),
file
=
sio
)
print
(
"// node.op.destroy_map=
%
s"
%
str
(
getattr
(
node
.
op
,
'destroy_map'
,
None
)),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
(
"// Input "
,
ipos
,
str
(
i
.
type
),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
...
...
@@ -307,6 +312,9 @@ class NaiveAlgo(object):
return
sio
.
getvalue
()
# print some leading comments to make the code easier to read
print
(
"//
%
s"
%
str
(
node
.
op
),
file
=
sio
)
print
(
"// node.op.destroy_map=
%
s"
%
str
(
getattr
(
node
.
op
,
'destroy_map'
,
None
)),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
(
"// Input "
,
ipos
,
str
(
i
.
type
),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
...
...
@@ -456,6 +464,9 @@ class NaiveAlgo(object):
sio
=
StringIO
()
# print 'C_SRC_KERNEL', sio.getvalue()
print
(
"//
%
s"
%
str
(
node
.
op
),
file
=
sio
)
print
(
"// node.op.destroy_map=
%
s"
%
str
(
getattr
(
node
.
op
,
'destroy_map'
,
None
)),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
(
"// Input "
,
ipos
,
str
(
i
.
type
),
file
=
sio
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
...
...
theano/sandbox/cuda/opt.py
浏览文件 @
d219054e
...
...
@@ -795,6 +795,11 @@ def local_gpu_careduce(node):
replace
=
False
if
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
HostFromGpu
):
replace
=
True
# If this is a useless reduce, remove it as
# local_cut_useless_reduce. This is needed as the code
# below do not support when x.ndim == 0.
if
x
.
type
==
node
.
outputs
[
0
]
.
type
:
return
[
x
]
elif
(
all
([
c
!=
"output"
and
isinstance
(
c
.
op
,
GpuFromHost
)
for
c
,
i
in
node
.
outputs
[
0
]
.
clients
])
and
x
.
owner
and
x
.
owner
.
op
.
__class__
in
...
...
theano/tensor/opt.py
浏览文件 @
d219054e
...
...
@@ -296,6 +296,12 @@ def inplace_elemwise_optimizer_op(OP):
# gpuarray GpuElemwise inherit from Elemwise
if
not
type
(
op
)
==
OP
:
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
if
(
check_each_change
!=
1
and
all
([
getattr
(
o
.
type
,
'ndim'
,
-
1
)
==
0
for
o
in
node
.
outputs
])):
continue
baseline
=
op
.
inplace_pattern
protected_inputs
=
[
...
...
@@ -4188,28 +4194,29 @@ def local_sum_prod_mul_by_scalar(node):
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if
isinstance
(
node
.
op
,
T
.
Sum
)
or
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
):
if
isinstance
(
node
.
op
,
(
T
.
Sum
,
T
.
elemwise
.
Prod
)
):
node_inps
,
=
node
.
inputs
if
node_inps
.
owner
and
node_inps
.
owner
.
op
==
T
.
mul
:
terms
=
node_inps
.
owner
.
inputs
scalars
=
[
t
.
dimshuffle
()
for
t
in
terms
if
numpy
.
all
(
t
.
type
.
broadcastable
)]
non_scalars
=
[
t
for
t
in
terms
if
not
numpy
.
all
(
t
.
broadcastable
)]
if
len
(
scalars
)
==
0
:
# Nothing to optimize here
return
non_scalars
=
[
t
for
t
in
terms
if
not
numpy
.
all
(
t
.
broadcastable
)]
# Perform the op only on the non-scalar inputs, if applicable
if
len
(
non_scalars
)
==
0
:
new_op_input_nb_elements
=
1
new_op_output
=
1
elif
len
(
non_scalars
)
==
1
:
new_op_input_nb_elements
=
T
.
prod
(
non_scalars
[
0
]
.
shape
)
new_op_input_nb_elements
=
non_scalars
[
0
]
.
size
new_op_output
=
node
.
op
(
non_scalars
[
0
])
else
:
new_op_input
=
T
.
mul
(
*
non_scalars
)
new_op_input_nb_elements
=
T
.
prod
(
new_op_input
.
shape
)
new_op_input_nb_elements
=
new_op_input
.
size
new_op_output
=
node
.
op
(
new_op_input
)
# If node.op is a T.elemwise.Prod, then the scalars need to be
...
...
@@ -4226,7 +4233,10 @@ def local_sum_prod_mul_by_scalar(node):
if
new_op_input_nb_elements
!=
1
:
mul_inputs
.
append
(
new_op_output
)
return
[
T
.
mul
(
*
mul_inputs
)]
if
len
(
mul_inputs
)
==
1
:
return
mul_inputs
else
:
return
[
T
.
mul
(
*
mul_inputs
)]
if
isinstance
(
node
.
op
,
T
.
Sum
)
and
node_inps
.
owner
and
node_inps
.
owner
.
op
==
T
.
neg
:
return
[
T
.
neg
(
node
.
op
(
node_inps
.
owner
.
inputs
[
0
]))]
...
...
@@ -4453,25 +4463,25 @@ def local_sum_prod_div_dimshuffle(node):
if
isinstance
(
node
.
op
,
T
.
Sum
):
op_on_compatible_dims
=
T
.
sum
(
numerator
,
axis
=
compatible_dims
)
div_op
=
T
.
true_div
(
rval
=
T
.
true_div
(
op_on_compatible_dims
,
optimized_dimshuffle
)
op_on_incompatible_dims
=
T
.
sum
(
div_op
,
axis
=
reordered_incompatible_dims
)
if
len
(
reordered_incompatible_dims
)
>
0
:
rval
=
T
.
sum
(
rval
,
axis
=
reordered_incompatible_dims
)
elif
isinstance
(
node
.
op
,
T
.
elemwise
.
Prod
):
op_on_compatible_dims
=
T
.
prod
(
numerator
,
axis
=
compatible_dims
)
dtype
=
numerator
.
dtype
div_op
=
T
.
true_div
(
rval
=
T
.
true_div
(
op_on_compatible_dims
,
(
optimized_dimshuffle
**
T
.
prod
([
numerator
.
shape
[
ax
]
.
astype
(
dtype
)
for
ax
in
compatible_dims
])))
op_on_incompatible_dims
=
T
.
prod
(
div_op
,
axis
=
reordered_incompatible_dims
)
return
[
op_on_incompatible_dims
]
if
len
(
reordered_incompatible_dims
)
>
0
:
rval
=
T
.
prod
(
rval
,
axis
=
reordered_incompatible_dims
)
return
[
rval
]
@register_canonicalize
...
...
theano/tensor/tests/test_opt.py
浏览文件 @
d219054e
...
...
@@ -4810,7 +4810,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 2
test_reduction_opt
([
vect
,
scalar1
],
[
v_val
,
s1_val
],
T
.
elemwise
.
Prod
,
(
s1_val
*
v_val
)
.
prod
(),
2
)
(
s1_val
*
v_val
)
.
prod
(),
1
)
# Case 3
test_reduction_opt
([
vect
,
mat
,
scalar1
],
[
v_val
,
m_val
,
s1_val
],
...
...
@@ -4823,7 +4823,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 5
test_reduction_opt
([
vect
,
scalar1
,
scalar2
],
[
v_val
,
s1_val
,
s2_val
],
T
.
elemwise
.
Prod
,
(
s1_val
*
s2_val
*
v_val
)
.
prod
(),
2
)
1
)
# Case 6
test_reduction_opt
([
vect
,
mat
,
scalar1
,
scalar2
],
...
...
theano/tensor/var.py
浏览文件 @
d219054e
...
...
@@ -280,7 +280,8 @@ class _tensor_py_operators:
shape
=
property
(
lambda
self
:
theano
.
tensor
.
basic
.
shape
(
self
))
size
=
property
(
lambda
self
:
theano
.
tensor
.
basic
.
prod
(
self
.
shape
))
size
=
property
(
lambda
self
:
self
.
shape
[
0
]
if
self
.
ndim
==
1
else
theano
.
tensor
.
basic
.
prod
(
self
.
shape
))
# We can't implement __len__ to provide a better error message.
def
any
(
self
,
axis
=
None
,
keepdims
=
False
):
...
...
theano/tests/test_flake8.py
浏览文件 @
d219054e
...
...
@@ -30,7 +30,6 @@ whitelist_flake8 = [
"tests/test_gradient.py"
,
"tests/test_config.py"
,
"tests/diverse_tests.py"
,
"tests/test_ifelse.py"
,
"tests/test_rop.py"
,
"tests/test_2nd_order_grads.py"
,
"tests/run_tests_in_batch.py"
,
...
...
theano/tests/test_ifelse.py
浏览文件 @
d219054e
...
...
@@ -3,20 +3,22 @@
"""
from
__future__
import
print_function
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
import
unittest
import
numpy
from
nose.plugins.skip
import
SkipTest
from
six.moves
import
reduce
import
theano
from
theano
import
tensor
import
theano.ifelse
from
theano.ifelse
import
IfElse
,
ifelse
from
theano.tests
import
unittest_tools
as
utt
from
theano.tests
import
unittest_tools
as
utt
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
class
test_ifelse
(
unittest
.
TestCase
,
utt
.
TestOptimizationMixin
):
...
...
@@ -51,6 +53,32 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
assert
numpy
.
allclose
(
vx
,
f
(
1
,
vx
,
vy
))
assert
numpy
.
allclose
(
vy
,
f
(
0
,
vx
,
vy
))
def
test_not_lazy_if_inplace
(
self
):
# Tests that if the outputs are scalars and the graph is big,
# we disable the inplace opt to speed up optimization
x
=
tensor
.
vector
(
'x'
,
dtype
=
self
.
dtype
)
y
=
tensor
.
vector
(
'y'
,
dtype
=
self
.
dtype
)
c
=
tensor
.
iscalar
(
'c'
)
mode
=
theano
.
compile
.
get_mode
(
self
.
mode
)
.
excluding
(
# Disable many opt to keep the graph big enough to disable
# the opt.
'fusion'
,
'local_add_canonizer'
,
'inplace'
,
'constant_folding'
,
'constant_folding'
)
y2
=
reduce
(
lambda
x
,
y
:
x
+
y
,
[
y
]
+
list
(
range
(
200
)))
f
=
theano
.
function
([
c
,
x
,
y
],
ifelse
(
c
,
x
,
y2
),
mode
=
mode
)
# For not inplace ifelse
self
.
assertFunctionContains1
(
f
,
IfElse
(
1
))
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
xlen
=
rng
.
randint
(
200
)
ylen
=
rng
.
randint
(
200
)
vx
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
xlen
,)),
self
.
dtype
)
vy
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
ylen
,)),
self
.
dtype
)
assert
numpy
.
allclose
(
vx
,
f
(
1
,
vx
,
vy
))
assert
numpy
.
allclose
(
vy
+
sum
(
range
(
200
)),
f
(
0
,
vx
,
vy
))
def
test_mixed_dtype
(
self
):
x1
=
tensor
.
vector
(
'x1'
,
dtype
=
'int32'
)
x2
=
tensor
.
vector
(
'x2'
,
dtype
=
self
.
dtype
)
...
...
@@ -65,9 +93,9 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
xlen
=
rng
.
randint
(
200
)
ylen
=
rng
.
randint
(
200
)
vx1
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
xlen
,))
*
3
,
'int32'
)
vx1
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
xlen
,))
*
3
,
'int32'
)
vx2
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
xlen
,)),
self
.
dtype
)
vy1
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
ylen
,))
*
3
,
'int32'
)
vy1
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
ylen
,))
*
3
,
'int32'
)
vy2
=
numpy
.
asarray
(
rng
.
uniform
(
size
=
(
ylen
,)),
self
.
dtype
)
o1
,
o2
=
f
(
1
,
vx1
,
vx2
,
vy1
,
vy2
)
...
...
@@ -288,8 +316,8 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
z2
=
ifelse
(
c
,
x
+
2
,
y
+
2
)
z
=
z1
+
z2
f
=
theano
.
function
([
c
,
x
,
y
],
z
)
assert
len
([
x
for
x
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
IfElse
)])
==
1
assert
len
([
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
IfElse
)])
==
1
def
test_remove_useless_inputs1
(
self
):
raise
SkipTest
(
"Optimization temporarily disabled"
)
...
...
@@ -299,8 +327,8 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
z
=
ifelse
(
c
,
(
x
,
x
),
(
y
,
y
))
f
=
theano
.
function
([
c
,
x
,
y
],
z
)
ifnode
=
[
x
for
x
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
IfElse
)][
0
]
ifnode
=
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
n
.
op
,
IfElse
)][
0
]
assert
len
(
ifnode
.
inputs
)
==
3
def
test_remove_useless_inputs2
(
self
):
...
...
@@ -418,12 +446,12 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
c
=
tensor
.
iscalar
(
'c'
)
out
=
ifelse
(
c
,
ifelse
(
c
,
x1
,
x2
)
+
ifelse
(
c
,
y1
,
y2
)
+
w1
,
ifelse
(
c
,
x1
,
x2
)
+
ifelse
(
c
,
y1
,
y2
)
+
w2
)
ifelse
(
c
,
x1
,
x2
)
+
ifelse
(
c
,
y1
,
y2
)
+
w1
,
ifelse
(
c
,
x1
,
x2
)
+
ifelse
(
c
,
y1
,
y2
)
+
w2
)
f
=
theano
.
function
([
x1
,
x2
,
y1
,
y2
,
w1
,
w2
,
c
],
out
,
allow_input_downcast
=
True
)
assert
len
([
x
for
x
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
IfElse
)])
==
1
if
isinstance
(
x
.
op
,
IfElse
)])
==
1
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
vx1
=
rng
.
uniform
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论