Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4bbac540
提交
4bbac540
authored
3月 15, 2011
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
8ea860cd
59561698
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
133 行增加
和
100 行删除
+133
-100
graph.py
theano/gof/graph.py
+1
-1
scan.py
theano/scan.py
+40
-42
basic.py
theano/tensor/basic.py
+55
-48
elemwise.py
theano/tensor/elemwise.py
+37
-9
没有找到文件。
theano/gof/graph.py
浏览文件 @
4bbac540
...
@@ -417,7 +417,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
...
@@ -417,7 +417,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
raise
ValueError
(
'mode should be bfs or dfs'
,
mode
)
raise
ValueError
(
'mode should be bfs or dfs'
,
mode
)
rval_set
=
set
()
rval_set
=
set
()
rval_list
=
list
()
rval_list
=
list
()
if
mode
is
'bfs'
:
start_pop
=
start
.
popleft
if
mode
==
'bfs'
:
start_pop
=
start
.
popleft
else
:
start_pop
=
start
.
pop
else
:
start_pop
=
start
.
pop
expand_inv
=
{}
expand_inv
=
{}
while
start
:
while
start
:
...
...
theano/scan.py
浏览文件 @
4bbac540
...
@@ -106,7 +106,7 @@ def map( fn
...
@@ -106,7 +106,7 @@ def map( fn
:param go_backwards: Boolean value that decides the direction of
:param go_backwards: Boolean value that decides the direction of
iteration. True means that sequences are parsed
iteration. True means that sequences are parsed
from the end towards the begining, while False
from the end towards the begin
n
ing, while False
is the other way around.
is the other way around.
:param mode: See ``scan``.
:param mode: See ``scan``.
...
@@ -301,7 +301,7 @@ def scan( fn
...
@@ -301,7 +301,7 @@ def scan( fn
scan)
scan)
The order of the sequences is the same as the one in the list
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the sa
n
e
`sequences` given to scan. The order of the outputs is the sa
m
e
as the order of ``output_info``. For any sequence or output the
as the order of ``output_info``. For any sequence or output the
order of the time slices is the same as the order of the time
order of the time slices is the same as the order of the time
taps provided. For example if one writes the following :
taps provided. For example if one writes the following :
...
@@ -314,7 +314,7 @@ def scan( fn
...
@@ -314,7 +314,7 @@ def scan( fn
, outputs_info = [ dict( Output1, taps = [-3,-5])
, outputs_info = [ dict( Output1, taps = [-3,-5])
, dict( Output2, taps = None)
, dict( Output2, taps = None)
, Output3 ]
, Output3 ]
, non_sequences = [ Argument1, Argument
2])
, non_sequences = [ Argument1, Argument2])
``fn`` should expect the following arguments in this given order:
``fn`` should expect the following arguments in this given order:
...
@@ -341,7 +341,7 @@ def scan( fn
...
@@ -341,7 +341,7 @@ def scan( fn
`fn` should return an update dictionary ( that tells how to
`fn` should return an update dictionary ( that tells how to
update any shared variable after each iteration ste). The
update any shared variable after each iteration ste). The
dictionary can optionally be given as a list of tuples. There is
dictionary can optionally be given as a list of tuples. There is
no constraint on the order of these two list, ``fn`` can return
no constraint on the order of these two list
s
, ``fn`` can return
either ``(outputs_list, update_dictionary)`` or ``(update_dictionary,
either ``(outputs_list, update_dictionary)`` or ``(update_dictionary,
outputs_list)`` or just one of the two (in case the other is
outputs_list)`` or just one of the two (in case the other is
empty).
empty).
...
@@ -369,7 +369,7 @@ def scan( fn
...
@@ -369,7 +369,7 @@ def scan( fn
:param outputs_info:
:param outputs_info:
``outputs_info`` is the list of Theano variables or dictionaries
``outputs_info`` is the list of Theano variables or dictionaries
describing the initial state of the outputs computed
describing the initial state of the outputs computed
recurrently. When this initial state
s are given as dictionary
recurrently. When this initial state
is given as a dictionary,
optional information can be provided about the output corresponding
optional information can be provided about the output corresponding
to these initial states. The dictionary should have the following
to these initial states. The dictionary should have the following
keys:
keys:
...
@@ -388,11 +388,11 @@ def scan( fn
...
@@ -388,11 +388,11 @@ def scan( fn
the initial state, which in this case should have the shape
the initial state, which in this case should have the shape
(5,)+output.shape. If this variable containing the initial
(5,)+output.shape. If this variable containing the initial
state is called ``init_y`` then ``init_y[0]`` *corresponds to*
state is called ``init_y`` then ``init_y[0]`` *corresponds to*
``output[-5]``; ``init_y[1]`` *correponds to* ``output[-4]``;
``output[-5]``; ``init_y[1]`` *corre
s
ponds to* ``output[-4]``;
``init_y[2]`` corresponds to ``output[-3]``; ``init_y[3]``
``init_y[2]`` corresponds to ``output[-3]``; ``init_y[3]``
coresponds to ``output[-2]``; ``init_y[4]`` corresponds to
coresponds to ``output[-2]``; ``init_y[4]`` corresponds to
``output[-1]``. While this order might seem strange, it comes
``output[-1]``. While this order might seem strange, it comes
natural from splitting an array at a given point. Assume that
natural
ly
from splitting an array at a given point. Assume that
we have a array ``x``, and we choose ``k`` to be time step
we have a array ``x``, and we choose ``k`` to be time step
``0``. Then our initial state would be ``x[:k]``, while the
``0``. Then our initial state would be ``x[:k]``, while the
output will be ``x[k:]``. Looking at this split, elements in
output will be ``x[k:]``. Looking at this split, elements in
...
@@ -401,17 +401,10 @@ def scan( fn
...
@@ -401,17 +401,10 @@ def scan( fn
``fn``. They are provided as a list of *negative* integers,
``fn``. They are provided as a list of *negative* integers,
where a value ``k`` implies that at iteration step ``t`` scan will
where a value ``k`` implies that at iteration step ``t`` scan will
pass to ``fn`` the slice ``t+k``.
pass to ``fn`` the slice ``t+k``.
* ``inplace`` -- One of the Theano variables provided as
* ``inplace`` -- DEPRECATED. Previously, one could specify with this
``sequences``. ``scan`` will try to compute this output *in
option whether the output should overwrite some particular input,
place* of the provided input *iff* it respects the following
but it is now inferred automatically. If you specify this option
constraints:
it will be ignored.
* There is no other output that is denied to be computed in
place for whatever reason.
* ``fn`` is not using past taps of the input sequence that
will get overwritten by the output
* ``return_steps`` -- Integer representing the number of steps
* ``return_steps`` -- Integer representing the number of steps
to return for the current steps. For example, if ``k`` is
to return for the current steps. For example, if ``k`` is
provided, ``scan`` will return ``output[-k:]``. This is meant as a
provided, ``scan`` will return ``output[-k:]``. This is meant as a
...
@@ -422,7 +415,7 @@ def scan( fn
...
@@ -422,7 +415,7 @@ def scan( fn
* ``store_steps`` -- Integer representing the number of
* ``store_steps`` -- Integer representing the number of
intermediate steps ``scan`` should use for a given output. Use
intermediate steps ``scan`` should use for a given output. Use
this key only if you really know what you are doing. In general
this key only if you really know what you are doing. In general
it is recommended to let scan decide for you the am
m
ount of memory
it is recommended to let scan decide for you the amount of memory
it should use.
it should use.
``scan`` will follow this logic if partial information is given:
``scan`` will follow this logic if partial information is given:
...
@@ -437,12 +430,12 @@ def scan( fn
...
@@ -437,12 +430,12 @@ def scan( fn
* If you wrap an output in a dictionary but you do not provide any
* If you wrap an output in a dictionary but you do not provide any
initial state, it assumes that you are not using any form of
initial state, it assumes that you are not using any form of
taps.
taps.
* If you provide
a
``None`` instead of a variable or a dictionary
* If you provide ``None`` instead of a variable or a dictionary
``scan`` assumes that you will not use any taps for this output
``scan`` assumes that you will not use any taps for this output
(like for example in case of a map)
(like for example in case of a map)
If ``outputs_info`` is an empty list or None, ``scan`` assumes
If ``outputs_info`` is an empty list or None, ``scan`` assumes
that no tap is used for any of the o
tu
puts. If information is
that no tap is used for any of the o
ut
puts. If information is
provided just for a subset of the outputs an exception is
provided just for a subset of the outputs an exception is
raised (because there is no convention on how scan should map
raised (because there is no convention on how scan should map
the provided information to the outputs of ``fn``)
the provided information to the outputs of ``fn``)
...
@@ -450,8 +443,8 @@ def scan( fn
...
@@ -450,8 +443,8 @@ def scan( fn
:param non_sequences:
:param non_sequences:
``non_sequences`` is the list of arguments that are passed to
``non_sequences`` is the list of arguments that are passed to
``fn`` at each step
s. One can opt to exclude
shared variables
``fn`` at each step
. It is not necessary to list
shared variables
used in ``fn``
from this list
.
used in ``fn``
here, since they will be identified automatically
.
:param n_steps:
:param n_steps:
...
@@ -469,9 +462,10 @@ def scan( fn
...
@@ -469,9 +462,10 @@ def scan( fn
:param truncate_gradient:
:param truncate_gradient:
``truncate_gradient`` is the number of steps to use in truncated
``truncate_gradient`` is the number of steps to use in truncated
BPTT. If you compute gradients through a scan op, they are
BPTT (backpropagation through time). If you compute gradients
through a scan op, they are
computed using backpropagation through time. By providing a
computed using backpropagation through time. By providing a
different value th
e
n -1, you choose to use truncated BPTT instead
different value th
a
n -1, you choose to use truncated BPTT instead
of classical BPTT, where you go for only ``truncate_gradient``
of classical BPTT, where you go for only ``truncate_gradient``
number of steps back in time.
number of steps back in time.
...
@@ -512,33 +506,32 @@ def scan( fn
...
@@ -512,33 +506,32 @@ def scan( fn
"""
"""
# General observation : this code is executed only once, at creation
# General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about
# of the computational graph, so we don't yet need to be smart about
# anything (
to speed things up)
# anything (to speed things up)
# check if inputs are just single variables instead of lists
# check if inputs are just single variables instead of lists
if
not
(
type
(
sequences
)
in
(
list
,
tuple
))
and
sequences
!=
None
:
if
sequences
==
None
:
seqs
=
[
sequences
]
elif
sequences
==
None
:
seqs
=
[]
seqs
=
[]
elif
not
(
type
(
sequences
)
in
(
list
,
tuple
)):
seqs
=
[
sequences
]
else
:
else
:
seqs
=
sequences
seqs
=
sequences
if
not
(
type
(
outputs_info
)
in
(
list
,
tuple
))
and
outputs_info
!=
None
:
if
outputs_info
==
None
:
outs_info
=
[
outputs_info
]
elif
outputs_info
==
None
:
outs_info
=
[]
outs_info
=
[]
elif
not
(
type
(
outputs_info
)
in
(
list
,
tuple
)):
outs_info
=
[
outputs_info
]
else
:
else
:
outs_info
=
outputs_info
outs_info
=
outputs_info
if
(
not
(
type
(
non_sequences
)
in
(
list
,
tuple
))
if
non_sequences
==
None
:
and
non_sequences
!=
None
):
non_seqs
=
[
non_sequences
]
elif
non_sequences
==
None
:
non_seqs
=
[]
non_seqs
=
[]
elif
not
(
type
(
non_sequences
)
in
(
list
,
tuple
)):
non_seqs
=
[
non_sequences
]
else
:
else
:
non_seqs
=
non_sequences
non_seqs
=
non_sequences
# If we provided a known number of steps (
before compilation)
# If we provided a known number of steps (before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
# To do that we check here to see the nature of n_steps
...
@@ -570,7 +563,7 @@ def scan( fn
...
@@ -570,7 +563,7 @@ def scan( fn
sequences_taps
=
{}
sequences_taps
=
{}
outputs_taps
=
{}
outputs_taps
=
{}
# Assume that for any output we want to store everythin that it produces
# Assume that for any output we want to store everythin
g
that it produces
store_steps
=
[]
store_steps
=
[]
return_steps
=
{}
return_steps
=
{}
...
@@ -591,8 +584,8 @@ def scan( fn
...
@@ -591,8 +584,8 @@ def scan( fn
# See if the user actually provided the None value to taps,
# See if the user actually provided the None value to taps,
# which would indicate that the sequence was provided but
# which would indicate that the sequence was provided but
# not used by the internal function; Only if the user has
# not used by the internal function; Only if the user has
# not provided anything add the defaul [0]
# not provided anything add the defaul
t
[0]
#
Possible reason to provide a squence and not use it
is
#
A possible reason to provide a sequence and not use it
is
# if you want to compute the output
# if you want to compute the output
# inplace of this input; it is a very unlikely behaviour but
# inplace of this input; it is a very unlikely behaviour but
# we do want to cover it for completeness
# we do want to cover it for completeness
...
@@ -635,7 +628,7 @@ def scan( fn
...
@@ -635,7 +628,7 @@ def scan( fn
raise
ValueError
(
'If you are using slices of an output you need to '
\
raise
ValueError
(
'If you are using slices of an output you need to '
\
'provide an initial state for it'
,
outs_info
[
i
])
'provide an initial state for it'
,
outs_info
[
i
])
# if there is an intial state but no tap, we will add the default value
# if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen
# for taps, namely [-1] ( previous value); not
e
that this will happen
# even though you have provided for taps the value None, which is a bit
# even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to
# strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message
# use it ? ), just that in that case we will throw in a warning message
...
@@ -658,9 +651,14 @@ def scan( fn
...
@@ -658,9 +651,14 @@ def scan( fn
if
outs_info
[
i
]
.
get
(
'taps'
,
None
):
if
outs_info
[
i
]
.
get
(
'taps'
,
None
):
# Create a separate outputs_taps dictionary with all the outputs taps; This
# Create a separate outputs_taps dictionary with all the outputs taps; This
# is how the Scan Op expects this information, separ
e
ted from the variables
# is how the Scan Op expects this information, separ
a
ted from the variables
outputs_taps
[
i
]
=
outs_info
[
i
][
'taps'
]
outputs_taps
[
i
]
=
outs_info
[
i
][
'taps'
]
if
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
if
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
warning
(
"DEPRECATED: you should not set the inplace parameter for an output in scan(...). "
"This can cause problems for the early stages of the optimizer "
"and there is a late optimization which automatically figures it out."
)
# The same is true for the inplace info; it has to go into a separate
# The same is true for the inplace info; it has to go into a separate
# dictionary based on index; Note that the input we're replacing should also
# dictionary based on index; Note that the input we're replacing should also
# come as an index, therefore we have to look for it at this point
# come as an index, therefore we have to look for it at this point
...
...
theano/tensor/basic.py
浏览文件 @
4bbac540
...
@@ -1336,8 +1336,9 @@ def _redefine_asRoutine(real_symbol_value):
...
@@ -1336,8 +1336,9 @@ def _redefine_asRoutine(real_symbol_value):
return
real_symbol_value
return
real_symbol_value
return
decorator
return
decorator
def
_scal_elemwise
(
symbol
):
def
_scal_elemwise
_with_nfunc
(
nfunc
,
nin
,
nout
):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
def
construct
(
symbol
):
symbolname
=
symbol
.
__name__
symbolname
=
symbol
.
__name__
inplace
=
symbolname
.
endswith
(
'_inplace'
)
inplace
=
symbolname
.
endswith
(
'_inplace'
)
if
inplace
:
if
inplace
:
...
@@ -1349,10 +1350,10 @@ def _scal_elemwise(symbol):
...
@@ -1349,10 +1350,10 @@ def _scal_elemwise(symbol):
if
inplace
:
if
inplace
:
scalar_op
=
getattr
(
scal
,
symbolname
[:
-
len
(
'_inplace'
)])
scalar_op
=
getattr
(
scal
,
symbolname
[:
-
len
(
'_inplace'
)])
inplace_scalar_op
=
scalar_op
.
__class__
(
scal
.
transfer_type
(
0
))
inplace_scalar_op
=
scalar_op
.
__class__
(
scal
.
transfer_type
(
0
))
rval
=
elemwise
.
Elemwise
(
inplace_scalar_op
,
{
0
:
0
},
name
=
n
)
rval
=
elemwise
.
Elemwise
(
inplace_scalar_op
,
{
0
:
0
},
name
=
n
,
nfunc_spec
=
((
nfunc
,
nin
,
nout
)
if
nfunc
else
None
)
)
else
:
else
:
scalar_op
=
getattr
(
scal
,
symbolname
)
scalar_op
=
getattr
(
scal
,
symbolname
)
rval
=
elemwise
.
Elemwise
(
scalar_op
,
name
=
n
)
rval
=
elemwise
.
Elemwise
(
scalar_op
,
name
=
n
,
nfunc_spec
=
((
nfunc
,
nin
,
nout
)
if
nfunc
else
None
)
)
if
getattr
(
symbol
,
'__doc__'
,
False
):
if
getattr
(
symbol
,
'__doc__'
,
False
):
rval
.
__doc__
=
symbol
.
__doc__
+
'
\n
'
+
rval
.
__doc__
rval
.
__doc__
=
symbol
.
__doc__
+
'
\n
'
+
rval
.
__doc__
...
@@ -1365,6 +1366,9 @@ def _scal_elemwise(symbol):
...
@@ -1365,6 +1366,9 @@ def _scal_elemwise(symbol):
pprint
.
assign
(
rval
,
printing
.
FunctionPrinter
(
symbolname
))
pprint
.
assign
(
rval
,
printing
.
FunctionPrinter
(
symbolname
))
return
rval
return
rval
return
construct
_scal_elemwise
=
_scal_elemwise_with_nfunc
(
None
,
None
,
None
)
#########################
#########################
...
@@ -1865,27 +1869,27 @@ def largest(*args):
...
@@ -1865,27 +1869,27 @@ def largest(*args):
# Comparison
# Comparison
##########################
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'less'
,
2
,
1
)
def
lt
(
a
,
b
):
def
lt
(
a
,
b
):
"""a < b"""
"""a < b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'greater'
,
2
,
1
)
def
gt
(
a
,
b
):
def
gt
(
a
,
b
):
"""a > b"""
"""a > b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'less_equal'
,
2
,
1
)
def
le
(
a
,
b
):
def
le
(
a
,
b
):
"""a <= b"""
"""a <= b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'greater_equal'
,
2
,
1
)
def
ge
(
a
,
b
):
def
ge
(
a
,
b
):
"""a >= b"""
"""a >= b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'equal'
,
2
,
1
)
def
eq
(
a
,
b
):
def
eq
(
a
,
b
):
"""a == b"""
"""a == b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'not_equal'
,
2
,
1
)
def
neq
(
a
,
b
):
def
neq
(
a
,
b
):
"""a != b"""
"""a != b"""
...
@@ -1903,19 +1907,19 @@ def switch(cond, ift, iff):
...
@@ -1903,19 +1907,19 @@ def switch(cond, ift, iff):
# Bit-wise
# Bit-wise
##########################
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'bitwise_and'
,
2
,
1
)
def
and_
(
a
,
b
):
def
and_
(
a
,
b
):
"""bitwise a & b"""
"""bitwise a & b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'bitwise_or'
,
2
,
1
)
def
or_
(
a
,
b
):
def
or_
(
a
,
b
):
"""bitwise a | b"""
"""bitwise a | b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'bitwise_xor'
,
2
,
1
)
def
xor
(
a
,
b
):
def
xor
(
a
,
b
):
"""bitwise a ^ b"""
"""bitwise a ^ b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'invert'
,
1
,
1
)
def
invert
(
a
):
def
invert
(
a
):
"""bitwise ~a"""
"""bitwise ~a"""
...
@@ -1923,7 +1927,7 @@ def invert(a):
...
@@ -1923,7 +1927,7 @@ def invert(a):
# Math
# Math
##########################
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'abs'
,
1
,
1
)
def
abs_
(
a
):
def
abs_
(
a
):
"""|`a`|
"""|`a`|
...
@@ -1934,43 +1938,43 @@ def abs_(a):
...
@@ -1934,43 +1938,43 @@ def abs_(a):
pprint
.
assign
(
abs_
,
printing
.
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
pprint
.
assign
(
abs_
,
printing
.
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'exp'
,
1
,
1
)
def
exp
(
a
):
def
exp
(
a
):
"""e^`a`"""
"""e^`a`"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'negative'
,
1
,
1
)
def
neg
(
a
):
def
neg
(
a
):
"""-a"""
"""-a"""
@_scal_elemwise
@_scal_elemwise
# numpy.reciprocal does integer division on integer inputs (which is not very interesting)
def
inv
(
a
):
def
inv
(
a
):
"""1.0/a"""
"""1.0/a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log'
,
1
,
1
)
def
log
(
a
):
def
log
(
a
):
"""base e logarithm of a"""
"""base e logarithm of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log2'
,
1
,
1
)
def
log2
(
a
):
def
log2
(
a
):
"""base 2 logarithm of a"""
"""base 2 logarithm of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log10'
,
1
,
1
)
def
log10
(
a
):
def
log10
(
a
):
"""base 10 logarithm of a"""
"""base 10 logarithm of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log1p'
,
1
,
1
)
def
log1p
(
a
):
def
log1p
(
a
):
"""log(1+a)"""
"""log(1+a)"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sign'
,
1
,
1
)
def
sgn
(
a
):
def
sgn
(
a
):
"""sign of a"""
"""sign of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'ceil'
,
1
,
1
)
def
ceil
(
a
):
def
ceil
(
a
):
"""ceiling of a"""
"""ceiling of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'floor'
,
1
,
1
)
def
floor
(
a
):
def
floor
(
a
):
"""floor of a"""
"""floor of a"""
...
@@ -1989,7 +1993,10 @@ def round(a, mode="half_away_from_zero"):
...
@@ -1989,7 +1993,10 @@ def round(a, mode="half_away_from_zero"):
else
:
else
:
raise
Exception
(
"round mode
%
s is not implemented."
%
mode
)
raise
Exception
(
"round mode
%
s is not implemented."
%
mode
)
@_scal_elemwise
# def __round_half_to_even(a, dest):
# dest[:] = numpy.around(a)
@_scal_elemwise_with_nfunc
(
'around'
,
1
,
0
)
def
round_half_to_even
(
a
):
def
round_half_to_even
(
a
):
"""round_half_to_even(a)"""
"""round_half_to_even(a)"""
...
@@ -1997,35 +2004,35 @@ def round_half_to_even(a):
...
@@ -1997,35 +2004,35 @@ def round_half_to_even(a):
def
round_half_away_from_zero
(
a
):
def
round_half_away_from_zero
(
a
):
"""round_half_away_from_zero(a)"""
"""round_half_away_from_zero(a)"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'square'
,
1
,
1
)
def
sqr
(
a
):
def
sqr
(
a
):
"""square of a"""
"""square of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sqrt'
,
1
,
1
)
def
sqrt
(
a
):
def
sqrt
(
a
):
"""square root of a"""
"""square root of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'cos'
,
1
,
1
)
def
cos
(
a
):
def
cos
(
a
):
"""cosine of a"""
"""cosine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sin'
,
1
,
1
)
def
sin
(
a
):
def
sin
(
a
):
"""sine of a"""
"""sine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'tan'
,
1
,
1
)
def
tan
(
a
):
def
tan
(
a
):
"""tangent of a"""
"""tangent of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'cosh'
,
1
,
1
)
def
cosh
(
a
):
def
cosh
(
a
):
"""hyperbolic cosine of a"""
"""hyperbolic cosine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sinh'
,
1
,
1
)
def
sinh
(
a
):
def
sinh
(
a
):
"""hyperbolic sine of a"""
"""hyperbolic sine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'tanh'
,
1
,
1
)
def
tanh
(
a
):
def
tanh
(
a
):
"""hyperbolic tangent of a"""
"""hyperbolic tangent of a"""
...
@@ -2037,19 +2044,19 @@ def erf(a):
...
@@ -2037,19 +2044,19 @@ def erf(a):
def
erfc
(
a
):
def
erfc
(
a
):
"""complementary error function"""
"""complementary error function"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'real'
,
1
,
0
)
def
real
(
z
):
def
real
(
z
):
"""Return real component of complex-valued tensor `z`"""
"""Return real component of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'imag'
,
1
,
0
)
def
imag
(
z
):
def
imag
(
z
):
"""Return imaginary component of complex-valued tensor `z`"""
"""Return imaginary component of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'angle'
,
1
,
0
)
def
angle
(
z
):
def
angle
(
z
):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
"""Return polar-coordinate angle of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise
# numpy.complex cannot build tensors
def
complex
(
real
,
imag
):
def
complex
(
real
,
imag
):
"""Return complex-valued tensor with `real` and `imag` components"""
"""Return complex-valued tensor with `real` and `imag` components"""
...
@@ -2475,13 +2482,13 @@ setdefault = default # legacy
...
@@ -2475,13 +2482,13 @@ setdefault = default # legacy
##########################
##########################
# Arithmetics
# Arithmetics
##########################
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'maximum'
,
2
,
1
)
def
maximum
(
x
,
y
):
def
maximum
(
x
,
y
):
"""elemwise maximum. See max for the maximum in one tensor
"""elemwise maximum. See max for the maximum in one tensor
"""
"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'minimum'
,
2
,
1
)
def
minimum
(
x
,
y
):
def
minimum
(
x
,
y
):
"""elemwise minimum. See min for the minimum in one tensor
"""elemwise minimum. See min for the minimum in one tensor
"""
"""
...
@@ -2495,47 +2502,47 @@ def div_proxy(x, y):
...
@@ -2495,47 +2502,47 @@ def div_proxy(x, y):
else
:
else
:
return
true_div
(
x
,
y
)
return
true_div
(
x
,
y
)
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'add'
,
2
,
1
)
def
add
(
a
,
*
other_terms
):
def
add
(
a
,
*
other_terms
):
"""elementwise addition"""
"""elementwise addition"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'subtract'
,
2
,
1
)
def
sub
(
a
,
b
):
def
sub
(
a
,
b
):
"""elementwise subtraction"""
"""elementwise subtraction"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'multiply'
,
2
,
1
)
def
mul
(
a
,
*
other_terms
):
def
mul
(
a
,
*
other_terms
):
"""elementwise multiplication"""
"""elementwise multiplication"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'true_divide'
,
2
,
1
)
def
true_div
(
a
,
b
):
def
true_div
(
a
,
b
):
"""elementwise [true] division (inverse of multiplication)"""
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'floor_divide'
,
2
,
1
)
def
floor_div
(
a
,
b
):
def
floor_div
(
a
,
b
):
"""elementwise [floor] division (inverse of multiplication)"""
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'floor_divide'
,
2
,
1
)
# not a c/p error, floor_div and int_div are the same thing
def
int_div
(
a
,
b
):
def
int_div
(
a
,
b
):
"""elementwise integer-division"""
"""elementwise integer-division"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'mod'
,
2
,
1
)
def
mod
(
a
,
b
):
def
mod
(
a
,
b
):
"""elementwise modulo"""
"""elementwise modulo"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'power'
,
2
,
1
)
def
pow
(
a
,
b
):
def
pow
(
a
,
b
):
"""elementwise power"""
"""elementwise power"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'clip'
,
3
,
1
)
def
clip
(
x
,
min
,
max
):
def
clip
(
x
,
min
,
max
):
"""clip x to be between min and max"""
"""clip x to be between min and max"""
# see decorator for function body
# see decorator for function body
...
...
theano/tensor/elemwise.py
浏览文件 @
4bbac540
...
@@ -361,6 +361,21 @@ class DimShufflePrinter:
...
@@ -361,6 +361,21 @@ class DimShufflePrinter:
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
DimShuffle
),
DimShufflePrinter
())
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
DimShuffle
),
DimShufflePrinter
())
def
_make_nfunc
(
name
,
nin
,
nout
):
f
=
getattr
(
numpy
,
name
)
return
f
# if name.endswith("*"):
# name = name[:-1]
# f = getattr(numpy, name)
# def fn(*args):
# args[-1][:] = f(*(args[:-1]))
# return fn
# else:
# f = getattr(numpy, name)
# return f
################
################
### Elemwise ###
### Elemwise ###
################
################
...
@@ -392,7 +407,7 @@ class Elemwise(Op):
...
@@ -392,7 +407,7 @@ class Elemwise(Op):
Elemwise(log)(rand(3, 4, 5))
Elemwise(log)(rand(3, 4, 5))
"""
"""
def
__init__
(
self
,
scalar_op
,
inplace_pattern
=
{},
name
=
None
):
def
__init__
(
self
,
scalar_op
,
inplace_pattern
=
{},
name
=
None
,
nfunc_spec
=
None
):
"""
"""
Usage: Elemwise(scalar_op, inplace_pattern = {})
Usage: Elemwise(scalar_op, inplace_pattern = {})
...
@@ -406,10 +421,14 @@ class Elemwise(Op):
...
@@ -406,10 +421,14 @@ class Elemwise(Op):
self
.
scalar_op
=
scalar_op
self
.
scalar_op
=
scalar_op
self
.
inplace_pattern
=
inplace_pattern
self
.
inplace_pattern
=
inplace_pattern
self
.
destroy_map
=
dict
((
o
,
[
i
])
for
o
,
i
in
inplace_pattern
.
items
())
self
.
destroy_map
=
dict
((
o
,
[
i
])
for
o
,
i
in
inplace_pattern
.
items
())
if
scalar_op
.
nin
>
0
:
self
.
ufunc
=
numpy
.
frompyfunc
(
scalar_op
.
impl
,
scalar_op
.
nin
,
scalar_op
.
nout
)
else
:
self
.
ufunc
=
None
self
.
ufunc
=
None
self
.
nfunc
=
None
self
.
nfunc_spec
=
nfunc_spec
if
nfunc_spec
:
self
.
nfunc
=
_make_nfunc
(
*
nfunc_spec
)
elif
scalar_op
.
nin
>
0
:
self
.
ufunc
=
numpy
.
frompyfunc
(
scalar_op
.
impl
,
scalar_op
.
nin
,
scalar_op
.
nout
)
#precompute the hash of this node
#precompute the hash of this node
self
.
_rehash
()
self
.
_rehash
()
...
@@ -417,16 +436,19 @@ class Elemwise(Op):
...
@@ -417,16 +436,19 @@ class Elemwise(Op):
def
__getstate__
(
self
):
def
__getstate__
(
self
):
d
=
copy
(
self
.
__dict__
)
d
=
copy
(
self
.
__dict__
)
d
.
pop
(
'ufunc'
)
d
.
pop
(
'ufunc'
)
d
.
pop
(
'nfunc'
)
d
.
pop
(
'__epydoc_asRoutine'
,
None
)
d
.
pop
(
'__epydoc_asRoutine'
,
None
)
d
.
pop
(
'_hashval'
)
d
.
pop
(
'_hashval'
)
return
d
return
d
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
if
self
.
scalar_op
.
nin
>
0
:
self
.
ufunc
=
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
self
.
scalar_op
.
nin
,
self
.
scalar_op
.
nout
)
else
:
self
.
ufunc
=
None
self
.
ufunc
=
None
self
.
nfunc
=
None
if
getattr
(
self
,
'nfunc_spec'
,
None
):
self
.
nfunc
=
_make_nfunc
(
*
self
.
nfunc_spec
)
elif
self
.
scalar_op
.
nin
>
0
:
self
.
ufunc
=
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
self
.
scalar_op
.
nin
,
self
.
scalar_op
.
nout
)
self
.
_rehash
()
self
.
_rehash
()
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
...
@@ -621,10 +643,16 @@ class Elemwise(Op):
...
@@ -621,10 +643,16 @@ class Elemwise(Op):
else
:
else
:
odat
=
numpy
.
ndarray
(
shape
,
dtype
=
output
.
type
.
dtype
)
odat
=
numpy
.
ndarray
(
shape
,
dtype
=
output
.
type
.
dtype
)
storage
[
0
]
=
odat
storage
[
0
]
=
odat
ufunc_args
=
inputs
# + output_storage
if
self
.
nfunc
and
len
(
inputs
)
==
self
.
nfunc_spec
[
1
]:
ufunc
=
self
.
nfunc
nout
=
1
else
:
# the second calling form is used because in certain versions of numpy
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
# the first (faster) version leads to segfaults
ufunc_args
=
inputs
# + output_storage
ufunc
=
self
.
ufunc
or
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
len
(
inputs
),
self
.
scalar_op
.
nout
)
ufunc
=
self
.
ufunc
or
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
len
(
inputs
),
self
.
scalar_op
.
nout
)
nout
=
ufunc
.
nout
try
:
try
:
variables
=
ufunc
(
*
ufunc_args
)
variables
=
ufunc
(
*
ufunc_args
)
...
@@ -633,7 +661,7 @@ class Elemwise(Op):
...
@@ -633,7 +661,7 @@ class Elemwise(Op):
'for params of shape'
,
[
arg
.
shape
for
arg
in
ufunc_args
]
'for params of shape'
,
[
arg
.
shape
for
arg
in
ufunc_args
]
e
.
args
=
e
.
args
+
errormsg
e
.
args
=
e
.
args
+
errormsg
raise
raise
if
ufunc
.
nout
==
1
:
variables
=
[
variables
]
if
nout
==
1
:
variables
=
[
variables
]
for
variable
,
storage
in
zip
(
variables
,
output_storage
):
for
variable
,
storage
in
zip
(
variables
,
output_storage
):
if
hasattr
(
variable
,
'shape'
)
and
storage
[
0
]
.
shape
!=
variable
.
shape
:
if
hasattr
(
variable
,
'shape'
)
and
storage
[
0
]
.
shape
!=
variable
.
shape
:
storage
[
0
]
.
resize
(
variable
.
shape
)
storage
[
0
]
.
resize
(
variable
.
shape
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论