Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4bbac540
提交
4bbac540
authored
3月 15, 2011
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
8ea860cd
59561698
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
155 行增加
和
122 行删除
+155
-122
graph.py
theano/gof/graph.py
+1
-1
scan.py
theano/scan.py
+40
-42
basic.py
theano/tensor/basic.py
+75
-68
elemwise.py
theano/tensor/elemwise.py
+39
-11
没有找到文件。
theano/gof/graph.py
浏览文件 @
4bbac540
...
...
@@ -417,7 +417,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
raise
ValueError
(
'mode should be bfs or dfs'
,
mode
)
rval_set
=
set
()
rval_list
=
list
()
if
mode
is
'bfs'
:
start_pop
=
start
.
popleft
if
mode
==
'bfs'
:
start_pop
=
start
.
popleft
else
:
start_pop
=
start
.
pop
expand_inv
=
{}
while
start
:
...
...
theano/scan.py
浏览文件 @
4bbac540
...
...
@@ -106,7 +106,7 @@ def map( fn
:param go_backwards: Boolean value that decides the direction of
iteration. True means that sequences are parsed
from the end towards the begining, while False
from the end towards the begin
n
ing, while False
is the other way around.
:param mode: See ``scan``.
...
...
@@ -301,7 +301,7 @@ def scan( fn
scan)
The order of the sequences is the same as the one in the list
`sequences` given to scan. The order of the outputs is the sa
n
e
`sequences` given to scan. The order of the outputs is the sa
m
e
as the order of ``output_info``. For any sequence or output the
order of the time slices is the same as the order of the time
taps provided. For example if one writes the following :
...
...
@@ -314,7 +314,7 @@ def scan( fn
, outputs_info = [ dict( Output1, taps = [-3,-5])
, dict( Output2, taps = None)
, Output3 ]
, non_sequences = [ Argument1, Argument
2])
, non_sequences = [ Argument1, Argument2])
``fn`` should expect the following arguments in this given order:
...
...
@@ -341,7 +341,7 @@ def scan( fn
`fn` should return an update dictionary ( that tells how to
update any shared variable after each iteration ste). The
dictionary can optionally be given as a list of tuples. There is
no constraint on the order of these two list, ``fn`` can return
no constraint on the order of these two list
s
, ``fn`` can return
either ``(outputs_list, update_dictionary)`` or ``(update_dictionary,
outputs_list)`` or just one of the two (in case the other is
empty).
...
...
@@ -369,7 +369,7 @@ def scan( fn
:param outputs_info:
``outputs_info`` is the list of Theano variables or dictionaries
describing the initial state of the outputs computed
recurrently. When this initial state
s are given as dictionary
recurrently. When this initial state
is given as a dictionary,
optional information can be provided about the output corresponding
to these initial states. The dictionary should have the following
keys:
...
...
@@ -388,11 +388,11 @@ def scan( fn
the initial state, which in this case should have the shape
(5,)+output.shape. If this variable containing the initial
state is called ``init_y`` then ``init_y[0]`` *corresponds to*
``output[-5]``; ``init_y[1]`` *correponds to* ``output[-4]``;
``output[-5]``; ``init_y[1]`` *corre
s
ponds to* ``output[-4]``;
``init_y[2]`` corresponds to ``output[-3]``; ``init_y[3]``
coresponds to ``output[-2]``; ``init_y[4]`` corresponds to
``output[-1]``. While this order might seem strange, it comes
natural from splitting an array at a given point. Assume that
natural
ly
from splitting an array at a given point. Assume that
we have a array ``x``, and we choose ``k`` to be time step
``0``. Then our initial state would be ``x[:k]``, while the
output will be ``x[k:]``. Looking at this split, elements in
...
...
@@ -401,17 +401,10 @@ def scan( fn
``fn``. They are provided as a list of *negative* integers,
where a value ``k`` implies that at iteration step ``t`` scan will
pass to ``fn`` the slice ``t+k``.
* ``inplace`` -- One of the Theano variables provided as
``sequences``. ``scan`` will try to compute this output *in
place* of the provided input *iff* it respects the following
constraints:
* There is no other output that is denied to be computed in
place for whatever reason.
* ``fn`` is not using past taps of the input sequence that
will get overwritten by the output
* ``inplace`` -- DEPRECATED. Previously, one could specify with this
option whether the output should overwrite some particular input,
but it is now inferred automatically. If you specify this option
it will be ignored.
* ``return_steps`` -- Integer representing the number of steps
to return for the current steps. For example, if ``k`` is
provided, ``scan`` will return ``output[-k:]``. This is meant as a
...
...
@@ -422,7 +415,7 @@ def scan( fn
* ``store_steps`` -- Integer representing the number of
intermediate steps ``scan`` should use for a given output. Use
this key only if you really know what you are doing. In general
it is recommended to let scan decide for you the am
m
ount of memory
it is recommended to let scan decide for you the amount of memory
it should use.
``scan`` will follow this logic if partial information is given:
...
...
@@ -437,12 +430,12 @@ def scan( fn
* If you wrap an output in a dictionary but you do not provide any
initial state, it assumes that you are not using any form of
taps.
* If you provide
a
``None`` instead of a variable or a dictionary
* If you provide ``None`` instead of a variable or a dictionary
``scan`` assumes that you will not use any taps for this output
(like for example in case of a map)
If ``outputs_info`` is an empty list or None, ``scan`` assumes
that no tap is used for any of the o
tu
puts. If information is
that no tap is used for any of the o
ut
puts. If information is
provided just for a subset of the outputs an exception is
raised (because there is no convention on how scan should map
the provided information to the outputs of ``fn``)
...
...
@@ -450,8 +443,8 @@ def scan( fn
:param non_sequences:
``non_sequences`` is the list of arguments that are passed to
``fn`` at each step
s. One can opt to exclude
shared variables
used in ``fn``
from this list
.
``fn`` at each step
. It is not necessary to list
shared variables
used in ``fn``
here, since they will be identified automatically
.
:param n_steps:
...
...
@@ -469,9 +462,10 @@ def scan( fn
:param truncate_gradient:
``truncate_gradient`` is the number of steps to use in truncated
BPTT. If you compute gradients through a scan op, they are
BPTT (backpropagation through time). If you compute gradients
through a scan op, they are
computed using backpropagation through time. By providing a
different value th
e
n -1, you choose to use truncated BPTT instead
different value th
a
n -1, you choose to use truncated BPTT instead
of classical BPTT, where you go for only ``truncate_gradient``
number of steps back in time.
...
...
@@ -512,33 +506,32 @@ def scan( fn
"""
# General observation : this code is executed only once, at creation
# of the computational graph, so we don't yet need to be smart about
# anything (
to speed things up)
# anything (to speed things up)
# check if inputs are just single variables instead of lists
if
not
(
type
(
sequences
)
in
(
list
,
tuple
))
and
sequences
!=
None
:
seqs
=
[
sequences
]
elif
sequences
==
None
:
if
sequences
==
None
:
seqs
=
[]
elif
not
(
type
(
sequences
)
in
(
list
,
tuple
)):
seqs
=
[
sequences
]
else
:
seqs
=
sequences
if
not
(
type
(
outputs_info
)
in
(
list
,
tuple
))
and
outputs_info
!=
None
:
outs_info
=
[
outputs_info
]
elif
outputs_info
==
None
:
if
outputs_info
==
None
:
outs_info
=
[]
elif
not
(
type
(
outputs_info
)
in
(
list
,
tuple
)):
outs_info
=
[
outputs_info
]
else
:
outs_info
=
outputs_info
if
(
not
(
type
(
non_sequences
)
in
(
list
,
tuple
))
and
non_sequences
!=
None
):
non_seqs
=
[
non_sequences
]
elif
non_sequences
==
None
:
if
non_sequences
==
None
:
non_seqs
=
[]
elif
not
(
type
(
non_sequences
)
in
(
list
,
tuple
)):
non_seqs
=
[
non_sequences
]
else
:
non_seqs
=
non_sequences
# If we provided a known number of steps (
before compilation)
# If we provided a known number of steps (before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
...
...
@@ -570,7 +563,7 @@ def scan( fn
sequences_taps
=
{}
outputs_taps
=
{}
# Assume that for any output we want to store everythin that it produces
# Assume that for any output we want to store everythin
g
that it produces
store_steps
=
[]
return_steps
=
{}
...
...
@@ -591,8 +584,8 @@ def scan( fn
# See if the user actually provided the None value to taps,
# which would indicate that the sequence was provided but
# not used by the internal function; Only if the user has
# not provided anything add the defaul [0]
#
Possible reason to provide a squence and not use it
is
# not provided anything add the defaul
t
[0]
#
A possible reason to provide a sequence and not use it
is
# if you want to compute the output
# inplace of this input; it is a very unlikely behaviour but
# we do want to cover it for completeness
...
...
@@ -635,7 +628,7 @@ def scan( fn
raise
ValueError
(
'If you are using slices of an output you need to '
\
'provide an initial state for it'
,
outs_info
[
i
])
# if there is an intial state but no tap, we will add the default value
# for taps, namely [-1] ( previous value); not that this will happen
# for taps, namely [-1] ( previous value); not
e
that this will happen
# even though you have provided for taps the value None, which is a bit
# strange (why would one provide an initial state but tell scan not to
# use it ? ), just that in that case we will throw in a warning message
...
...
@@ -658,9 +651,14 @@ def scan( fn
if
outs_info
[
i
]
.
get
(
'taps'
,
None
):
# Create a separate outputs_taps dictionary with all the outputs taps; This
# is how the Scan Op expects this information, separ
e
ted from the variables
# is how the Scan Op expects this information, separ
a
ted from the variables
outputs_taps
[
i
]
=
outs_info
[
i
][
'taps'
]
if
outs_info
[
i
]
.
get
(
'inplace'
,
None
):
warning
(
"DEPRECATED: you should not set the inplace parameter for an output in scan(...). "
"This can cause problems for the early stages of the optimizer "
"and there is a late optimization which automatically figures it out."
)
# The same is true for the inplace info; it has to go into a separate
# dictionary based on index; Note that the input we're replacing should also
# come as an index, therefore we have to look for it at this point
...
...
theano/tensor/basic.py
浏览文件 @
4bbac540
...
...
@@ -1336,35 +1336,39 @@ def _redefine_asRoutine(real_symbol_value):
return
real_symbol_value
return
decorator
def
_scal_elemwise
(
symbol
):
def
_scal_elemwise
_with_nfunc
(
nfunc
,
nin
,
nout
):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname
=
symbol
.
__name__
inplace
=
symbolname
.
endswith
(
'_inplace'
)
if
inplace
:
msg
=
"inplace"
else
:
msg
=
"no_inplace"
n
=
"Elemwise{
%
s,
%
s}"
%
(
symbolname
,
msg
)
def
construct
(
symbol
):
symbolname
=
symbol
.
__name__
inplace
=
symbolname
.
endswith
(
'_inplace'
)
if
inplace
:
msg
=
"inplace"
else
:
msg
=
"no_inplace"
n
=
"Elemwise{
%
s,
%
s}"
%
(
symbolname
,
msg
)
if
inplace
:
scalar_op
=
getattr
(
scal
,
symbolname
[:
-
len
(
'_inplace'
)])
inplace_scalar_op
=
scalar_op
.
__class__
(
scal
.
transfer_type
(
0
))
rval
=
elemwise
.
Elemwise
(
inplace_scalar_op
,
{
0
:
0
},
name
=
n
)
else
:
scalar_op
=
getattr
(
scal
,
symbolname
)
rval
=
elemwise
.
Elemwise
(
scalar_op
,
name
=
n
)
if
inplace
:
scalar_op
=
getattr
(
scal
,
symbolname
[:
-
len
(
'_inplace'
)])
inplace_scalar_op
=
scalar_op
.
__class__
(
scal
.
transfer_type
(
0
))
rval
=
elemwise
.
Elemwise
(
inplace_scalar_op
,
{
0
:
0
},
name
=
n
,
nfunc_spec
=
((
nfunc
,
nin
,
nout
)
if
nfunc
else
None
)
)
else
:
scalar_op
=
getattr
(
scal
,
symbolname
)
rval
=
elemwise
.
Elemwise
(
scalar_op
,
name
=
n
,
nfunc_spec
=
((
nfunc
,
nin
,
nout
)
if
nfunc
else
None
)
)
if
getattr
(
symbol
,
'__doc__'
,
False
):
rval
.
__doc__
=
symbol
.
__doc__
+
'
\n
'
+
rval
.
__doc__
if
getattr
(
symbol
,
'__doc__'
,
False
):
rval
.
__doc__
=
symbol
.
__doc__
+
'
\n
'
+
rval
.
__doc__
#for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval
.
__epydoc_asRoutine
=
symbol
rval
.
__module__
=
'tensor'
#for the meaning of this see the ./epydoc script
# it makes epydoc display rval as if it were a function, not an object
rval
.
__epydoc_asRoutine
=
symbol
rval
.
__module__
=
'tensor'
pprint
.
assign
(
rval
,
printing
.
FunctionPrinter
(
symbolname
))
pprint
.
assign
(
rval
,
printing
.
FunctionPrinter
(
symbolname
))
return
rval
return
rval
return
construct
_scal_elemwise
=
_scal_elemwise_with_nfunc
(
None
,
None
,
None
)
#########################
...
...
@@ -1865,27 +1869,27 @@ def largest(*args):
# Comparison
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'less'
,
2
,
1
)
def
lt
(
a
,
b
):
"""a < b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'greater'
,
2
,
1
)
def
gt
(
a
,
b
):
"""a > b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'less_equal'
,
2
,
1
)
def
le
(
a
,
b
):
"""a <= b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'greater_equal'
,
2
,
1
)
def
ge
(
a
,
b
):
"""a >= b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'equal'
,
2
,
1
)
def
eq
(
a
,
b
):
"""a == b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'not_equal'
,
2
,
1
)
def
neq
(
a
,
b
):
"""a != b"""
...
...
@@ -1903,19 +1907,19 @@ def switch(cond, ift, iff):
# Bit-wise
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'bitwise_and'
,
2
,
1
)
def
and_
(
a
,
b
):
"""bitwise a & b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'bitwise_or'
,
2
,
1
)
def
or_
(
a
,
b
):
"""bitwise a | b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'bitwise_xor'
,
2
,
1
)
def
xor
(
a
,
b
):
"""bitwise a ^ b"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'invert'
,
1
,
1
)
def
invert
(
a
):
"""bitwise ~a"""
...
...
@@ -1923,7 +1927,7 @@ def invert(a):
# Math
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'abs'
,
1
,
1
)
def
abs_
(
a
):
"""|`a`|
...
...
@@ -1934,43 +1938,43 @@ def abs_(a):
pprint
.
assign
(
abs_
,
printing
.
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'exp'
,
1
,
1
)
def
exp
(
a
):
"""e^`a`"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'negative'
,
1
,
1
)
def
neg
(
a
):
"""-a"""
@_scal_elemwise
@_scal_elemwise
# numpy.reciprocal does integer division on integer inputs (which is not very interesting)
def
inv
(
a
):
"""1.0/a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log'
,
1
,
1
)
def
log
(
a
):
"""base e logarithm of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log2'
,
1
,
1
)
def
log2
(
a
):
"""base 2 logarithm of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log10'
,
1
,
1
)
def
log10
(
a
):
"""base 10 logarithm of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'log1p'
,
1
,
1
)
def
log1p
(
a
):
"""log(1+a)"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sign'
,
1
,
1
)
def
sgn
(
a
):
"""sign of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'ceil'
,
1
,
1
)
def
ceil
(
a
):
"""ceiling of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'floor'
,
1
,
1
)
def
floor
(
a
):
"""floor of a"""
...
...
@@ -1989,7 +1993,10 @@ def round(a, mode="half_away_from_zero"):
else
:
raise
Exception
(
"round mode
%
s is not implemented."
%
mode
)
@_scal_elemwise
# def __round_half_to_even(a, dest):
# dest[:] = numpy.around(a)
@_scal_elemwise_with_nfunc
(
'around'
,
1
,
0
)
def
round_half_to_even
(
a
):
"""round_half_to_even(a)"""
...
...
@@ -1997,35 +2004,35 @@ def round_half_to_even(a):
def
round_half_away_from_zero
(
a
):
"""round_half_away_from_zero(a)"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'square'
,
1
,
1
)
def
sqr
(
a
):
"""square of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sqrt'
,
1
,
1
)
def
sqrt
(
a
):
"""square root of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'cos'
,
1
,
1
)
def
cos
(
a
):
"""cosine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sin'
,
1
,
1
)
def
sin
(
a
):
"""sine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'tan'
,
1
,
1
)
def
tan
(
a
):
"""tangent of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'cosh'
,
1
,
1
)
def
cosh
(
a
):
"""hyperbolic cosine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'sinh'
,
1
,
1
)
def
sinh
(
a
):
"""hyperbolic sine of a"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'tanh'
,
1
,
1
)
def
tanh
(
a
):
"""hyperbolic tangent of a"""
...
...
@@ -2037,19 +2044,19 @@ def erf(a):
def
erfc
(
a
):
"""complementary error function"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'real'
,
1
,
0
)
def
real
(
z
):
"""Return real component of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'imag'
,
1
,
0
)
def
imag
(
z
):
"""Return imaginary component of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'angle'
,
1
,
0
)
def
angle
(
z
):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
@_scal_elemwise
@_scal_elemwise
# numpy.complex cannot build tensors
def
complex
(
real
,
imag
):
"""Return complex-valued tensor with `real` and `imag` components"""
...
...
@@ -2475,13 +2482,13 @@ setdefault = default # legacy
##########################
# Arithmetics
##########################
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'maximum'
,
2
,
1
)
def
maximum
(
x
,
y
):
"""elemwise maximum. See max for the maximum in one tensor
"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'minimum'
,
2
,
1
)
def
minimum
(
x
,
y
):
"""elemwise minimum. See min for the minimum in one tensor
"""
...
...
@@ -2495,47 +2502,47 @@ def div_proxy(x, y):
else
:
return
true_div
(
x
,
y
)
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'add'
,
2
,
1
)
def
add
(
a
,
*
other_terms
):
"""elementwise addition"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'subtract'
,
2
,
1
)
def
sub
(
a
,
b
):
"""elementwise subtraction"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'multiply'
,
2
,
1
)
def
mul
(
a
,
*
other_terms
):
"""elementwise multiplication"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'true_divide'
,
2
,
1
)
def
true_div
(
a
,
b
):
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'floor_divide'
,
2
,
1
)
def
floor_div
(
a
,
b
):
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'floor_divide'
,
2
,
1
)
# not a c/p error, floor_div and int_div are the same thing
def
int_div
(
a
,
b
):
"""elementwise integer-division"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'mod'
,
2
,
1
)
def
mod
(
a
,
b
):
"""elementwise modulo"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'power'
,
2
,
1
)
def
pow
(
a
,
b
):
"""elementwise power"""
# see decorator for function body
@_scal_elemwise
@_scal_elemwise
_with_nfunc
(
'clip'
,
3
,
1
)
def
clip
(
x
,
min
,
max
):
"""clip x to be between min and max"""
# see decorator for function body
...
...
theano/tensor/elemwise.py
浏览文件 @
4bbac540
...
...
@@ -361,6 +361,21 @@ class DimShufflePrinter:
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
DimShuffle
),
DimShufflePrinter
())
def
_make_nfunc
(
name
,
nin
,
nout
):
f
=
getattr
(
numpy
,
name
)
return
f
# if name.endswith("*"):
# name = name[:-1]
# f = getattr(numpy, name)
# def fn(*args):
# args[-1][:] = f(*(args[:-1]))
# return fn
# else:
# f = getattr(numpy, name)
# return f
################
### Elemwise ###
################
...
...
@@ -392,7 +407,7 @@ class Elemwise(Op):
Elemwise(log)(rand(3, 4, 5))
"""
def
__init__
(
self
,
scalar_op
,
inplace_pattern
=
{},
name
=
None
):
def
__init__
(
self
,
scalar_op
,
inplace_pattern
=
{},
name
=
None
,
nfunc_spec
=
None
):
"""
Usage: Elemwise(scalar_op, inplace_pattern = {})
...
...
@@ -406,10 +421,14 @@ class Elemwise(Op):
self
.
scalar_op
=
scalar_op
self
.
inplace_pattern
=
inplace_pattern
self
.
destroy_map
=
dict
((
o
,
[
i
])
for
o
,
i
in
inplace_pattern
.
items
())
if
scalar_op
.
nin
>
0
:
self
.
ufunc
=
None
self
.
nfunc
=
None
self
.
nfunc_spec
=
nfunc_spec
if
nfunc_spec
:
self
.
nfunc
=
_make_nfunc
(
*
nfunc_spec
)
elif
scalar_op
.
nin
>
0
:
self
.
ufunc
=
numpy
.
frompyfunc
(
scalar_op
.
impl
,
scalar_op
.
nin
,
scalar_op
.
nout
)
else
:
self
.
ufunc
=
None
#precompute the hash of this node
self
.
_rehash
()
...
...
@@ -417,16 +436,19 @@ class Elemwise(Op):
def
__getstate__
(
self
):
d
=
copy
(
self
.
__dict__
)
d
.
pop
(
'ufunc'
)
d
.
pop
(
'nfunc'
)
d
.
pop
(
'__epydoc_asRoutine'
,
None
)
d
.
pop
(
'_hashval'
)
return
d
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
if
self
.
scalar_op
.
nin
>
0
:
self
.
ufunc
=
None
self
.
nfunc
=
None
if
getattr
(
self
,
'nfunc_spec'
,
None
):
self
.
nfunc
=
_make_nfunc
(
*
self
.
nfunc_spec
)
elif
self
.
scalar_op
.
nin
>
0
:
self
.
ufunc
=
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
self
.
scalar_op
.
nin
,
self
.
scalar_op
.
nout
)
else
:
self
.
ufunc
=
None
self
.
_rehash
()
def
make_node
(
self
,
*
inputs
):
...
...
@@ -621,10 +643,16 @@ class Elemwise(Op):
else
:
odat
=
numpy
.
ndarray
(
shape
,
dtype
=
output
.
type
.
dtype
)
storage
[
0
]
=
odat
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
ufunc_args
=
inputs
# + output_storage
ufunc
=
self
.
ufunc
or
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
len
(
inputs
),
self
.
scalar_op
.
nout
)
if
self
.
nfunc
and
len
(
inputs
)
==
self
.
nfunc_spec
[
1
]:
ufunc
=
self
.
nfunc
nout
=
1
else
:
# the second calling form is used because in certain versions of numpy
# the first (faster) version leads to segfaults
ufunc
=
self
.
ufunc
or
numpy
.
frompyfunc
(
self
.
scalar_op
.
impl
,
len
(
inputs
),
self
.
scalar_op
.
nout
)
nout
=
ufunc
.
nout
try
:
variables
=
ufunc
(
*
ufunc_args
)
...
...
@@ -633,7 +661,7 @@ class Elemwise(Op):
'for params of shape'
,
[
arg
.
shape
for
arg
in
ufunc_args
]
e
.
args
=
e
.
args
+
errormsg
raise
if
ufunc
.
nout
==
1
:
variables
=
[
variables
]
if
nout
==
1
:
variables
=
[
variables
]
for
variable
,
storage
in
zip
(
variables
,
output_storage
):
if
hasattr
(
variable
,
'shape'
)
and
storage
[
0
]
.
shape
!=
variable
.
shape
:
storage
[
0
]
.
resize
(
variable
.
shape
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论