Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ff2a27d5
提交
ff2a27d5
authored
4月 26, 2011
作者:
Ian Goodfellow
浏览文件
操作
浏览文件
下载
差异文件
merged
上级
a233573d
83c7e294
显示空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
458 行增加
和
441 行删除
+458
-441
how_to_release.txt
doc/internal/how_to_release.txt
+2
-1
config.txt
doc/library/config.txt
+75
-4
shape_info.txt
doc/tutorial/shape_info.txt
+15
-12
__init__.py
theano/__init__.py
+30
-0
debugmode.py
theano/compile/debugmode.py
+7
-1
function_module.py
theano/compile/function_module.py
+45
-1
io.py
theano/compile/io.py
+6
-0
pfunc.py
theano/compile/pfunc.py
+13
-1
test_function_module.py
theano/compile/tests/test_function_module.py
+2
-5
configdefaults.py
theano/configdefaults.py
+39
-1
basic.py
theano/scalar/basic.py
+4
-4
scan.py
theano/scan_module/scan.py
+1
-9
scan_op.py
theano/scan_module/scan_op.py
+93
-91
scan_utils.py
theano/scan_module/scan_utils.py
+0
-296
test_scan.py
theano/scan_module/tests/test_scan.py
+27
-0
test_basic.py
theano/sparse/tests/test_basic.py
+57
-0
basic.py
theano/tensor/basic.py
+16
-7
test_basic.py
theano/tensor/tests/test_basic.py
+26
-3
test_opt.py
theano/tensor/tests/test_opt.py
+0
-5
没有找到文件。
doc/internal/how_to_release.txt
浏览文件 @
ff2a27d5
...
@@ -25,7 +25,8 @@ Edit ``setup.py`` to contain the newest version number ::
...
@@ -25,7 +25,8 @@ Edit ``setup.py`` to contain the newest version number ::
* Change the ``version`` and ``release`` variables to new version number.
* Change the ``version`` and ``release`` variables to new version number.
* Change the upper copyright year to the current year if necessary.
* Change the upper copyright year to the current year if necessary.
* Update the year in the Theano/LICENSE.txt file.
Update the year in the ``Theano/LICENSE.txt`` file too, if necessary.
``NEWS.txt`` usually contains the name and date of the release, change them
``NEWS.txt`` usually contains the name and date of the release, change them
too.
too.
...
...
doc/library/config.txt
浏览文件 @
ff2a27d5
...
@@ -296,8 +296,79 @@ import theano and print the config variable, as in:
...
@@ -296,8 +296,79 @@ import theano and print the config variable, as in:
Default: False
Default: False
Remove compiled file when not needed anymore.
If False, source code files are removed when they are not needed anymore.
This mean it remove file that he tried to compile but failed.
This means files whose compilation failed are deleted.
Set to True to keep the source file that failed to compile to
Set to True to keep those files in order to debug compilation errors.
debug them.
.. attribute:: config.DebugMode
This section contains various attributes configuring the behaviour
of mode :class:`~debugmode.DebugMode`.
.. attribute:: config.numpy.seterr_all
String Value: ``'ignore'``, ``'warn'``, ``'raise'``, ``'call'``,
``'print'``, ``'log'``, ``'None'``
Default: ``'ignore'``
Set the default behaviour described by `numpy.seterr
<http://docs.scipy.org/doc/numpy/reference/generated/numpy.seterr.html>`__.
``'None'`` means that numpy's default behaviour will not be changed (unless
one of the other `config.numpy.seterr_*` overrides it), but this behaviour
can change between numpy releases.
This flag sets the default behaviour for all kinds of floating-pont
errors, and it can be overriden for specific errors by setting one
(or more) of the flags below.
This flag's value cannot be modified during the program execution.
.. attribute:: config.numpy.seterr_divide
String Value: ``'None'``, ``'ignore'``, ``'warn'``, ``'raise'``,
``'call'``, ``'print'``, ``'log'``
Default: ``'None'``
Sets numpy's behavior for division by zero. ``'None'`` means using the
default, defined by config.numpy.seterr_all.
This flag's value cannot be modified during the program execution.
.. attribute:: config.numpy.seterr_over
String Value: ``'None'``, ``'ignore'``, ``'warn'``, ``'raise'``,
``'call'``, ``'print'``, ``'log'``
Default: ``'None'``
Sets numpy's behavior for floating-point overflow. ``'None'`` means
using the default, defined by config.numpy.seterr_all.
This flag's value cannot be modified during the program execution.
.. attribute:: config.numpy.seterr_under
String Value: ``'None'``, ``'ignore'``, ``'warn'``, ``'raise'``,
``'call'``, ``'print'``, ``'log'``
Default: ``'None'``
Sets numpy's behavior for floating-point underflow. ``'None'`` means
using the default, defined by config.numpy.seterr_all.
This flag's value cannot be modified during the program execution.
.. attribute:: numpy.seterr_invalid
String Value: ``'None'``, ``'ignore'``, ``'warn'``, ``'raise'``,
``'call'``, ``'print'``, ``'log'``
Default: ``'None'``
Sets numpy's behavior for invalid floating-point operation. ``'None'``
means using the default, defined by :attr:`config.numpy.seterr_all`.
This flag's value cannot be modified during the program execution.
doc/tutorial/shape_info.txt
浏览文件 @
ff2a27d5
...
@@ -38,8 +38,8 @@ output.
...
@@ -38,8 +38,8 @@ output.
Shape inference problem
Shape inference problem
=======================
=======================
Theano
do shape information propag
ation in the graph. Sometimes this
Theano
propagates shape inform
ation in the graph. Sometimes this
can
had error. E
xample:
can
lead to errors. For e
xample:
.. code-block:: python
.. code-block:: python
...
@@ -71,10 +71,10 @@ can had error. Example:
...
@@ -71,10 +71,10 @@ can had error. Example:
# |Shape_i{1} [@55959184] '' 0
# |Shape_i{1} [@55959184] '' 0
# | |<TensorType(float64, matrix)> [@55583888]
# | |<TensorType(float64, matrix)> [@55583888]
print f(xv,yv)# DO
N
T RAISE AN ERROR AS SHOULD BE.
print f(xv,yv)# DO
ES NO
T RAISE AN ERROR AS SHOULD BE.
#[8,4]
#[8,4]
f = theano.function([x,y], z)# Do
n'
t take the shape.
f = theano.function([x,y], z)# Do
no
t take the shape.
theano.printing.debugprint(f)
theano.printing.debugprint(f)
#Join [@44540496] '' 0
#Join [@44540496] '' 0
# |0 [@44540432]
# |0 [@44540432]
...
@@ -84,22 +84,25 @@ can had error. Example:
...
@@ -84,22 +84,25 @@ can had error. Example:
f(xv,yv)
f(xv,yv)
# Raise a dimensions mismatch error.
# Raise a dimensions mismatch error.
As you see, when you ask for the shape of some computation(join in the
As you see, when you ask for the shape of some computation
(join in the
example), we sometimes compute
the shape without executing the
example), we sometimes compute
an inferred shape directly, without executing
computation
(there is no join in the first output or debugprint).
the computation itself
(there is no join in the first output or debugprint).
This make
the computation of the shape faster, but can hide error
. In
This make
s the computation of the shape faster, but it can hide errors
. In
the example, the computation of the shape of join is done on the first
the example, the computation of the shape of join is done on the first
theano variable in the join, not on the other.
theano variable in the join, not on the other.
This can probably happen with many other op as elemwise, dot, ...
This can probably happen with many other op as elemwise, dot, ...
Indeed, to make some optimizations (for speed or stability, for instance),
Theano can assume that the computation is correct and consistent
in the first place, this is the case here.
You can detect those problem by running the code without this
You can detect those problem by running the code without this
optimization
with the t
heano flag
optimization
, with the T
heano flag
`optimizer_excluding=local_shape_to_shape_i`. You can also have the
`optimizer_excluding=local_shape_to_shape_i`. You can also have the
same effect by running in the mode FAST_COMPILE
(won'
t apply this
same effect by running in the mode FAST_COMPILE
(it will no
t apply this
optimization
and most other optimization too) or DEBUG_MODE(
will test
optimization
, nor most other optimizations) or DEBUG_MODE (it
will test
before and after all optimizations(much slower)).
before and after all optimizations
(much slower)).
Specifing exact shape
Specifing exact shape
...
...
theano/__init__.py
浏览文件 @
ff2a27d5
...
@@ -81,6 +81,36 @@ import gof
...
@@ -81,6 +81,36 @@ import gof
if
config
.
device
.
startswith
(
'gpu'
)
or
config
.
init_gpu_device
.
startswith
(
'gpu'
):
if
config
.
device
.
startswith
(
'gpu'
)
or
config
.
init_gpu_device
.
startswith
(
'gpu'
):
import
theano.sandbox.cuda
import
theano.sandbox.cuda
# Use config.numpy to call numpy.seterr
import
numpy
if
config
.
numpy
.
seterr_all
==
'None'
:
_all
=
None
else
:
_all
=
config
.
numpy
.
seterr_all
if
config
.
numpy
.
seterr_divide
==
'None'
:
_divide
=
None
else
:
_divide
=
config
.
numpy
.
seterr_divide
if
config
.
numpy
.
seterr_over
==
'None'
:
_over
=
None
else
:
_over
=
config
.
numpy
.
seterr_over
if
config
.
numpy
.
seterr_under
==
'None'
:
_under
=
None
else
:
_under
=
config
.
numpy
.
seterr_under
if
config
.
numpy
.
seterr_invalid
==
'None'
:
_invalid
=
None
else
:
_invalid
=
config
.
numpy
.
seterr_invalid
numpy
.
seterr
(
all
=
_all
,
divide
=
_divide
,
over
=
_over
,
under
=
_under
,
invalid
=
_invalid
)
del
_all
,
_divide
,
_over
,
_under
,
_invalid
## import scalar_opt
## import scalar_opt
### This is defined here because it is designed to work across symbolic datatypes
### This is defined here because it is designed to work across symbolic datatypes
...
...
theano/compile/debugmode.py
浏览文件 @
ff2a27d5
...
@@ -535,7 +535,13 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, clobber_dr_v
...
@@ -535,7 +535,13 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, clobber_dr_v
if
warn_input_not_reused
and
destroyed_res_list
:
if
warn_input_not_reused
and
destroyed_res_list
:
dmap
=
getattr
(
node
.
op
,
'destroy_map'
,{})
dmap
=
getattr
(
node
.
op
,
'destroy_map'
,{})
for
oo
,
ii
in
dmap
.
iteritems
():
for
oo
,
ii
in
dmap
.
iteritems
():
if
storage_map
[
node
.
outputs
[
oo
]][
0
]
is
not
storage_map
[
node
.
inputs
[
ii
[
0
]]][
0
]:
out_var
=
storage_map
[
node
.
outputs
[
oo
]][
0
]
in_var
=
storage_map
[
node
.
inputs
[
ii
[
0
]]][
0
]
if
isinstance
(
node
.
op
,
theano
.
compile
.
mode
.
OutputGuard
):
# The point of OutputGuard is to be declared as destructive
# while not destroying anything
continue
if
out_var
is
not
in_var
:
opt_warning
(
"input idx
%
d marked as destroyed was not changed for node '
%
s'"
%
(
ii
[
0
],
str
(
node
)))
opt_warning
(
"input idx
%
d marked as destroyed was not changed for node '
%
s'"
%
(
ii
[
0
],
str
(
node
)))
if
warn_input_not_reused
:
if
warn_input_not_reused
:
...
...
theano/compile/function_module.py
浏览文件 @
ff2a27d5
...
@@ -190,7 +190,36 @@ class DeepCopyOp(theano.gof.Op):
...
@@ -190,7 +190,36 @@ class DeepCopyOp(theano.gof.Op):
else
:
else
:
super
(
DeepCopyOp
,
self
)
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
super
(
DeepCopyOp
,
self
)
.
c_code
(
node
,
name
,
inames
,
onames
,
sub
)
class
ViewOp
(
theano
.
gof
.
Op
):
def
__init__
(
self
):
self
.
view_map
=
{
0
:[
0
]}
def
__str__
(
self
):
return
self
.
__class__
.
__name__
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
make_node
(
self
,
x
):
return
theano
.
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
args
,
outs
):
outs
[
0
][
0
]
=
args
[
0
]
def
infer_shape
(
self
,
node
,
input_shapes
):
return
input_shapes
def
grad
(
self
,
args
,
g_outs
):
return
g_outs
deep_copy_op
=
DeepCopyOp
()
deep_copy_op
=
DeepCopyOp
()
view_op
=
ViewOp
()
DUPLICATE
=
[
'DUPLICATE'
]
# unique id object used as a placeholder for duplicate entries
DUPLICATE
=
[
'DUPLICATE'
]
# unique id object used as a placeholder for duplicate entries
class
Function
(
object
):
class
Function
(
object
):
...
@@ -771,6 +800,9 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs):
...
@@ -771,6 +800,9 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs):
# We could don't put deep copy if both outputs have borrow==True
# We could don't put deep copy if both outputs have borrow==True
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
if
env
.
outputs
[
j
]
in
views_of_output_i
:
if
env
.
outputs
[
j
]
in
views_of_output_i
:
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_outputs
[
j
]
.
borrow
:
env
.
change_input
(
'output'
,
i
,
view_op
(
env
.
outputs
[
i
]))
else
:
env
.
change_input
(
'output'
,
i
,
deep_copy_op
(
env
.
outputs
[
i
]))
env
.
change_input
(
'output'
,
i
,
deep_copy_op
(
env
.
outputs
[
i
]))
copied
=
True
copied
=
True
break
break
...
@@ -784,8 +816,20 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs):
...
@@ -784,8 +816,20 @@ def insert_deepcopy(env, wrapped_inputs, wrapped_outputs):
continue
continue
if
input_j
in
updated_env_inputs
:
if
input_j
in
updated_env_inputs
:
continue
continue
# We could don't put deep_copy_op if the input and the output have borrow==True
if
input_j
in
views_of_output_i
:
if
input_j
in
views_of_output_i
:
# We don't put deep_copy_op if the input and the output have borrow==True
if
input_j
in
env
.
inputs
:
j
=
env
.
inputs
.
index
(
input_j
)
if
wrapped_outputs
[
i
]
.
borrow
and
wrapped_inputs
[
j
]
.
borrow
:
env
.
change_input
(
'output'
,
i
,
view_op
(
env
.
outputs
[
i
]))
break
else
:
env
.
change_input
(
'output'
,
i
,
deep_copy_op
(
env
.
outputs
[
i
]))
break
elif
wrapped_outputs
[
i
]
.
borrow
:
env
.
change_input
(
'output'
,
i
,
view_op
(
env
.
outputs
[
i
]))
break
else
:
env
.
change_input
(
'output'
,
i
,
deep_copy_op
(
env
.
outputs
[
i
]))
env
.
change_input
(
'output'
,
i
,
deep_copy_op
(
env
.
outputs
[
i
]))
break
break
...
...
theano/compile/io.py
浏览文件 @
ff2a27d5
...
@@ -164,6 +164,11 @@ class In(SymbolicInput):
...
@@ -164,6 +164,11 @@ class In(SymbolicInput):
True: permit the compiled function to modify the python object being passed as the input
True: permit the compiled function to modify the python object being passed as the input
False: do not permit the compiled function to modify the python object being passed as the input.
False: do not permit the compiled function to modify the python object being passed as the input.
borrow: Bool (default: False if update is None, True if update is not None)
True: permit the output of the compiled function to be aliased to the input
False: do not permit any output to be aliased to the input
strict: Bool (default: False)
strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type
False: the value you pass for this input may be cast automatically to the proper type
...
@@ -226,6 +231,7 @@ class In(SymbolicInput):
...
@@ -226,6 +231,7 @@ class In(SymbolicInput):
autoname
=
autoname
,
autoname
=
autoname
,
implicit
=
implicit
)
implicit
=
implicit
)
self
.
value
=
value
self
.
value
=
value
self
.
borrow
=
borrow
if
self
.
implicit
and
value
is
None
:
if
self
.
implicit
and
value
is
None
:
raise
TypeError
(
'An implicit input must be given a default value'
)
raise
TypeError
(
'An implicit input must be given a default value'
)
...
...
theano/compile/pfunc.py
浏览文件 @
ff2a27d5
...
@@ -244,7 +244,7 @@ def rebuild_collect_shared( outputs
...
@@ -244,7 +244,7 @@ def rebuild_collect_shared( outputs
class
Param
(
object
):
class
Param
(
object
):
def
__init__
(
self
,
variable
,
default
=
None
,
name
=
None
,
mutable
=
False
,
def
__init__
(
self
,
variable
,
default
=
None
,
name
=
None
,
mutable
=
False
,
strict
=
False
,
allow_downcast
=
None
,
implicit
=
None
):
strict
=
False
,
allow_downcast
=
None
,
implicit
=
None
,
borrow
=
False
):
"""
"""
:param variable: A variable in an expression graph to use as a compiled-function parameter
:param variable: A variable in an expression graph to use as a compiled-function parameter
...
@@ -255,6 +255,11 @@ class Param(object):
...
@@ -255,6 +255,11 @@ class Param(object):
:param mutable: True -> function is allowed to modify this argument.
:param mutable: True -> function is allowed to modify this argument.
:param borrow: True -> function is allowed to alias some output to
this input
False: do not permit any output to be aliased to the input
:param strict: False -> function arguments may be copied or casted to match the
:param strict: False -> function arguments may be copied or casted to match the
type required by the parameter `variable`. True -> function arguments must exactly match the type
type required by the parameter `variable`. True -> function arguments must exactly match the type
required by `variable`.
required by `variable`.
...
@@ -271,9 +276,15 @@ class Param(object):
...
@@ -271,9 +276,15 @@ class Param(object):
self
.
default
=
default
self
.
default
=
default
self
.
name
=
name
self
.
name
=
name
self
.
mutable
=
mutable
self
.
mutable
=
mutable
# Mutable implies borrow. You can get borrow = False because of the
# default and it is a bit annoying to require the user to set both
# borrow and mutable to True
if
mutable
:
borrow
=
True
self
.
strict
=
strict
self
.
strict
=
strict
self
.
allow_downcast
=
allow_downcast
self
.
allow_downcast
=
allow_downcast
self
.
implicit
=
implicit
self
.
implicit
=
implicit
self
.
borrow
=
borrow
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
no_default_updates
=
False
,
accept_inplace
=
False
,
name
=
None
,
no_default_updates
=
False
,
accept_inplace
=
False
,
name
=
None
,
...
@@ -396,6 +407,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
...
@@ -396,6 +407,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
value
=
param
.
default
,
value
=
param
.
default
,
mutable
=
param
.
mutable
,
mutable
=
param
.
mutable
,
strict
=
param
.
strict
,
strict
=
param
.
strict
,
borrow
=
param
.
borrow
,
allow_downcast
=
param
.
allow_downcast
,
allow_downcast
=
param
.
allow_downcast
,
implicit
=
param
.
implicit
)
implicit
=
param
.
implicit
)
raise
TypeError
(
'Unknown parameter type:
%
s'
%
type
(
param
))
raise
TypeError
(
'Unknown parameter type:
%
s'
%
type
(
param
))
...
...
theano/compile/tests/test_function_module.py
浏览文件 @
ff2a27d5
...
@@ -304,11 +304,8 @@ class T_function(unittest.TestCase):
...
@@ -304,11 +304,8 @@ class T_function(unittest.TestCase):
assert
(
out
==
4
)
.
all
()
assert
(
out
==
4
)
.
all
()
out
[
0
]
=
3
out
[
0
]
=
3
out2
=
f
()
out2
=
f
()
# Currently we don't do this optimization!
assert
out2
is
out
# As this is a corner case that is not usefull for use
assert
(
out2
==
3
)
.
all
()
# We probably won't optimize it.
assert
out2
is
not
out
assert
(
out2
==
4
)
.
all
()
def
test_borrow_input
(
self
):
def
test_borrow_input
(
self
):
...
...
theano/configdefaults.py
浏览文件 @
ff2a27d5
...
@@ -71,7 +71,7 @@ AddConfigVar('home',
...
@@ -71,7 +71,7 @@ AddConfigVar('home',
#This expanduser works on windows (see discussion on theano-users, July 13 2010)
#This expanduser works on windows (see discussion on theano-users, July 13 2010)
AddConfigVar
(
'nocleanup'
,
AddConfigVar
(
'nocleanup'
,
"
s
uppress the deletion of code files that did not compile cleanly"
,
"
S
uppress the deletion of code files that did not compile cleanly"
,
BoolParam
(
False
))
BoolParam
(
False
))
AddConfigVar
(
'tensor.cmp_sloppy'
,
AddConfigVar
(
'tensor.cmp_sloppy'
,
...
@@ -115,6 +115,44 @@ AddConfigVar('experimental.mrg',
...
@@ -115,6 +115,44 @@ AddConfigVar('experimental.mrg',
"Another random number generator that work on the gpu"
,
"Another random number generator that work on the gpu"
,
BoolParam
(
False
))
BoolParam
(
False
))
AddConfigVar
(
'numpy.seterr_all'
,
(
"Sets numpy's behaviour for floating-point errors, "
,
"see numpy.seterr. "
"'None' means not to change numpy's default, which can be "
"different for different numpy releases. "
"This flag sets the default behaviour for all kinds of floating-"
"point errors, its effect can be overriden for specific errors "
"by the following flags: seterr_divide, seterr_over, "
"seterr_under and seterr_invalid."
),
EnumStr
(
'ignore'
,
'warn'
,
'raise'
,
'call'
,
'print'
,
'log'
,
'None'
,
allow_override
=
False
))
AddConfigVar
(
'numpy.seterr_divide'
,
(
"Sets numpy's behavior for division by zero, see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
'None'
,
'ignore'
,
'warn'
,
'raise'
,
'call'
,
'print'
,
'log'
,
allow_override
=
False
))
AddConfigVar
(
'numpy.seterr_over'
,
(
"Sets numpy's behavior for floating-point overflow, "
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
'None'
,
'ignore'
,
'warn'
,
'raise'
,
'call'
,
'print'
,
'log'
,
allow_override
=
False
))
AddConfigVar
(
'numpy.seterr_under'
,
(
"Sets numpy's behavior for floating-point underflow, "
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
'None'
,
'ignore'
,
'warn'
,
'raise'
,
'call'
,
'print'
,
'log'
,
allow_override
=
False
))
AddConfigVar
(
'numpy.seterr_invalid'
,
(
"Sets numpy's behavior for invalid floating-point operation, "
"see numpy.seterr. "
"'None' means using the default, defined by numpy.seterr_all."
),
EnumStr
(
'None'
,
'ignore'
,
'warn'
,
'raise'
,
'call'
,
'print'
,
'log'
,
allow_override
=
False
))
###
###
### To disable some warning about old bug that are fixed now.
### To disable some warning about old bug that are fixed now.
...
...
theano/scalar/basic.py
浏览文件 @
ff2a27d5
...
@@ -5,10 +5,10 @@ WARNING
...
@@ -5,10 +5,10 @@ WARNING
This directory is for the internal of Theano.
This directory is for the internal of Theano.
You are strongly advi
ced to don't use it
except if you know
You are strongly advi
sed not to use it,
except if you know
what you
do
!
what you
are doing
!
If you want to use scalar variable in a Theano graph,
If you want to use
a
scalar variable in a Theano graph,
you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
"""
"""
...
@@ -113,7 +113,7 @@ class Scalar(Type):
...
@@ -113,7 +113,7 @@ class Scalar(Type):
raise
TypeError
(
"Could not convert
%
s (value=
%
s) to
%
s"
%
(
type
(
data
),
data
,
self
.
dtype
),
e
)
raise
TypeError
(
"Could not convert
%
s (value=
%
s) to
%
s"
%
(
type
(
data
),
data
,
self
.
dtype
),
e
)
def
values_eq_approx
(
self
,
a
,
b
,
tolerance
=
1e-4
):
def
values_eq_approx
(
self
,
a
,
b
,
tolerance
=
1e-4
):
return
abs
(
a
-
b
)
/
(
a
+
b
)
<
tolerance
return
abs
(
a
-
b
)
<=
((
abs
(
a
)
+
abs
(
b
))
*
tolerance
)
def
c_headers
(
self
):
def
c_headers
(
self
):
l
=
[
'<math.h>'
]
l
=
[
'<math.h>'
]
...
...
theano/scan_module/scan.py
浏览文件 @
ff2a27d5
...
@@ -846,15 +846,7 @@ def scan( fn
...
@@ -846,15 +846,7 @@ def scan( fn
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
revised_outs
=
[]
local_op
=
scan_op
.
Scan
(
inner_inputs
,
new_outs
,
info
)
for
o
in
new_outs
:
if
(
o
in
inner_inputs
or
isinstance
(
o
,
tensor
.
Constant
)):
revised_outs
.
append
(
scan_utils
.
cloneOp
(
o
))
else
:
revised_outs
.
append
(
o
)
local_op
=
scan_op
.
Scan
(
inner_inputs
,
revised_outs
,
info
)
##
##
### Step 8. Compute the outputs using the scan op
### Step 8. Compute the outputs using the scan op
...
...
theano/scan_module/scan_op.py
浏览文件 @
ff2a27d5
...
@@ -18,7 +18,8 @@ import logging
...
@@ -18,7 +18,8 @@ import logging
import
numpy
import
numpy
import
sys
import
sys
from
theano.compile
import
SharedVariable
,
function
,
Param
from
theano.compile
import
SharedVariable
,
function
,
Param
,
Out
from
theano.compile.function_module
import
ViewOp
,
DeepCopyOp
from
theano
import
compile
from
theano
import
compile
from
theano
import
gradient
from
theano
import
gradient
from
theano.gof.python25
import
all
from
theano.gof.python25
import
all
...
@@ -166,47 +167,25 @@ class Scan(Op):
...
@@ -166,47 +167,25 @@ class Scan(Op):
self
.
info
[
'name'
]
=
self
.
name
self
.
info
[
'name'
]
=
self
.
name
self
.
info
[
'mode_instance'
]
=
self
.
mode_instance
self
.
info
[
'mode_instance'
]
=
self
.
mode_instance
if
isinstance
(
self
.
mode_instance
,
compile
.
debugmode
.
DebugMode
):
wrapped_inputs
=
[
Param
(
x
,
borrow
=
True
)
for
x
in
inputs
]
theano_fn
=
function
(
wrapped_outputs
=
[
Out
(
x
,
borrow
=
True
)
for
x
in
outputs
]
inputs
self
.
fn
=
function
(
wrapped_inputs
,
,
outputs
wrapped_outputs
,
,
mode
=
self
.
mode_instance
mode
=
self
.
mode_instance
,
,
name
=
self
.
name
)
name
=
self
.
name
)
self
.
mask
=
[
0
for
x
in
xrange
(
self
.
n_shared_outs
)
]
def
fn_wrapper
(
ins_storage
,
outs_storage
):
# If a shared variable is the result of a ViewOp it is a clear
'''
# indication that we need to copy that value after the perform of
Wrap theano_fn to have same interface as scan_utils's
# scan is done
scan_function
slices
=
(
self
.
n_mit_mot_outs
+
'''
outputs
=
theano_fn
(
*
ins_storage
)
for
(
out
,
out_storage
)
in
zip
(
outputs
,
outs_storage
):
if
out_storage
[
0
]
is
not
None
and
out_storage
[
0
]
.
shape
:
out_storage
[
0
][:]
=
out
elif
out_storage
[
0
]
is
not
None
:
out_storage
[
0
]
.
itemset
(
out
)
return
[[
o
]
for
o
in
outputs
]
self
.
fn
=
fn_wrapper
self
.
fn
.
maker
=
scan_utils
.
EmptyObject
()
self
.
fn
.
maker
.
inputs
=
inputs
self
.
fn
.
maker
.
outputs
=
outputs
self
.
fn
.
maker
.
env
=
theano_fn
.
maker
.
env
self
.
mask
=
[
0
for
x
in
xrange
(
self
.
n_shared_outs
)]
else
:
self
.
mask
,
self
.
fn
=
scan_utils
.
scan_function
(
inputs
,
outputs
,
nonmutable
,
mode
=
self
.
mode_instance
,
name
=
self
.
name
,
slices
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
for
i
in
xrange
(
slices
,
slices
+
self
.
n_shared_outs
):
if
isinstance
(
self
.
fn
.
maker
.
env
.
outputs
[
i
]
.
owner
.
op
,
ViewOp
):
self
.
mask
[
i
-
slices
]
=
1
)
# check for shared variables in the inputs
assert
not
numpy
.
any
(
[
isinstance
(
x
,
SharedVariable
)
for
x
in
self
.
fn
.
maker
.
inputs
])
# Pre-computing some values to speed up perform
# Pre-computing some values to speed up perform
self
.
mintaps
=
[
numpy
.
min
(
x
)
for
x
in
self
.
tap_array
]
self
.
mintaps
=
[
numpy
.
min
(
x
)
for
x
in
self
.
tap_array
]
...
@@ -269,7 +248,7 @@ class Scan(Op):
...
@@ -269,7 +248,7 @@ class Scan(Op):
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'Initial state'
raise
ValueError
(
err_msg1
%
(
'Initial state'
,
inputs
[
index
]
.
name
,
inputs
[
index
]
.
name
,
i
d
x
,
i
nde
x
,
inputs
[
index
]
.
dtype
,
inputs
[
index
]
.
dtype
,
self
.
inputs
[
index_i
]
.
name
,
self
.
inputs
[
index_i
]
.
name
,
self
.
inputs
[
index_i
]
.
dtype
)
)
,
self
.
inputs
[
index_i
]
.
dtype
)
)
...
@@ -277,7 +256,7 @@ class Scan(Op):
...
@@ -277,7 +256,7 @@ class Scan(Op):
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
for
k
in
self
.
mit_mot_out_slices
[
index
-
start
]:
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
:
raise
ValueError
(
err_msg2
%
(
inputs
[
index
]
.
name
raise
ValueError
(
err_msg2
%
(
inputs
[
index
]
.
name
,
i
d
x
,
i
nde
x
,
inputs
[
index
]
.
dtype
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
,
self
.
outputs
[
index_o
]
.
dtype
)
)
index_o
+=
1
index_o
+=
1
...
@@ -289,7 +268,7 @@ class Scan(Op):
...
@@ -289,7 +268,7 @@ class Scan(Op):
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
if
inputs
[
index
]
.
dtype
!=
self
.
inputs
[
index_i
]
.
dtype
:
raise
ValueError
(
err_msg1
%
(
'Initial state'
raise
ValueError
(
err_msg1
%
(
'Initial state'
,
inputs
[
index
]
.
name
,
inputs
[
index
]
.
name
,
i
d
x
,
i
nde
x
,
inputs
[
index
]
.
dtype
,
inputs
[
index
]
.
dtype
,
self
.
inputs
[
index_i
]
.
name
,
self
.
inputs
[
index_i
]
.
name
,
self
.
inputs
[
index_i
]
.
dtype
)
)
,
self
.
inputs
[
index_i
]
.
dtype
)
)
...
@@ -310,7 +289,7 @@ class Scan(Op):
...
@@ -310,7 +289,7 @@ class Scan(Op):
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
if
(
hasattr
(
inputs
[
index
],
'dtype'
)
and
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
):
inputs
[
index
]
.
dtype
!=
self
.
outputs
[
index_o
]
.
dtype
):
raise
ValueError
(
err_msg2
%
(
inputs
[
index
]
.
name
raise
ValueError
(
err_msg2
%
(
inputs
[
index
]
.
name
,
i
d
x
,
i
nde
x
,
inputs
[
index
]
.
dtype
,
inputs
[
index
]
.
dtype
,
self
.
outputs
[
index_o
]
.
dtype
)
)
,
self
.
outputs
[
index_o
]
.
dtype
)
)
index
+=
1
index
+=
1
...
@@ -406,10 +385,8 @@ class Scan(Op):
...
@@ -406,10 +385,8 @@ class Scan(Op):
if
n_steps
<
0
:
if
n_steps
<
0
:
n_steps
=
abs
(
n_steps
)
n_steps
=
abs
(
n_steps
)
seqs
=
[
seq
[::
-
1
]
for
seq
in
args
[
1
:
self
.
seqs_arg_offset
]]
seqs
=
[
seq
[::
-
1
]
for
seq
in
args
[
1
:
self
.
seqs_arg_offset
]]
seqs
=
zip
(
seqs
,
self
.
vector_seqs
)
else
:
else
:
seqs
=
args
[
1
:
self
.
seqs_arg_offset
]
seqs
=
args
[
1
:
self
.
seqs_arg_offset
]
seqs
=
zip
(
seqs
,
self
.
vector_seqs
)
# 2. Allocate memory for the outputs. Construct the list:
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output
# store_steps -- map containting the length of each output
...
@@ -447,62 +424,81 @@ class Scan(Op):
...
@@ -447,62 +424,81 @@ class Scan(Op):
offset
=
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
offset
=
self
.
nit_sot_arg_offset
+
self
.
n_nit_sot
other_args
=
args
[
offset
:]
other_args
=
args
[
offset
:]
zipped_outs
=
[(
outs
[
idx
],
self
.
vector_outs
[
idx
],
tap
,
input_storage
=
self
.
fn
.
input_storage
store_steps
[
idx
],
idx
)
for
idx
in
xrange
(
self
.
n_outs
)
output_storage
=
self
.
fn
.
output_storage
for
tap
in
self
.
tap_array
[
idx
]
]
fn
=
self
.
fn
.
fn
end
=
self
.
n_outs
+
self
.
n_nit_sot
offset
=
(
self
.
n_seqs
+
sum
(
map
(
len
,
self
.
tap_array
[:
self
.
n_outs
]))
+
sot_outs
=
zip
(
outs
[
self
.
n_mit_mot
:
end
]
self
.
n_shared_outs
)
,
self
.
vector_outs
[
self
.
n_mit_mot
:
end
]
for
idx
in
xrange
(
len
(
other_args
)):
,
store_steps
[
self
.
n_mit_mot
:
end
]
input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
,
range
(
self
.
n_mit_mot
,
end
))
############## THE MAIN LOOP #########################
############## THE MAIN LOOP #########################
for
i
in
xrange
(
n_steps
):
for
i
in
xrange
(
n_steps
):
# sequences over which scan iterates
# sequences over which scan iterates
# 3. collect input slices
# 3. collect input slices
if
i
==
1
and
self
.
n_nit_sot
>
0
:
for
idx
in
xrange
(
self
.
n_seqs
)
:
sot_outs
=
zip
(
outs
[
self
.
n_mit_mot
:
end
]
if
self
.
vector_seqs
[
idx
]:
,
self
.
vector_outs
[
self
.
n_mit_mot
:
end
]
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
:
i
+
1
]
.
reshape
(())
,
store_steps
[
self
.
n_mit_mot
:
end
]
else
:
,
range
(
self
.
n_mit_mot
,
end
))
input_storage
[
idx
]
.
storage
[
0
]
=
seqs
[
idx
][
i
]
offset
=
self
.
n_seqs
for
idx
in
xrange
(
self
.
n_outs
):
if
self
.
vector_outs
[
idx
]:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
\
outs
[
idx
][
0
][
_idx
:
_idx
+
1
]
.
reshape
(())
offset
+=
1
else
:
for
tap
in
self
.
tap_array
[
idx
]:
_idx
=
(
pos
[
idx
]
+
tap
)
%
store_steps
[
idx
]
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
idx
][
0
][
_idx
]
offset
+=
1
fn_args
=
[
seq
[
i
:
i
+
1
]
.
reshape
(())
if
c
else
seq
[
i
]
for
seq
,
c
in
seqs
]
fn_args
+=
[
out
[
0
][(
pos
[
j
]
+
tap
)
%
sz
:
(
pos
[
j
]
+
tap
)
%
sz
+
1
]
.
reshape
(())
if
c
else
out
[
0
][(
pos
[
j
]
+
tap
)
%
sz
]
for
(
out
,
c
,
tap
,
sz
,
j
)
in
zipped_outs
]
a_offset
=
self
.
shared_arg_offset
a_offset
=
self
.
shared_arg_offset
o_offset
=
self
.
n_outs
+
self
.
n_nit_sot
o_offset
=
self
.
n_outs
+
self
.
n_nit_sot
fn_args
+=
[
args
[
a_offset
+
j
]
if
i
==
0
else
outs
[
o_offset
+
j
][
0
]
if
i
==
0
:
for
j
in
xrange
(
self
.
n_shared_outs
)
]
for
j
in
xrange
(
self
.
n_shared_outs
):
input_storage
[
offset
]
.
storage
[
0
]
=
args
[
a_offset
+
j
]
fn_args
+=
other_args
offset
+=
1
else
:
for
j
in
xrange
(
self
.
n_shared_outs
):
input_storage
[
offset
]
.
storage
[
0
]
=
outs
[
o_offset
+
j
][
0
]
offset
+=
1
# 4. collecting slices where the output should be stored
# 4. collecting slices where the output should be stored
fn_out_storage
=
[
[
None
]
for
x
in
xrange
(
self
.
n_mit_mot_outs
)]
for
idx
in
xrange
(
self
.
n_mit_mot_outs
):
if
i
==
0
and
self
.
n_nit_sot
>
0
:
output_storage
[
idx
]
.
storage
[
0
]
=
None
fn_out_storage
+=
[
[
None
]
if
store
==
1
or
c
else
[
out
[
0
][
pos
[
j
]]]
offset
=
self
.
n_mit_mot_outs
for
out
,
c
,
store
,
j
in
sot_outs
[:
-
self
.
n_nit_sot
]
]
if
i
!=
0
and
self
.
n_nit_sot
>
0
:
fn_out_storage
+=
[[
None
]]
*
self
.
n_nit_sot
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
):
if
(
store_steps
[
idx
+
self
.
n_mit_mot
]
==
1
or
self
.
vector_outs
[
idx
+
self
.
n_mit_mot
]):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
else
:
else
:
fn_out_storage
+=
[
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
\
[
None
]
if
store
==
1
or
c
else
[
out
[
0
][
pos
[
j
]]]
outs
[
idx
+
self
.
n_mit_mot
][
0
][
pos
[
idx
+
self
.
n_mit_mot
]]
for
out
,
c
,
store
,
j
in
sot_outs
]
else
:
for
idx
in
xrange
(
self
.
n_outs
+
self
.
n_nit_sot
-
fn_out_storage
+=
[
[
None
]
for
x
in
xrange
(
self
.
n_shared_outs
)
]
self
.
n_mit_mot
):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
offset
+=
self
.
n_outs
+
self
.
n_nit_sot
-
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_shared_outs
):
output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
# 5. compute outputs
# 5. compute outputs
something
=
self
.
fn
(
fn_args
,
fn_out_storage
)
fn
()
offset_out
=
0
offset_out
=
0
# 5.1 Copy over the values for mit_mot outputs
# 5.1 Copy over the values for mit_mot outputs
for
j
in
xrange
(
self
.
n_mit_mot
):
for
j
in
xrange
(
self
.
n_mit_mot
):
for
k
in
self
.
mit_mot_out_slices
[
j
]:
for
k
in
self
.
mit_mot_out_slices
[
j
]:
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
something
[
offset_out
]
[
0
]
outs
[
j
][
0
][
k
+
pos
[
j
]]
=
output_storage
[
offset_out
]
.
storage
[
0
]
offset_out
+=
1
offset_out
+=
1
# 5.2 Copy over the values for mit_sot/sit_sot outputs
# 5.2 Copy over the values for mit_sot/sit_sot outputs
...
@@ -511,8 +507,10 @@ class Scan(Op):
...
@@ -511,8 +507,10 @@ class Scan(Op):
offset_out
-=
self
.
n_mit_mot
offset_out
-=
self
.
n_mit_mot
for
j
in
xrange
(
begin
,
end
):
for
j
in
xrange
(
begin
,
end
):
if
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]:
if
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
=
something
[
offset_out
+
j
][
0
]
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
offset_out
+
j
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
offset_out
+
j
]
.
storage
[
0
]
# 5.3 Copy over the values for nit_sot outputs
# 5.3 Copy over the values for nit_sot outputs
begin
=
end
begin
=
end
...
@@ -520,10 +518,10 @@ class Scan(Op):
...
@@ -520,10 +518,10 @@ class Scan(Op):
for
j
in
xrange
(
begin
,
end
):
for
j
in
xrange
(
begin
,
end
):
if
i
==
0
:
if
i
==
0
:
jout
=
j
+
offset_out
jout
=
j
+
offset_out
shape
=
(
store_steps
[
j
],)
+
something
[
jout
]
[
0
]
.
shape
shape
=
(
store_steps
[
j
],)
+
output_storage
[
jout
]
.
storage
[
0
]
.
shape
if
len
(
something
[
jout
]
[
0
]
.
shape
)
==
0
:
if
len
(
output_storage
[
jout
]
.
storage
[
0
]
.
shape
)
==
0
:
self
.
vector_outs
[
j
]
=
True
self
.
vector_outs
[
j
]
=
True
dtype
=
something
[
jout
]
[
0
]
.
dtype
dtype
=
output_storage
[
jout
]
.
storage
[
0
]
.
dtype
if
(
outs
[
j
][
0
]
is
None
or
if
(
outs
[
j
][
0
]
is
None
or
outs
[
j
][
0
]
.
shape
[
0
]
<
store_steps
[
j
]
or
outs
[
j
][
0
]
.
shape
[
0
]
<
store_steps
[
j
]
or
outs
[
j
][
0
]
.
shape
[
1
:]
!=
shape
[
1
:]
or
outs
[
j
][
0
]
.
shape
[
1
:]
!=
shape
[
1
:]
or
...
@@ -534,9 +532,10 @@ class Scan(Op):
...
@@ -534,9 +532,10 @@ class Scan(Op):
outs
[
j
][
0
]
=
numpy
.
zeros
(
shape
,
dtype
)
outs
[
j
][
0
]
=
numpy
.
zeros
(
shape
,
dtype
)
elif
outs
[
j
][
0
]
.
shape
[
0
]
!=
store_steps
[
j
]:
elif
outs
[
j
][
0
]
.
shape
[
0
]
!=
store_steps
[
j
]:
outs
[
j
][
0
]
=
outs
[
j
][
0
][:
store_steps
[
j
]]
outs
[
j
][
0
]
=
outs
[
j
][
0
][:
store_steps
[
j
]]
outs
[
j
][
0
][
pos
[
j
]]
=
something
[
jout
][
0
]
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
jout
]
.
storage
[
0
]
elif
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]:
elif
(
store_steps
[
j
]
==
1
or
self
.
vector_outs
[
j
]
or
outs
[
j
][
0
][
pos
[
j
]]
=
something
[
j
+
offset_out
][
0
]
outs
[
j
][
0
][
pos
[
j
]]
is
not
output_storage
[
j
+
offset_out
]
.
storage
[
0
]):
outs
[
j
][
0
][
pos
[
j
]]
=
output_storage
[
j
+
offset_out
]
.
storage
[
0
]
# 5.4 Copy over the values for outputs corresponding to shared
# 5.4 Copy over the values for outputs corresponding to shared
...
@@ -545,7 +544,7 @@ class Scan(Op):
...
@@ -545,7 +544,7 @@ class Scan(Op):
end
+=
self
.
n_shared_outs
end
+=
self
.
n_shared_outs
for
j
in
xrange
(
begin
,
end
):
for
j
in
xrange
(
begin
,
end
):
jout
=
j
+
offset_out
jout
=
j
+
offset_out
outs
[
j
][
0
]
=
something
[
jout
]
[
0
]
outs
[
j
][
0
]
=
output_storage
[
jout
]
.
storage
[
0
]
pos
=
[
(
idx
+
1
)
%
store
for
idx
,
store
in
pos
=
[
(
idx
+
1
)
%
store
for
idx
,
store
in
itertools
.
izip
(
pos
,
store_steps
)
itertools
.
izip
(
pos
,
store_steps
)
...
@@ -571,6 +570,7 @@ class Scan(Op):
...
@@ -571,6 +570,7 @@ class Scan(Op):
tmp
[:]
=
outs
[
idx
][
0
][:
pdx
]
tmp
[:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
tmp
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
tmp
del
tmp
else
:
else
:
shape
=
(
store_steps
[
idx
]
-
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
shape
=
(
store_steps
[
idx
]
-
pdx
,)
+
outs
[
idx
][
0
]
.
shape
[
1
:]
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
if
cuda
.
cuda_available
and
isinstance
(
outs
[
idx
][
0
],
...
@@ -581,6 +581,7 @@ class Scan(Op):
...
@@ -581,6 +581,7 @@ class Scan(Op):
tmp
[:]
=
outs
[
idx
][
0
][
pdx
:]
tmp
[:]
=
outs
[
idx
][
0
][
pdx
:]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][
store_steps
[
idx
]
-
pdx
:]
=
outs
[
idx
][
0
][:
pdx
]
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
tmp
outs
[
idx
][
0
][:
store_steps
[
idx
]
-
pdx
]
=
tmp
del
tmp
for
idx
,
val
in
enumerate
(
self
.
mask
):
for
idx
,
val
in
enumerate
(
self
.
mask
):
...
@@ -628,9 +629,10 @@ class Scan(Op):
...
@@ -628,9 +629,10 @@ class Scan(Op):
inputs
=
self
.
inputs
,
inputs
=
self
.
inputs
,
input_shapes
=
inner_ins_shapes
)
input_shapes
=
inner_ins_shapes
)
# Will be used to check if outs_shape can be expressed without using
# Will be used to check if outs_shape can be expressed without using
# variables in self.inputs
# variables in self.inputs.
# The shapes of node.inputs are valid.
validator
=
scan_utils
.
Validator
(
validator
=
scan_utils
.
Validator
(
valid
=
[]
,
valid
=
input_shapes
,
invalid
=
self
.
inputs
,
invalid
=
self
.
inputs
,
valid_equivalent
=
out_equivalent
)
valid_equivalent
=
out_equivalent
)
...
...
theano/scan_module/scan_utils.py
浏览文件 @
ff2a27d5
...
@@ -86,277 +86,6 @@ def traverse(out, x,x_copy, d):
...
@@ -86,277 +86,6 @@ def traverse(out, x,x_copy, d):
d
=
traverse
(
inp
,
x
,
x_copy
,
d
)
d
=
traverse
(
inp
,
x
,
x_copy
,
d
)
return
d
return
d
class
EmptyObject
(
object
):
def
__init__
(
self
):
pass
class
ScanInnerFunction
(
object
):
"""
Stripped down, simplified version of theano.function class that has a
low overhead at calling a function.
"""
def
__init__
(
self
,
fn
,
input_storage
,
output_storage
,
env
,
inputs
,
outputs
,
nonmutable_indices
,
mode
,
name
):
self
.
fn
=
fn
self
.
input_storage
=
input_storage
self
.
n_ins
=
len
(
input_storage
)
self
.
n_outs
=
len
(
output_storage
)
self
.
outputs_storage
=
output_storage
self
.
maker
=
EmptyObject
()
self
.
maker
.
env
=
env
self
.
maker
.
inputs
=
inputs
for
i
in
inputs
:
i
.
update
=
None
self
.
maker
.
expanded_inputs
=
inputs
self
.
maker
.
outputs
=
outputs
self
.
maker
.
nonmutable_indices
=
nonmutable_indices
self
.
maker
.
mode
=
mode
self
.
name
=
name
def
__call__
(
self
,
inputs
,
outputs
):
t0
=
time
.
time
()
# put data into the storage
for
idx
in
xrange
(
self
.
n_ins
):
self
.
input_storage
[
idx
][
0
]
=
inputs
[
idx
]
for
idx
in
xrange
(
self
.
n_outs
):
self
.
outputs_storage
[
idx
][
0
]
=
outputs
[
idx
][
0
]
_t0
=
time
.
time
()
self
.
fn
()
dt_fn
=
time
.
time
()
-
_t0
for
idx
in
xrange
(
self
.
n_outs
):
if
outputs
[
idx
][
0
]
is
not
None
:
if
outputs
[
idx
][
0
]
is
not
self
.
outputs_storage
[
idx
][
0
]:
if
outputs
[
idx
][
0
]
.
shape
:
outputs
[
idx
][
0
][:]
=
self
.
outputs_storage
[
idx
][
0
]
else
:
outputs
[
idx
][
0
]
.
itemset
(
self
.
outputs_storage
[
idx
][
0
])
dt_call
=
time
.
time
()
-
t0
if
hasattr
(
self
.
maker
.
mode
,
'fct_call_time'
):
self
.
maker
.
mode
.
fct_call_time
[
self
]
+=
dt_call
self
.
maker
.
mode
.
fct_call
[
self
]
+=
1
self
.
maker
.
mode
.
fn_time
+=
dt_fn
self
.
maker
.
mode
.
call_time
+=
dt_call
return
self
.
outputs_storage
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
del
state
[
'fn'
]
del
state
[
'input_storage'
]
del
state
[
'outputs_storage'
]
del
state
[
'maker'
]
.
env
return
state
def
__setstate__
(
self
):
self
.
__dict__
=
state
name
=
self
.
name
mode
=
self
.
maker
.
mode
inputs
=
self
.
maker
.
inputs
outputs
=
self
.
maker
.
outputs
nonmutable_indices
=
self
.
maker
.
nonmutable_indices
new_inputs
,
new_outputs
=
gof
.
graph
.
clone
(
inputs
,
ouputs
)
env
=
gof
.
env
.
Env
(
new_inputs
,
new_outputs
)
nonmutable
=
[]
for
idx
in
nonmutable_indices
:
nonmutable
.
append
(
new_inputs
[
idx
]
)
env
.
extend
(
Supervisor
(
inp
for
inp
in
nonmutable
if
not
(
hasattr
(
env
,
'destroyers'
)
and
env
.
destroyers
(
inp
))))
# If named nodes are replaced, keep the name
env
.
extend
(
gof
.
toolbox
.
PreserveNames
())
optimizer
,
linker
=
mode
.
optimizer
,
copy
.
copy
(
mode
.
linker
)
# optimize the env
t0
=
time
.
time
()
optimizer
(
env
)
_logger
.
debug
(
'Optimizing took
%
f seconds'
%
(
time
.
time
()
-
t0
))
if
not
hasattr
(
linker
,
'accept'
):
raise
ValueError
(
(
"'linker' parameter of FunctionFactory "
"should be a Linker with an accept method "
"or one of
%
s"
)
%
mode_module
.
predefined_linkers
.
keys
())
my_linker
=
linker
.
accept
(
env
)
input_storage
=
[]
output_storage
=
[]
for
input
in
inputs
:
input_storage
+=
[[
None
]]
for
output
in
outputs
:
output_storage
+=
[[
None
]]
t0
=
time
.
time
()
_fn
,
_i
,
_o
=
my_linker
.
make_thunk
(
input_storage
=
input_storage
,
output_storage
=
output_storage
)
_logger
.
debug
(
'Linking took
%
f seconds'
%
(
time
.
time
()
-
t0
))
fn
=
ScanInnerFunction
(
_fn
,
input_storage
,
output_storage
,
env
)
t2
=
time
.
time
()
self
.
fn
=
_fn
self
.
input_storage
=
input_storage
self
.
outputs_storage
=
output_storage
if
hasattr
(
mode
,
'fct_call_time'
):
mode
.
fct_call_time
.
setdefault
(
fn
,
0
)
if
hasattr
(
mode
,
'fct_call'
):
mode
.
fct_call
.
set_default
(
fn
,
0
)
def
scan_function
(
inputs
,
outputs
,
nonmutable_indices
=
None
,
mode
=
None
,
name
=
None
,
slices
=
0
):
"""
``Constructor`` of the ScanInnerFunction ( a simplified version of
theano.function ). This should only be used internally by Scan.
:param inputs: theano variable that represent the input of the function
:param outputs: theano expression that represents the outputs of the
function
:param nonmutable_indices: the subset of indices corresponding to
nonmutable inputs
:param mode: compilation mode for the function
:param name: name of the function
"""
t1
=
time
.
time
()
mode
=
mode_module
.
get_mode
(
mode
)
if
isinstance
(
mode
,
(
list
,
tuple
)):
# "mode comparison" semantics
_logger
.
warning
(
'Passing multiple modes is deprecated (20091019)'
)
if
not
mode
:
raise
ValueError
(
"Please provide at least one mode."
)
else
:
mode
=
mode
[
0
]
## Replacing the Function Maker
if
not
isinstance
(
outputs
,
(
list
,
tuple
)):
outputs
=
[
outputs
]
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
=
[
inputs
]
new_inputs
,
new_outputs
=
gof
.
graph
.
clone
(
inputs
,
outputs
)
env
=
gof
.
env
.
Env
(
new_inputs
,
new_outputs
)
nonmutable
=
[]
for
idx
in
nonmutable_indices
:
nonmutable
.
append
(
new_inputs
[
idx
]
)
env
.
extend
(
Supervisor
(
inp
for
inp
in
nonmutable
if
not
(
hasattr
(
env
,
'destroyers'
)
and
env
.
destroyers
(
inp
))))
# If named nodes are replaced, keep the name
env
.
extend
(
gof
.
toolbox
.
PreserveNames
())
optimizer
,
linker
=
mode
.
optimizer
,
copy
.
copy
(
mode
.
linker
)
# optimize the env
t0
=
time
.
time
()
optimizer
(
env
)
_logger
.
debug
(
'Optimizing took
%
f seconds'
%
(
time
.
time
()
-
t0
))
mask
=
[
0
for
x
in
env
.
outputs
[
slices
:]
]
for
i
,
out
in
enumerate
(
env
.
outputs
):
if
(
out
in
env
.
inputs
or
isinstance
(
out
,
tensor
.
Constant
)):
env
.
change_input
(
'output'
,
i
,
Clone
()(
out
)
)
for
i
in
xrange
(
len
(
env
.
outputs
[
slices
:])):
views_of_output_i
=
set
()
view_tree_set
(
alias_root
(
env
.
outputs
[
i
]),
views_of_output_i
)
copied
=
False
# do not allow outputs to be aliased
for
j
in
xrange
(
i
+
1
,
len
(
env
.
outputs
)):
if
env
.
outputs
[
j
]
in
views_of_output_i
:
mask
[
i
]
=
1
copied
=
True
break
if
not
copied
:
for
input_j
in
env
.
inputs
:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by e.g. in-place computations
if
hasattr
(
env
,
'get_destroyers_of'
)
and
env
.
get_destroyers_of
(
input_j
):
continue
if
input_j
in
views_of_output_i
:
mask
[
i
]
=
1
break
if
not
hasattr
(
linker
,
'accept'
):
raise
ValueError
(
(
"'linker' parameter of FunctionFactory "
"should be a Linker with an accept method "
"or one of
%
s"
)
%
mode_module
.
predefined_linkers
.
keys
())
my_linker
=
linker
.
accept
(
env
)
input_storage
=
[]
output_storage
=
[]
for
input
in
inputs
:
input_storage
+=
[[
None
]]
for
output
in
outputs
:
output_storage
+=
[[
None
]]
t0
=
time
.
time
()
_fn
,
_i
,
_o
=
my_linker
.
make_thunk
(
input_storage
=
input_storage
,
output_storage
=
output_storage
)
_logger
.
debug
(
'Linking took
%
f seconds'
%
(
time
.
time
()
-
t0
))
if
hasattr
(
mode
,
'apply_time'
):
for
i
,
node
in
enumerate
(
env
.
toposort
()):
mode
.
apply_time
[(
i
,
node
)]
=
0.0
assert
len
(
_fn
.
thunk_groups
[
i
])
==
1
mode
.
op_cimpl
[
node
.
op
]
=
hasattr
(
_fn
.
thunk_groups
[
i
][
0
],
'cthunk'
)
fn
=
ScanInnerFunction
(
_fn
,
input_storage
,
output_storage
,
env
,
inputs
,
outputs
,
nonmutable_indices
,
mode
,
name
)
t2
=
time
.
time
()
if
hasattr
(
mode
,
'compile_time'
):
mode
.
compile_time
+=
t2
-
t1
if
hasattr
(
mode
,
'fct_call_time'
):
mode
.
fct_call_time
.
setdefault
(
fn
,
0
)
if
hasattr
(
mode
,
'fct_call'
):
mode
.
fct_call
.
setdefault
(
fn
,
0
)
return
mask
,
fn
# Hashing a dictionary/list/tuple by xoring the hash of each element
# Hashing a dictionary/list/tuple by xoring the hash of each element
def
hash_listsDictsTuples
(
x
):
def
hash_listsDictsTuples
(
x
):
...
@@ -517,33 +246,8 @@ def expand( tensor_var, size):
...
@@ -517,33 +246,8 @@ def expand( tensor_var, size):
class
Clone
(
Op
):
def
__init__
(
self
):
self
.
view_map
=
{
0
:[
0
]}
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
__str__
(
self
):
return
'clone[as_view]'
def
make_node
(
self
,
*
inputs
):
x
=
inputs
[
0
]
return
Apply
(
self
,
inputs
,
[
x
.
type
()]
)
def
perform
(
self
,
node
,
args
,
outs
):
outs
[
0
][
0
]
=
args
[
0
]
def
infer_shape
(
self
,
node
,
input_shapes
):
return
input_shapes
def
grad
(
self
,
args
,
g_outs
):
return
g_outs
cloneOp
=
Clone
()
def
equal_computations
(
x
,
y
,
strict
=
False
):
def
equal_computations
(
x
,
y
,
strict
=
False
):
'''
'''
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
ff2a27d5
...
@@ -2007,7 +2007,34 @@ class T_Scan(unittest.TestCase):
...
@@ -2007,7 +2007,34 @@ class T_Scan(unittest.TestCase):
assert
scan1
.
owner
.
op
==
scan2
.
owner
.
op
assert
scan1
.
owner
.
op
==
scan2
.
owner
.
op
assert
hash
(
scan1
.
owner
.
op
)
==
hash
(
scan2
.
owner
.
op
)
assert
hash
(
scan1
.
owner
.
op
)
==
hash
(
scan2
.
owner
.
op
)
def
test_same
(
self
):
# This test is checking a bug discovered by Arnaud and it is based
# on his code
x
=
theano
.
tensor
.
fmatrix
(
'x'
)
mem_val
=
numpy
.
zeros
((
2
,),
dtype
=
'float32'
)
memory
=
theano
.
shared
(
mem_val
.
copy
())
W
=
theano
.
shared
(
numpy
.
random
.
random
((
5
,
2
))
.
astype
(
'float32'
))
def
f
(
inp
,
mem
):
i
=
theano
.
tensor
.
join
(
0
,
inp
,
mem
)
d
=
theano
.
tensor
.
dot
(
i
,
W
)
return
d
,
d
outs
,
updts
=
theano
.
scan
(
f
,
sequences
=
[
x
],
non_sequences
=
[],
outputs_info
=
[
None
,
memory
])
f
=
theano
.
function
([
x
],
outs
[
0
])
f2
=
theano
.
function
([
x
],
outs
[
1
])
x_val
=
numpy
.
random
.
random
((
4
,
3
))
.
astype
(
'float32'
)
f_vals
=
f
(
x_val
)
memory
.
set_value
(
mem_val
.
copy
())
f2_vals
=
f2
(
x_val
)
assert
numpy
.
allclose
(
f_vals
,
f2_vals
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#'''
#'''
...
...
theano/sparse/tests/test_basic.py
浏览文件 @
ff2a27d5
...
@@ -190,6 +190,63 @@ class T_AddMul(unittest.TestCase):
...
@@ -190,6 +190,63 @@ class T_AddMul(unittest.TestCase):
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
numpy
.
array
([[
1
,
0
],
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
numpy
.
array
([[
1
,
0
],
[
9
,
0
],
[
0
,
36
]])))
[
9
,
0
],
[
0
,
36
]])))
def
test_upcast
(
self
):
array1
=
numpy
.
array
([[
1
,
0
],
[
3
,
0
],
[
0
,
6
]],
dtype
=
'float32'
)
array2
=
numpy
.
array
([[
1
,
0
],
[
3
,
0
],
[
0
,
6
]],
dtype
=
'int32'
)
array3
=
numpy
.
array
([[
1
,
0
],
[
3
,
0
],
[
0
,
6
]],
dtype
=
'int8'
)
# AddSS and MulSS
for
mtype
in
_mtypes
:
a
=
mtype
(
array1
)
aR
=
as_sparse_variable
(
a
)
b
=
mtype
(
array2
)
bR
=
as_sparse_variable
(
b
)
c
=
mtype
(
array3
)
cR
=
as_sparse_variable
(
c
)
# Ops that do not upcast
self
.
assertRaises
(
NotImplementedError
,
add
,
aR
,
bR
)
self
.
assertRaises
(
NotImplementedError
,
add
,
bR
,
aR
)
self
.
assertRaises
(
NotImplementedError
,
add
,
bR
,
cR
)
self
.
assertRaises
(
NotImplementedError
,
add
,
cR
,
bR
)
self
.
assertRaises
(
NotImplementedError
,
add
,
aR
,
cR
)
self
.
assertRaises
(
NotImplementedError
,
add
,
cR
,
aR
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
aR
,
bR
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
bR
,
aR
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
bR
,
cR
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
cR
,
bR
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
aR
,
cR
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
cR
,
aR
)
# AddSD and MulSD
for
mtype
in
_mtypes
:
a
=
mtype
(
array1
)
a_sv
=
as_sparse_variable
(
a
)
a_dv
=
tensor
.
as_tensor_variable
(
array1
)
b
=
mtype
(
array2
)
b_sv
=
as_sparse_variable
(
b
)
b_dv
=
tensor
.
as_tensor_variable
(
array2
)
c
=
mtype
(
array3
)
c_sv
=
as_sparse_variable
(
c
)
c_dv
=
tensor
.
as_tensor_variable
(
array3
)
# add does not upcast
self
.
assertRaises
(
NotImplementedError
,
add
,
a_sv
,
b_dv
)
self
.
assertRaises
(
NotImplementedError
,
add
,
b_sv
,
a_dv
)
self
.
assertRaises
(
NotImplementedError
,
add
,
b_sv
,
c_dv
)
self
.
assertRaises
(
NotImplementedError
,
add
,
c_sv
,
b_dv
)
self
.
assertRaises
(
NotImplementedError
,
add
,
a_sv
,
c_dv
)
self
.
assertRaises
(
NotImplementedError
,
add
,
c_sv
,
a_dv
)
# mul upcasts the dense input if needed
self
.
assertRaises
(
NotImplementedError
,
mul
,
a_sv
,
b_dv
)
self
.
assertRaises
(
NotImplementedError
,
mul
,
b_sv
,
a_dv
)
assert
mul
(
b_sv
,
c_dv
)
.
dtype
==
'int32'
self
.
assertRaises
(
NotImplementedError
,
mul
,
c_sv
,
b_dv
)
assert
mul
(
a_sv
,
c_dv
)
.
dtype
==
'float32'
self
.
assertRaises
(
NotImplementedError
,
mul
,
c_sv
,
a_dv
)
class
T_conversion
(
unittest
.
TestCase
):
class
T_conversion
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
...
theano/tensor/basic.py
浏览文件 @
ff2a27d5
...
@@ -3753,7 +3753,7 @@ class Reshape(Op):
...
@@ -3753,7 +3753,7 @@ class Reshape(Op):
shp_orig
=
shp
shp_orig
=
shp
shp
=
as_tensor_variable
(
shp
,
ndim
=
1
)
shp
=
as_tensor_variable
(
shp
,
ndim
=
1
)
if
not
shp
.
dtype
.
startswith
(
'int'
):
if
not
shp
.
dtype
.
startswith
(
'int'
):
raise
TypeError
(
"Shape must be integers"
)
raise
TypeError
(
"Shape must be integers"
,
shp
,
shp
.
dtype
)
assert
shp
.
ndim
==
1
assert
shp
.
ndim
==
1
if
isinstance
(
shp
,
TensorConstant
):
if
isinstance
(
shp
,
TensorConstant
):
bcast
=
[
s
==
1
for
s
in
shp
.
data
]
bcast
=
[
s
==
1
for
s
in
shp
.
data
]
...
@@ -3788,9 +3788,18 @@ class Reshape(Op):
...
@@ -3788,9 +3788,18 @@ class Reshape(Op):
g_out
,
=
grads
g_out
,
=
grads
return
[
reshape
(
g_out
,
shape
(
x
),
ndim
=
x
.
ndim
),
None
]
return
[
reshape
(
g_out
,
shape
(
x
),
ndim
=
x
.
ndim
),
None
]
def
infer_shape
(
self
,
node
,
ishapes
):
def
infer_shape
(
self
,
node
,
ishapes
):
#we can't just put node.inputs[1] as not all op support interation
# inputs[1] can contain at most one value of '-1', meaning the actual
#and this is needed in the ShapeOptimizer
# shape of the output will be automatically computed by reshape, so
return
[
tuple
([
node
.
inputs
[
1
][
i
]
for
i
in
range
(
self
.
ndim
)])]
# that the total number of elements stays the same.
# TODO: Maybe put that formula here?
# It's not trivial, because we would have to check if the product of
# all the non-minus-one shapes is a divisor of the product of the
# original shapes.
return
[
tuple
([
switch
(
eq
(
node
.
inputs
[
1
][
i
],
-
1
),
theano
.
tensor
.
opt
.
Shape_i
(
i
)(
node
.
outputs
[
0
]),
node
.
inputs
[
1
][
i
])
for
i
in
range
(
self
.
ndim
)]
)]
def
reshape
(
x
,
newshape
,
ndim
=
None
,
name
=
None
):
def
reshape
(
x
,
newshape
,
ndim
=
None
,
name
=
None
):
if
ndim
is
None
:
if
ndim
is
None
:
...
@@ -4739,10 +4748,10 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
...
@@ -4739,10 +4748,10 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False,
ret
=
[]
ret
=
[]
for
p
in
wrt
:
for
p
in
wrt
:
if
p
not
in
gmap
and
not
assume_continuously_differentiable
:
if
p
not
in
gmap
and
not
assume_continuously_differentiable
:
raise
ValueError
((
"grad method was asked to compute the gra
id
ent "
raise
ValueError
((
"grad method was asked to compute the gra
di
ent "
"with respect to a variable that is not part of "
"with respect to a variable that is not part of "
"the computational graph of the cost or is used "
"the computational graph of the cost
,
or is used "
"by a non-differentiable operator
"
),
p
)
"by a non-differentiable operator
"
),
p
)
else
:
else
:
ret
.
append
(
gmap
.
get
(
p
,
zeros_like
(
p
)))
ret
.
append
(
gmap
.
get
(
p
,
zeros_like
(
p
)))
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
ff2a27d5
...
@@ -3290,9 +3290,11 @@ class T_op_cache(unittest.TestCase):
...
@@ -3290,9 +3290,11 @@ class T_op_cache(unittest.TestCase):
a
=
numpy
.
random
.
rand
(
5
,
2
)
.
astype
(
config
.
floatX
)
a
=
numpy
.
random
.
rand
(
5
,
2
)
.
astype
(
config
.
floatX
)
self
.
assertTrue
(
numpy
.
all
(
fn_py
(
a
)
==
fn_c_or_py
(
a
)))
self
.
assertTrue
(
numpy
.
all
(
fn_py
(
a
)
==
fn_c_or_py
(
a
)))
class
T_reshape
(
unittest
.
TestCase
):
def
setUp
(
self
):
utt
.
seed_rng
()
def
test_reshape
():
def
test_reshape
(
self
):
a
=
dvector
()
a
=
dvector
()
b
=
dmatrix
()
b
=
dmatrix
()
d
=
dmatrix
()
d
=
dmatrix
()
...
@@ -3361,9 +3363,30 @@ def test_reshape():
...
@@ -3361,9 +3363,30 @@ def test_reshape():
assert
numpy
.
all
(
f
(
numpy
.
asarray
([[
0
,
1
,
2
],[
3
,
4
,
5
]]))
==
numpy
.
asarray
([[[
0
],[
1
],[
2
]],[[
3
],[
4
],[
5
]]]))
assert
numpy
.
all
(
f
(
numpy
.
asarray
([[
0
,
1
,
2
],[
3
,
4
,
5
]]))
==
numpy
.
asarray
([[[
0
],[
1
],[
2
]],[[
3
],[
4
],[
5
]]]))
assert
f
.
maker
.
env
.
toposort
()[
-
2
]
.
outputs
[
0
]
.
type
.
broadcastable
==
(
False
,
False
,
True
)
assert
f
.
maker
.
env
.
toposort
()[
-
2
]
.
outputs
[
0
]
.
type
.
broadcastable
==
(
False
,
False
,
True
)
assert
numpy
.
all
(
f_sub
(
a_val
,
b_val
)
==
[
2
,
3
])
assert
numpy
.
all
(
f_sub
(
a_val
,
b_val
)
==
[
2
,
3
])
def
test_infer_shape
(
self
):
a
=
matrix
(
'a'
)
shapes
=
ivector
(
'shapes'
)
ndim
=
2
r
=
a
.
reshape
(
shapes
,
ndim
=
2
)
z
=
zeros_like
(
r
)
f
=
function
([
a
,
shapes
],
z
.
shape
)
rng
=
numpy
.
random
.
RandomState
(
seed
=
utt
.
fetch_seed
())
a_val
=
rng
.
uniform
(
size
=
(
3
,
4
))
.
astype
(
config
.
floatX
)
self
.
assertTrue
((
f
(
a_val
,
[
4
,
3
])
==
[
4
,
3
])
.
all
())
self
.
assertTrue
((
f
(
a_val
,
[
-
1
,
3
])
==
[
4
,
3
])
.
all
())
self
.
assertTrue
((
f
(
a_val
,
[
4
,
-
1
])
==
[
4
,
3
])
.
all
())
self
.
assertRaises
(
ValueError
,
f
,
a_val
,
[
-
1
,
5
])
self
.
assertRaises
(
ValueError
,
f
,
a_val
,
[
7
,
-
1
])
self
.
assertRaises
(
ValueError
,
f
,
a_val
,
[
7
,
5
])
self
.
assertRaises
(
ValueError
,
f
,
a_val
,
[
-
1
,
-
1
])
def
test_make_column_matrix_broadcastable
():
def
test_make_column_matrix_broadcastable
():
# The goal of the operation made by `b` is to ensure the second dimension
# The goal of the operation made by `b` is to ensure the second dimension
# of the column matrix is broadcastable.
# of the column matrix is broadcastable.
...
...
theano/tensor/tests/test_opt.py
浏览文件 @
ff2a27d5
...
@@ -2787,11 +2787,6 @@ def test_local_tensor_scalar_tensor():
...
@@ -2787,11 +2787,6 @@ def test_local_tensor_scalar_tensor():
assert
len
(
cast_nodes
)
==
0
assert
len
(
cast_nodes
)
==
0
f
(
0
)
f
(
0
)
@dec.knownfailureif
(
isinstance
(
theano
.
compile
.
mode
.
get_default_mode
(),
theano
.
compile
.
debugmode
.
DebugMode
),
(
"This test fails in DEBUG_MODE, but the generated code is OK. "
"It is actually a problem of DEBUG_MODE, see #624."
))
def
test_local_scalar_tensor_scalar
():
def
test_local_scalar_tensor_scalar
():
dtypes
=
[
'int8'
,
'int16'
,
'int32'
,
'int64'
,
dtypes
=
[
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论