Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
aecbbb99
提交
aecbbb99
authored
9月 23, 2016
作者:
abergeron
提交者:
GitHub
9月 23, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5001 from nouiz/Composite_name
Postpone Composite name creating
上级
d6c3505c
ac079830
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
132 行增加
和
81 行删除
+132
-81
debugmode.py
theano/compile/debugmode.py
+1
-3
cc.py
theano/gof/cc.py
+3
-0
op.py
theano/gof/op.py
+4
-5
printing.py
theano/printing.py
+42
-32
basic.py
theano/scalar/basic.py
+65
-35
elemwise.py
theano/tensor/elemwise.py
+14
-5
opt.py
theano/tensor/opt.py
+3
-1
没有找到文件。
theano/compile/debugmode.py
浏览文件 @
aecbbb99
...
@@ -1838,9 +1838,7 @@ class _Linker(gof.link.LocalLinker):
...
@@ -1838,9 +1838,7 @@ class _Linker(gof.link.LocalLinker):
thunk
.
outputs
=
[
storage_map
[
v
]
for
v
in
node
.
outputs
]
thunk
.
outputs
=
[
storage_map
[
v
]
for
v
in
node
.
outputs
]
thunk_other
=
thunk
thunk_other
=
thunk
else
:
else
:
new_node
=
node
.
op
.
prepare_node
(
node
,
storage_map
,
compute_map
)
node
.
op
.
prepare_node
(
node
,
storage_map
,
compute_map
)
if
new_node
is
not
None
:
node
=
new_node
debug
=
hasattr
(
node
.
op
,
'debug_perform'
)
debug
=
hasattr
(
node
.
op
,
'debug_perform'
)
...
...
theano/gof/cc.py
浏览文件 @
aecbbb99
...
@@ -1582,6 +1582,9 @@ class CLinker(link.Linker):
...
@@ -1582,6 +1582,9 @@ class CLinker(link.Linker):
# If we can't get a key, then forget the cache mechanism.
# If we can't get a key, then forget the cache mechanism.
module
=
self
.
compile_cmodule
()
module
=
self
.
compile_cmodule
()
else
:
else
:
# Set compute_map as None as clinker do not support lazy evaluation
for
node
in
self
.
node_order
:
node
.
op
.
prepare_node
(
node
,
storage_map
,
None
)
module
=
get_module_cache
()
.
module_from_key
(
module
=
get_module_cache
()
.
module_from_key
(
key
=
key
,
lnk
=
self
,
keep_lock
=
keep_lock
)
key
=
key
,
lnk
=
self
,
keep_lock
=
keep_lock
)
...
...
theano/gof/op.py
浏览文件 @
aecbbb99
...
@@ -795,7 +795,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -795,7 +795,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
Make any special modifications that the Op needs before doing
Make any special modifications that the Op needs before doing
make_thunk().
make_thunk().
This can
either modify the node inplace or return a new one
.
This can
modify the node inplace and should return nothing
.
"""
"""
pass
pass
...
@@ -916,10 +916,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -916,10 +916,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
"""
"""
logger
=
logging
.
getLogger
(
'theano.gof.op.Op'
)
logger
=
logging
.
getLogger
(
'theano.gof.op.Op'
)
new_node
=
self
.
prepare_node
(
node
,
storage_map
=
storage_map
,
self
.
prepare_node
(
node
,
storage_map
=
storage_map
,
compute_map
=
compute_map
)
compute_map
=
compute_map
)
if
new_node
is
not
None
:
node
=
new_node
if
not
hasattr
(
self
,
'_op_use_c_code'
):
if
not
hasattr
(
self
,
'_op_use_c_code'
):
warnings
.
warn
(
warnings
.
warn
(
"The __getstate__ method of '
%
s' is not implemented correctly."
"The __getstate__ method of '
%
s' is not implemented correctly."
...
...
theano/printing.py
浏览文件 @
aecbbb99
...
@@ -345,6 +345,11 @@ class PrinterState(gof.utils.scratchpad):
...
@@ -345,6 +345,11 @@ class PrinterState(gof.utils.scratchpad):
else
:
else
:
self
.
__dict__
.
update
(
props
)
self
.
__dict__
.
update
(
props
)
self
.
__dict__
.
update
(
more_props
)
self
.
__dict__
.
update
(
more_props
)
# A dict from the object to print to its string
# representation. If it is a dag and not a tree, it allow to
# parse each node of the graph only once. They will still be
# printed many times
self
.
memo
=
{}
def
clone
(
self
,
props
=
None
,
**
more_props
):
def
clone
(
self
,
props
=
None
,
**
more_props
):
if
props
is
None
:
if
props
is
None
:
...
@@ -361,6 +366,8 @@ class OperatorPrinter:
...
@@ -361,6 +366,8 @@ class OperatorPrinter:
assert
self
.
assoc
in
VALID_ASSOC
assert
self
.
assoc
in
VALID_ASSOC
def
process
(
self
,
output
,
pstate
):
def
process
(
self
,
output
,
pstate
):
if
output
in
pstate
.
memo
:
return
pstate
.
memo
[
output
]
pprinter
=
pstate
.
pprinter
pprinter
=
pstate
.
pprinter
node
=
output
.
owner
node
=
output
.
owner
if
node
is
None
:
if
node
is
None
:
...
@@ -393,9 +400,11 @@ class OperatorPrinter:
...
@@ -393,9 +400,11 @@ class OperatorPrinter:
else
:
else
:
s
=
(
"
%
s "
%
self
.
operator
)
.
join
(
input_strings
)
s
=
(
"
%
s "
%
self
.
operator
)
.
join
(
input_strings
)
if
parenthesize
:
if
parenthesize
:
r
eturn
"(
%
s)"
%
s
r
=
"(
%
s)"
%
s
else
:
else
:
return
s
r
=
s
pstate
.
memo
[
output
]
=
r
return
r
class
PatternPrinter
:
class
PatternPrinter
:
...
@@ -409,6 +418,8 @@ class PatternPrinter:
...
@@ -409,6 +418,8 @@ class PatternPrinter:
self
.
patterns
.
append
((
pattern
[
0
],
pattern
[
1
:]))
self
.
patterns
.
append
((
pattern
[
0
],
pattern
[
1
:]))
def
process
(
self
,
output
,
pstate
):
def
process
(
self
,
output
,
pstate
):
if
output
in
pstate
.
memo
:
return
pstate
.
memo
[
output
]
pprinter
=
pstate
.
pprinter
pprinter
=
pstate
.
pprinter
node
=
output
.
owner
node
=
output
.
owner
if
node
is
None
:
if
node
is
None
:
...
@@ -425,7 +436,9 @@ class PatternPrinter:
...
@@ -425,7 +436,9 @@ class PatternPrinter:
for
i
,
x
in
enumerate
(
pp_process
(
input
,
precedence
)
for
i
,
x
in
enumerate
(
pp_process
(
input
,
precedence
)
for
input
,
precedence
in
for
input
,
precedence
in
zip
(
node
.
inputs
,
precedences
)))
zip
(
node
.
inputs
,
precedences
)))
return
pattern
%
d
r
=
pattern
%
d
pstate
.
memo
[
output
]
=
r
return
r
class
FunctionPrinter
:
class
FunctionPrinter
:
...
@@ -434,6 +447,8 @@ class FunctionPrinter:
...
@@ -434,6 +447,8 @@ class FunctionPrinter:
self
.
names
=
names
self
.
names
=
names
def
process
(
self
,
output
,
pstate
):
def
process
(
self
,
output
,
pstate
):
if
output
in
pstate
.
memo
:
return
pstate
.
memo
[
output
]
pprinter
=
pstate
.
pprinter
pprinter
=
pstate
.
pprinter
node
=
output
.
owner
node
=
output
.
owner
if
node
is
None
:
if
node
is
None
:
...
@@ -441,40 +456,27 @@ class FunctionPrinter:
...
@@ -441,40 +456,27 @@ class FunctionPrinter:
"not the result of an operation"
%
self
.
names
)
"not the result of an operation"
%
self
.
names
)
idx
=
node
.
outputs
.
index
(
output
)
idx
=
node
.
outputs
.
index
(
output
)
name
=
self
.
names
[
idx
]
name
=
self
.
names
[
idx
]
r
eturn
"
%
s(
%
s)"
%
(
name
,
", "
.
join
(
r
=
"
%
s(
%
s)"
%
(
name
,
", "
.
join
(
[
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=-
1000
))
[
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=-
1000
))
for
input
in
node
.
inputs
]))
for
input
in
node
.
inputs
]))
pstate
.
memo
[
output
]
=
r
return
r
class
MemberPrinter
:
def
__init__
(
self
,
*
names
):
self
.
names
=
names
def
process
(
self
,
output
,
pstate
):
pprinter
=
pstate
.
pprinter
node
=
output
.
owner
if
node
is
None
:
raise
TypeError
(
"function
%
s cannot represent a variable that is"
" not the result of an operation"
%
self
.
function
)
idx
=
node
.
outputs
.
index
(
output
)
name
=
self
.
names
[
idx
]
input
=
node
.
inputs
[
0
]
return
"
%
s.
%
s"
%
(
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
)),
name
)
class
IgnorePrinter
:
class
IgnorePrinter
:
def
process
(
self
,
output
,
pstate
):
def
process
(
self
,
output
,
pstate
):
if
output
in
pstate
.
memo
:
return
pstate
.
memo
[
output
]
pprinter
=
pstate
.
pprinter
pprinter
=
pstate
.
pprinter
node
=
output
.
owner
node
=
output
.
owner
if
node
is
None
:
if
node
is
None
:
raise
TypeError
(
"function
%
s cannot represent a variable that is"
raise
TypeError
(
"function
%
s cannot represent a variable that is"
" not the result of an operation"
%
self
.
function
)
" not the result of an operation"
%
self
.
function
)
input
=
node
.
inputs
[
0
]
input
=
node
.
inputs
[
0
]
return
"
%
s"
%
pprinter
.
process
(
input
,
pstate
)
r
=
"
%
s"
%
pprinter
.
process
(
input
,
pstate
)
pstate
.
memo
[
output
]
=
r
return
r
class
DefaultPrinter
:
class
DefaultPrinter
:
...
@@ -482,22 +484,30 @@ class DefaultPrinter:
...
@@ -482,22 +484,30 @@ class DefaultPrinter:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
process
(
self
,
r
,
pstate
):
def
process
(
self
,
output
,
pstate
):
if
output
in
pstate
.
memo
:
return
pstate
.
memo
[
output
]
pprinter
=
pstate
.
pprinter
pprinter
=
pstate
.
pprinter
node
=
r
.
owner
node
=
output
.
owner
if
node
is
None
:
if
node
is
None
:
return
LeafPrinter
()
.
process
(
r
,
pstate
)
return
LeafPrinter
()
.
process
(
output
,
pstate
)
r
eturn
"
%
s(
%
s)"
%
(
str
(
node
.
op
),
", "
.
join
(
r
=
"
%
s(
%
s)"
%
(
str
(
node
.
op
),
", "
.
join
(
[
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=-
1000
))
[
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=-
1000
))
for
input
in
node
.
inputs
]))
for
input
in
node
.
inputs
]))
pstate
.
memo
[
output
]
=
r
return
r
class
LeafPrinter
:
class
LeafPrinter
:
def
process
(
self
,
r
,
pstate
):
def
process
(
self
,
output
,
pstate
):
if
r
.
name
in
greek
:
if
output
in
pstate
.
memo
:
return
greek
[
r
.
name
]
return
pstate
.
memo
[
output
]
if
output
.
name
in
greek
:
r
=
greek
[
output
.
name
]
else
:
else
:
return
str
(
r
)
r
=
str
(
output
)
pstate
.
memo
[
output
]
=
r
return
r
class
PPrinter
:
class
PPrinter
:
...
...
theano/scalar/basic.py
浏览文件 @
aecbbb99
...
@@ -3462,6 +3462,8 @@ class Composite(ScalarOp):
...
@@ -3462,6 +3462,8 @@ class Composite(ScalarOp):
init_param
=
(
'inputs'
,
'outputs'
)
init_param
=
(
'inputs'
,
'outputs'
)
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
name
is
None
:
self
.
init_name
()
return
self
.
name
return
self
.
name
def
make_new_inplace
(
self
,
output_types_preference
=
None
,
name
=
None
):
def
make_new_inplace
(
self
,
output_types_preference
=
None
,
name
=
None
):
...
@@ -3485,6 +3487,9 @@ class Composite(ScalarOp):
...
@@ -3485,6 +3487,9 @@ class Composite(ScalarOp):
Return the C code for this Composite Op.
Return the C code for this Composite Op.
"""
"""
# It was already called
if
hasattr
(
self
,
'_c_code'
):
return
subd
=
dict
(
chain
(
subd
=
dict
(
chain
(
((
e
,
"
%%
(i
%
i)s"
%
i
)
for
i
,
e
in
enumerate
(
self
.
fgraph
.
inputs
)),
((
e
,
"
%%
(i
%
i)s"
%
i
)
for
i
,
e
in
enumerate
(
self
.
fgraph
.
inputs
)),
((
e
,
"
%%
(o
%
i)s"
%
i
)
for
i
,
e
in
enumerate
(
self
.
fgraph
.
outputs
))))
((
e
,
"
%%
(o
%
i)s"
%
i
)
for
i
,
e
in
enumerate
(
self
.
fgraph
.
outputs
))))
...
@@ -3533,21 +3538,46 @@ class Composite(ScalarOp):
...
@@ -3533,21 +3538,46 @@ class Composite(ScalarOp):
Return a list of functions that compute each output of self.
Return a list of functions that compute each output of self.
"""
"""
# In the case where the graph is a dag, but not a tree like:
# add(*1 -> mul(x, y), *1)
# We have an efficent way to build the executable (we build
# and traverse each node only once).
# But we don't have an efficient execution. We will execute
# like a tree, so nodes that have more then 1 client will be
# executed as many times as there number of clients. In the
# example aboce, it will calculate *1 twice. Doing otherwise
# imply making a complicated execution engine.
# We need the fast creation of the executor as we always do it
# even if we will use the c code. The Python implementation is
# already slow, so it is not as much important to have a fast
# execution there.
memo
=
{}
def
compose_impl
(
r
):
def
compose_impl
(
r
):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
if
r
in
memo
:
# it will calculate *1 twice
return
memo
[
r
]
# it also doesn't follow fgraph.toposort but that's (presumably)
# still correct since we only have scalar ops
if
r
in
self
.
fgraph
.
inputs
:
if
r
in
self
.
fgraph
.
inputs
:
idx
=
self
.
fgraph
.
inputs
.
index
(
r
)
idx
=
self
.
fgraph
.
inputs
.
index
(
r
)
return
lambda
inputs
:
inputs
[
idx
]
def
f
(
inputs
):
return
inputs
[
idx
]
memo
[
r
]
=
f
return
f
elif
r
.
owner
is
None
:
# in fgraph.orphans:
elif
r
.
owner
is
None
:
# in fgraph.orphans:
return
lambda
inputs
:
r
.
data
def
f
(
inputs
):
return
r
.
data
memo
[
r
]
=
f
return
f
node
=
r
.
owner
node
=
r
.
owner
producers
=
[
compose_impl
(
input
)
for
input
in
node
.
inputs
]
producers
=
[
compose_impl
(
input
)
for
input
in
node
.
inputs
]
def
f
(
inputs
):
def
f
(
inputs
):
return
node
.
op
.
impl
(
*
[
p
(
inputs
)
for
p
in
producers
])
return
node
.
op
.
impl
(
*
[
p
(
inputs
)
for
p
in
producers
])
memo
[
r
]
=
f
return
f
return
f
self
.
_impls
=
[
compose_impl
(
r
)
for
r
in
self
.
fgraph
.
outputs
]
self
.
_impls
=
[
compose_impl
(
r
)
for
r
in
self
.
fgraph
.
outputs
]
...
@@ -3556,32 +3586,19 @@ class Composite(ScalarOp):
...
@@ -3556,32 +3586,19 @@ class Composite(ScalarOp):
Return a readable string representation of self.fgraph.
Return a readable string representation of self.fgraph.
"""
"""
try
:
rval
=
self
.
name
rval
=
self
.
name
if
rval
is
None
:
except
AttributeError
:
for
i
,
r
in
enumerate
(
self
.
fgraph
.
inputs
):
if
0
:
r
.
name
=
'i
%
i'
%
i
l
=
[]
for
i
,
r
in
enumerate
(
self
.
fgraph
.
outputs
):
for
n
in
self
.
fgraph
.
toposort
():
r
.
name
=
'o
%
i'
%
i
if
hasattr
(
n
.
op
,
"name"
)
and
n
.
op
.
name
is
not
None
:
io
=
set
(
self
.
fgraph
.
inputs
+
self
.
fgraph
.
outputs
)
v
=
n
.
op
.
name
for
i
,
r
in
enumerate
(
self
.
fgraph
.
variables
):
if
v
.
startswith
(
"Composite"
):
if
r
not
in
io
and
len
(
r
.
clients
)
>
1
:
v
=
v
[
len
(
"Composite"
):]
r
.
name
=
't
%
i'
%
i
else
:
rval
=
"Composite{
%
s}"
%
', '
.
join
([
pprint
(
output
)
for
output
v
=
n
.
op
.
__class__
.
__name__
in
self
.
fgraph
.
outputs
])
l
.
append
(
v
)
self
.
name
=
rval
rval
=
"Composite{"
+
","
.
join
(
l
)
+
"}"
else
:
for
i
,
r
in
enumerate
(
self
.
fgraph
.
inputs
):
r
.
name
=
'i
%
i'
%
i
for
i
,
r
in
enumerate
(
self
.
fgraph
.
outputs
):
r
.
name
=
'o
%
i'
%
i
io
=
set
(
self
.
fgraph
.
inputs
+
self
.
fgraph
.
outputs
)
for
i
,
r
in
enumerate
(
self
.
fgraph
.
variables
):
if
r
not
in
io
and
len
(
r
.
clients
)
>
1
:
r
.
name
=
't
%
i'
%
i
rval
=
"Composite{
%
s}"
%
', '
.
join
([
pprint
(
output
)
for
output
in
self
.
fgraph
.
outputs
])
self
.
name
=
rval
def
init_fgraph
(
self
):
def
init_fgraph
(
self
):
# The clone done by FunctionGraph is needed as we don't want
# The clone done by FunctionGraph is needed as we don't want
...
@@ -3642,9 +3659,15 @@ class Composite(ScalarOp):
...
@@ -3642,9 +3659,15 @@ class Composite(ScalarOp):
self
.
nin
=
len
(
inputs
)
self
.
nin
=
len
(
inputs
)
self
.
nout
=
len
(
outputs
)
self
.
nout
=
len
(
outputs
)
self
.
init_fgraph
()
# self.fgraph
self
.
init_fgraph
()
# self.fgraph
self
.
init_name
()
# self.name
self
.
init_c_code
()
# self._c_code and self.nodenames
# Postpone the creation in case it isn't needed.
# self.init_name() # self.name
self
.
name
=
None
def
prepare_node
(
self
,
node
,
storage_map
,
compute_map
):
self
.
init_py_impls
()
# self._impls
self
.
init_py_impls
()
# self._impls
for
n
in
theano
.
gof
.
graph
.
list_of_nodes
(
self
.
inputs
,
self
.
outputs
):
n
.
op
.
prepare_node
(
n
,
None
,
None
)
def
output_types
(
self
,
input_types
):
def
output_types
(
self
,
input_types
):
if
tuple
(
input_types
)
!=
self
.
inputs_type
:
if
tuple
(
input_types
)
!=
self
.
inputs_type
:
...
@@ -3688,6 +3711,9 @@ class Composite(ScalarOp):
...
@@ -3688,6 +3711,9 @@ class Composite(ScalarOp):
raise
NotImplementedError
(
"grad is not implemented for Composite"
)
raise
NotImplementedError
(
"grad is not implemented for Composite"
)
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
if
not
hasattr
(
self
,
'_c_code'
):
self
.
init_c_code
()
d
=
dict
(
chain
(
izip
((
"i
%
i"
%
i
for
i
in
xrange
(
len
(
inames
))),
inames
),
d
=
dict
(
chain
(
izip
((
"i
%
i"
%
i
for
i
in
xrange
(
len
(
inames
))),
inames
),
izip
((
"o
%
i"
%
i
for
i
in
xrange
(
len
(
onames
))),
izip
((
"o
%
i"
%
i
for
i
in
xrange
(
len
(
onames
))),
onames
)),
**
sub
)
onames
)),
**
sub
)
...
@@ -3745,9 +3771,13 @@ class Composite(ScalarOp):
...
@@ -3745,9 +3771,13 @@ class Composite(ScalarOp):
return
False
return
False
# see __hash__ for comment on why there is no mention of fgraph
# see __hash__ for comment on why there is no mention of fgraph
# or module cache key here.
# or module cache key here.
if
not
hasattr
(
self
,
'_c_code'
):
self
.
init_c_code
()
# self._c_code and self.nodenames
return
(
self
.
_c_code
==
other
.
_c_code
)
return
(
self
.
_c_code
==
other
.
_c_code
)
def
__hash__
(
self
):
def
__hash__
(
self
):
if
not
hasattr
(
self
,
'_c_code'
):
self
.
init_c_code
()
# self._c_code and self.nodenames
rval
=
hash
((
type
(
self
),
rval
=
hash
((
type
(
self
),
self
.
nin
,
self
.
nin
,
self
.
nout
,
self
.
nout
,
...
@@ -3764,7 +3794,7 @@ class Composite(ScalarOp):
...
@@ -3764,7 +3794,7 @@ class Composite(ScalarOp):
def
__getstate__
(
self
):
def
__getstate__
(
self
):
rval
=
dict
(
self
.
__dict__
)
rval
=
dict
(
self
.
__dict__
)
del
rval
[
'_impls'
]
rval
.
pop
(
'_impls'
,
None
)
del
rval
[
'fgraph'
]
del
rval
[
'fgraph'
]
return
rval
return
rval
...
...
theano/tensor/elemwise.py
浏览文件 @
aecbbb99
...
@@ -849,6 +849,14 @@ second dimension
...
@@ -849,6 +849,14 @@ second dimension
char
=
numpy
.
sctype2char
(
out_dtype
)
char
=
numpy
.
sctype2char
(
out_dtype
)
sig
=
char
*
node
.
nin
+
'->'
+
char
*
node
.
nout
sig
=
char
*
node
.
nin
+
'->'
+
char
*
node
.
nout
node
.
tag
.
sig
=
sig
node
.
tag
.
sig
=
sig
node
.
tag
.
fake_node
=
Apply
(
self
.
scalar_op
,
[
get_scalar_type
(
dtype
=
input
.
type
.
dtype
)
.
make_variable
()
for
input
in
node
.
inputs
],
[
get_scalar_type
(
dtype
=
output
.
type
.
dtype
)
.
make_variable
()
for
output
in
node
.
outputs
])
self
.
scalar_op
.
prepare_node
(
node
.
tag
.
fake_node
,
None
,
None
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
if
len
(
node
.
inputs
)
>=
32
:
if
len
(
node
.
inputs
)
>=
32
:
...
@@ -991,6 +999,11 @@ second dimension
...
@@ -991,6 +999,11 @@ second dimension
return
rval
return
rval
def
_c_all
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
def
_c_all
(
self
,
node
,
nodename
,
inames
,
onames
,
sub
):
# Some ops call directly the Elemwise._c_all or Elemwise.c_code
# To not request all of them to call prepare_node(), do it here.
# There is no harm if it get called multile time.
if
not
hasattr
(
node
.
tag
,
'fake_node'
):
self
.
prepare_node
(
node
,
None
,
None
)
_inames
=
inames
_inames
=
inames
_onames
=
onames
_onames
=
onames
...
@@ -1109,11 +1122,7 @@ second dimension
...
@@ -1109,11 +1122,7 @@ second dimension
# We generate the C code of the inner loop using the scalar op
# We generate the C code of the inner loop using the scalar op
task_code
=
self
.
scalar_op
.
c_code
(
task_code
=
self
.
scalar_op
.
c_code
(
Apply
(
self
.
scalar_op
,
node
.
tag
.
fake_node
,
[
get_scalar_type
(
dtype
=
input
.
type
.
dtype
)
.
make_variable
()
for
input
in
node
.
inputs
],
[
get_scalar_type
(
dtype
=
output
.
type
.
dtype
)
.
make_variable
()
for
output
in
node
.
outputs
]),
nodename
+
'_scalar_'
,
nodename
+
'_scalar_'
,
[
"
%
s_i"
%
s
for
s
in
_inames
],
[
"
%
s_i"
%
s
for
s
in
_inames
],
[
"
%
s_i"
%
s
for
s
in
onames
],
[
"
%
s_i"
%
s
for
s
in
onames
],
...
...
theano/tensor/opt.py
浏览文件 @
aecbbb99
...
@@ -7183,7 +7183,9 @@ def local_add_mul_fusion(node):
...
@@ -7183,7 +7183,9 @@ def local_add_mul_fusion(node):
for
inp
in
node
.
inputs
:
for
inp
in
node
.
inputs
:
if
(
inp
.
owner
and
if
(
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
Elemwise
)
and
isinstance
(
inp
.
owner
.
op
,
Elemwise
)
and
isinstance
(
inp
.
owner
.
op
.
scalar_op
,
s_op
)):
isinstance
(
inp
.
owner
.
op
.
scalar_op
,
s_op
)
and
# Do not duplicate the operation.
len
(
inp
.
clients
)
==
1
):
new_inp
.
extend
(
inp
.
owner
.
inputs
)
new_inp
.
extend
(
inp
.
owner
.
inputs
)
fused
=
True
fused
=
True
else
:
else
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论