Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f1d61ed9
提交
f1d61ed9
authored
10月 22, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added pprint (as printing) to the theano package root
上级
d5ae91e1
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
230 行增加
和
148 行删除
+230
-148
logistic_regression.py
examples/logistic_regression.py
+10
-4
__init__.py
theano/__init__.py
+7
-1
mode.py
theano/compile/mode.py
+9
-2
module.py
theano/compile/module.py
+19
-13
toolbox.py
theano/gof/toolbox.py
+1
-1
printing.py
theano/printing.py
+18
-106
basic.py
theano/scalar/basic.py
+17
-7
basic.py
theano/tensor/basic.py
+91
-10
elemwise.py
theano/tensor/elemwise.py
+27
-0
inplace.py
theano/tensor/inplace.py
+24
-0
opt.py
theano/tensor/opt.py
+7
-4
没有找到文件。
examples/logistic_regression.py
浏览文件 @
f1d61ed9
...
@@ -4,7 +4,8 @@ import theano
...
@@ -4,7 +4,8 @@ import theano
from
theano
import
tensor
as
T
from
theano
import
tensor
as
T
from
theano.tensor
import
nnet_ops
from
theano.tensor
import
nnet_ops
from
theano.compile
import
module
from
theano.compile
import
module
from
theano.sandbox
import
pprint
from
theano
import
printing
,
pprint
from
theano
import
compile
import
numpy
as
N
import
numpy
as
N
...
@@ -17,6 +18,7 @@ class LogisticRegressionN(module.FancyModule):
...
@@ -17,6 +18,7 @@ class LogisticRegressionN(module.FancyModule):
self
.
w
=
N
.
random
.
randn
(
n_in
,
n_out
)
self
.
w
=
N
.
random
.
randn
(
n_in
,
n_out
)
self
.
b
=
N
.
random
.
randn
(
n_out
)
self
.
b
=
N
.
random
.
randn
(
n_out
)
self
.
lr
=
0.01
self
.
lr
=
0.01
self
.
__hide__
=
[
'params'
]
def
__init__
(
self
,
x
=
None
,
targ
=
None
):
def
__init__
(
self
,
x
=
None
,
targ
=
None
):
super
(
LogisticRegressionN
,
self
)
.
__init__
()
#boilerplate
super
(
LogisticRegressionN
,
self
)
.
__init__
()
#boilerplate
...
@@ -84,8 +86,8 @@ class LogisticRegression2(module.FancyModule):
...
@@ -84,8 +86,8 @@ class LogisticRegression2(module.FancyModule):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
pprint
.
pp
.
assign
(
nnet_ops
.
crossentropy_softmax_1hot_with_bias_dx
,
pprint
.
FunctionPrinter
(
'xsoftmaxdx'
))
pprint
.
assign
(
nnet_ops
.
crossentropy_softmax_1hot_with_bias_dx
,
printing
.
FunctionPrinter
(
'xsoftmaxdx'
))
pprint
.
pp
.
assign
(
nnet_ops
.
crossentropy_softmax_argmax_1hot_with_bias
,
pprint
.
FunctionPrinter
(
'nll'
,
'softmax'
,
'argmax'
))
pprint
.
assign
(
nnet_ops
.
crossentropy_softmax_argmax_1hot_with_bias
,
printing
.
FunctionPrinter
(
'nll'
,
'softmax'
,
'argmax'
))
if
1
:
if
1
:
lrc
=
LogisticRegressionN
()
lrc
=
LogisticRegressionN
()
...
@@ -94,17 +96,21 @@ if __name__ == '__main__':
...
@@ -94,17 +96,21 @@ if __name__ == '__main__':
print
'================'
print
'================'
print
lrc
.
update
.
pretty
(
mode
=
theano
.
Mode
(
'py'
,
'fast_run'
))
print
lrc
.
update
.
pretty
(
mode
=
theano
.
Mode
(
'py'
,
'fast_run'
))
print
'================'
print
'================'
# print lrc.update.pretty(mode = compile.FAST_RUN.excluding('inplace'))
# print '================'
# sys.exit(0)
# sys.exit(0)
lr
=
lrc
.
make
(
10
,
2
,
mode
=
theano
.
Mode
(
'c|py'
,
'fast_run'
))
lr
=
lrc
.
make
(
10
,
2
,
mode
=
theano
.
Mode
(
'c|py'
,
'fast_run'
))
#lr = lrc.make(10, 2, mode=compile.FAST_RUN.excluding('fast_run'))
#lr = lrc.make(10, 2, mode=theano.Mode('py', 'merge')) #'FAST_RUN')
#lr = lrc.make(10, 2, mode=theano.Mode('py', 'merge')) #'FAST_RUN')
data_x
=
N
.
random
.
randn
(
5
,
10
)
data_x
=
N
.
random
.
randn
(
5
,
10
)
data_y
=
(
N
.
random
.
randn
(
5
)
>
0
)
data_y
=
(
N
.
random
.
randn
(
5
)
>
0
)
for
i
in
xrange
(
10000
):
for
i
in
xrange
(
10000
):
xe
=
lr
.
update
(
data_x
,
data_y
)
lr
.
lr
=
0.02
xe
=
lr
.
update
(
data_x
,
data_y
)
if
i
%
100
==
0
:
if
i
%
100
==
0
:
print
i
,
xe
print
i
,
xe
...
...
theano/__init__.py
浏览文件 @
f1d61ed9
...
@@ -41,13 +41,19 @@ from compile import \
...
@@ -41,13 +41,19 @@ from compile import \
SymbolicOutput
,
Out
,
\
SymbolicOutput
,
Out
,
\
Mode
,
\
Mode
,
\
predefined_modes
,
predefined_linkers
,
predefined_optimizers
,
\
predefined_modes
,
predefined_linkers
,
predefined_optimizers
,
\
FunctionMaker
,
function
,
OpFromGraph
#, eval_outputs, fast_compute
FunctionMaker
,
function
,
OpFromGraph
,
\
Component
,
External
,
Member
,
KitComponent
,
Method
,
\
Composite
,
ComponentList
,
Module
,
FancyModule
from
printing
import
\
pprint
,
pp
import
tensor
import
tensor
import
scalar
import
scalar
import
sparse
import
sparse
import
gradient
import
gradient
## import scalar_opt
## import scalar_opt
import
subprocess
as
_subprocess
import
subprocess
as
_subprocess
...
...
theano/compile/mode.py
浏览文件 @
f1d61ed9
...
@@ -102,12 +102,19 @@ class Mode(object):
...
@@ -102,12 +102,19 @@ class Mode(object):
optimizer
=
predefined_optimizers
[
optimizer
]
optimizer
=
predefined_optimizers
[
optimizer
]
if
isinstance
(
optimizer
,
gof
.
Query
):
if
isinstance
(
optimizer
,
gof
.
Query
):
self
.
provided_optimizer
=
optimizer
self
.
provided_optimizer
=
optimizer
optimizer
=
optdb
.
query
(
optimizer
)
self
.
_optimizer
=
optimizer
self
.
optimizer
=
optimizer
def
__str__
(
self
):
def
__str__
(
self
):
return
"Mode(linker =
%
s, optimizer =
%
s)"
%
(
self
.
provided_linker
,
self
.
provided_optimizer
)
return
"Mode(linker =
%
s, optimizer =
%
s)"
%
(
self
.
provided_linker
,
self
.
provided_optimizer
)
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
return
optdb
.
query
(
self
.
_optimizer
)
else
:
return
self
.
_optimizer
optimizer
=
property
(
__get_optimizer
)
def
including
(
self
,
*
tags
):
def
including
(
self
,
*
tags
):
return
Mode
(
self
.
provided_linker
,
self
.
provided_optimizer
.
including
(
*
tags
))
return
Mode
(
self
.
provided_linker
,
self
.
provided_optimizer
.
including
(
*
tags
))
...
...
theano/compile/module.py
浏览文件 @
f1d61ed9
from
..
import
gof
from
..
import
gof
from
..printing
import
pprint
from
collections
import
defaultdict
from
collections
import
defaultdict
from
itertools
import
chain
from
itertools
import
chain
from
functools
import
partial
from
functools
import
partial
from
copy
import
copy
from
copy
import
copy
import
mode
import
io
import
function_module
as
F
import
function_module
as
F
#from ..sandbox import pprint
def
join
(
*
args
):
def
join
(
*
args
):
...
@@ -117,7 +117,7 @@ class External(_RComponent):
...
@@ -117,7 +117,7 @@ class External(_RComponent):
def
pretty
(
self
,
**
kwargs
):
def
pretty
(
self
,
**
kwargs
):
rval
=
super
(
External
,
self
)
.
pretty
()
rval
=
super
(
External
,
self
)
.
pretty
()
if
self
.
r
.
owner
:
if
self
.
r
.
owner
:
rval
+=
'
\n
=
%
s'
%
(
pprint
.
pp2
.
process
(
self
.
r
,
dict
(
target
=
self
.
r
)))
rval
+=
'
\n
=
%
s'
%
(
pprint
(
self
.
r
,
dict
(
target
=
self
.
r
)))
return
rval
return
rval
...
@@ -196,13 +196,15 @@ class Method(Component):
...
@@ -196,13 +196,15 @@ class Method(Component):
else
:
else
:
return
gof
.
Container
(
r
,
storage
=
[
None
])
return
gof
.
Container
(
r
,
storage
=
[
None
])
inputs
=
self
.
inputs
inputs
=
self
.
inputs
inputs
=
[
mode
.
In
(
result
=
input
,
inputs
=
[
io
.
In
(
result
=
input
,
value
=
get_storage
(
input
))
value
=
get_storage
(
input
),
mutable
=
False
)
for
input
in
inputs
]
for
input
in
inputs
]
inputs
+=
[
mode
.
In
(
result
=
k
,
inputs
+=
[
io
.
In
(
result
=
k
,
update
=
v
,
update
=
v
,
value
=
get_storage
(
k
,
True
),
value
=
get_storage
(
k
,
True
),
strict
=
True
)
mutable
=
True
,
strict
=
True
)
for
k
,
v
in
self
.
updates
.
iteritems
()]
for
k
,
v
in
self
.
updates
.
iteritems
()]
outputs
=
self
.
outputs
outputs
=
self
.
outputs
_inputs
=
[
x
.
result
for
x
in
inputs
]
_inputs
=
[
x
.
result
for
x
in
inputs
]
...
@@ -210,8 +212,9 @@ class Method(Component):
...
@@ -210,8 +212,9 @@ class Method(Component):
+
[
x
.
update
for
x
in
inputs
if
getattr
(
x
,
'update'
,
False
)],
+
[
x
.
update
for
x
in
inputs
if
getattr
(
x
,
'update'
,
False
)],
blockers
=
_inputs
):
blockers
=
_inputs
):
if
input
not
in
_inputs
and
not
isinstance
(
input
,
gof
.
Value
):
if
input
not
in
_inputs
and
not
isinstance
(
input
,
gof
.
Value
):
inputs
+=
[
mode
.
In
(
result
=
input
,
inputs
+=
[
io
.
In
(
result
=
input
,
value
=
get_storage
(
input
,
True
))]
value
=
get_storage
(
input
,
True
),
mutable
=
False
)]
inputs
+=
[(
kit
,
get_storage
(
kit
,
True
))
for
kit
in
self
.
kits
]
inputs
+=
[(
kit
,
get_storage
(
kit
,
True
))
for
kit
in
self
.
kits
]
return
F
.
function
(
inputs
,
outputs
,
mode
)
return
F
.
function
(
inputs
,
outputs
,
mode
)
...
@@ -234,11 +237,14 @@ class Method(Component):
...
@@ -234,11 +237,14 @@ class Method(Component):
nup
=
len
(
k
)
nup
=
len
(
k
)
eff_in
=
tuple
(
inputs
)
+
tuple
(
k
)
eff_in
=
tuple
(
inputs
)
+
tuple
(
k
)
eff_out
=
tuple
(
outputs
)
+
tuple
(
v
)
eff_out
=
tuple
(
outputs
)
+
tuple
(
v
)
env
=
gof
.
Env
(
*
gof
.
graph
.
clone
(
eff_in
+
tuple
(
gof
.
graph
.
inputs
(
eff_out
)),
supp_in
=
tuple
(
gof
.
graph
.
inputs
(
eff_out
))
env
=
gof
.
Env
(
*
gof
.
graph
.
clone
(
eff_in
+
supp_in
,
eff_out
))
eff_out
))
sup
=
F
.
Supervisor
(
set
(
env
.
inputs
)
.
difference
(
env
.
inputs
[
len
(
inputs
):
len
(
eff_in
)]))
env
.
extend
(
sup
)
mode
.
optimizer
.
optimize
(
env
)
mode
.
optimizer
.
optimize
(
env
)
inputs
,
outputs
,
updates
=
env
.
inputs
[:
nin
],
env
.
outputs
[:
nout
],
dict
(
zip
(
env
.
inputs
[
nin
:],
env
.
outputs
[
nout
:]))
inputs
,
outputs
,
updates
=
env
.
inputs
[:
nin
],
env
.
outputs
[:
nout
],
dict
(
zip
(
env
.
inputs
[
nin
:],
env
.
outputs
[
nout
:]))
rval
+=
pprint
.
pp
.
process_graph
(
inputs
,
outputs
,
updates
,
False
)
rval
+=
pprint
(
inputs
,
outputs
,
updates
,
False
)
return
rval
return
rval
def
__str__
(
self
):
def
__str__
(
self
):
...
...
theano/gof/toolbox.py
浏览文件 @
f1d61ed9
...
@@ -101,7 +101,7 @@ class ReplaceValidate(History, Validator):
...
@@ -101,7 +101,7 @@ class ReplaceValidate(History, Validator):
try
:
try
:
env
.
replace
(
r
,
new_r
)
env
.
replace
(
r
,
new_r
)
except
Exception
,
e
:
except
Exception
,
e
:
if
not
'The type of the replacement must be the same'
in
str
(
e
)
or
not
'does not belong to this Env'
in
str
(
e
):
if
'The type of the replacement must be the same'
not
in
str
(
e
)
and
'does not belong to this Env'
not
in
str
(
e
):
print
>>
sys
.
stderr
,
"<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>"
,
type
(
e
),
e
print
>>
sys
.
stderr
,
"<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>"
,
type
(
e
),
e
env
.
revert
(
chk
)
# this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
env
.
revert
(
chk
)
# this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise
raise
...
...
theano/
sandbox/pprint
.py
→
theano/
printing
.py
浏览文件 @
f1d61ed9
from
..
import
tensor
as
T
import
gof
from
..
import
scalar
as
S
from
..
import
gof
from
copy
import
copy
from
copy
import
copy
import
sys
import
sys
...
@@ -88,7 +86,7 @@ class FunctionPrinter:
...
@@ -88,7 +86,7 @@ class FunctionPrinter:
raise
TypeError
(
"function
%
s cannot represent a result with no associated operation"
%
self
.
names
)
raise
TypeError
(
"function
%
s cannot represent a result with no associated operation"
%
self
.
names
)
idx
=
node
.
outputs
.
index
(
output
)
idx
=
node
.
outputs
.
index
(
output
)
name
=
self
.
names
[
idx
]
name
=
self
.
names
[
idx
]
return
"
%
s(
%
s)"
%
(
name
,
", "
.
join
([
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
))
return
"
%
s(
%
s)"
%
(
name
,
", "
.
join
([
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
-
1000
))
for
input
in
node
.
inputs
]))
for
input
in
node
.
inputs
]))
class
MemberPrinter
:
class
MemberPrinter
:
...
@@ -119,65 +117,6 @@ class IgnorePrinter:
...
@@ -119,65 +117,6 @@ class IgnorePrinter:
return
"
%
s"
%
pprinter
.
process
(
input
,
pstate
)
return
"
%
s"
%
pprinter
.
process
(
input
,
pstate
)
class
DimShufflePrinter
:
def
__p
(
self
,
new_order
,
pstate
,
r
):
if
new_order
!=
()
and
new_order
[
0
]
==
'x'
:
# return "%s" % self.__p(new_order[1:], pstate, r)
return
"[
%
s]"
%
self
.
__p
(
new_order
[
1
:],
pstate
,
r
)
if
list
(
new_order
)
==
range
(
r
.
type
.
ndim
):
return
pstate
.
pprinter
.
process
(
r
)
if
list
(
new_order
)
==
list
(
reversed
(
range
(
r
.
type
.
ndim
))):
return
"
%
s.T"
%
pstate
.
pprinter
.
process
(
r
)
return
"DimShuffle{
%
s}(
%
s)"
%
(
", "
.
join
(
map
(
str
,
new_order
)),
pstate
.
pprinter
.
process
(
r
))
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print DimShuffle."
)
elif
isinstance
(
r
.
owner
.
op
,
T
.
DimShuffle
):
ord
=
r
.
owner
.
op
.
new_order
return
self
.
__p
(
ord
,
pstate
,
r
.
owner
.
inputs
[
0
])
else
:
raise
TypeError
(
"Can only print DimShuffle."
)
class
SubtensorPrinter
:
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print Subtensor."
)
elif
isinstance
(
r
.
owner
.
op
,
T
.
Subtensor
):
idxs
=
r
.
owner
.
op
.
idx_list
inputs
=
list
(
r
.
owner
.
inputs
)
input
=
inputs
.
pop
()
sidxs
=
[]
inbrack_pstate
=
pstate
.
clone
(
precedence
=
-
1000
)
for
entry
in
idxs
:
if
isinstance
(
entry
,
int
):
sidxs
.
append
(
str
(
entry
))
elif
isinstance
(
entry
,
S
.
Scalar
):
sidxs
.
append
(
inbrack_pstate
.
pprinter
.
process
(
inputs
.
pop
()))
elif
isinstance
(
entry
,
slice
):
sidxs
.
append
(
"
%
s:
%
s
%
s"
%
(
""
if
entry
.
start
is
None
or
entry
.
start
==
0
else
entry
.
start
,
""
if
entry
.
stop
is
None
or
entry
.
stop
==
sys
.
maxint
else
entry
.
stop
,
""
if
entry
.
step
is
None
else
":
%
s"
%
entry
.
step
))
return
"
%
s[
%
s]"
%
(
pstate
.
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
)),
", "
.
join
(
sidxs
))
else
:
raise
TypeError
(
"Can only print Subtensor."
)
class
MakeVectorPrinter
:
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print make_vector."
)
elif
isinstance
(
r
.
owner
.
op
,
T
.
MakeVector
):
return
"[
%
s]"
%
", "
.
join
(
pstate
.
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
))
for
input
in
r
.
owner
.
inputs
)
else
:
raise
TypeError
(
"Can only print make_vector."
)
class
DefaultPrinter
:
class
DefaultPrinter
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -263,6 +202,16 @@ class PPrinter:
...
@@ -263,6 +202,16 @@ class PPrinter:
strings
.
sort
()
strings
.
sort
()
return
"
\n
"
.
join
(
s
[
1
]
for
s
in
strings
)
return
"
\n
"
.
join
(
s
[
1
]
for
s
in
strings
)
def
__call__
(
self
,
*
args
):
if
len
(
args
)
==
1
:
return
self
.
process
(
*
args
)
elif
len
(
args
)
==
2
and
isinstance
(
args
[
1
],
(
PrinterState
,
dict
)):
return
self
.
process
(
*
args
)
elif
len
(
args
)
>
2
:
return
self
.
process_graph
(
*
args
)
else
:
raise
TypeError
(
'Not enough arguments to call.'
)
...
@@ -276,47 +225,10 @@ greek = dict(alpha = u"\u03B1",
...
@@ -276,47 +225,10 @@ greek = dict(alpha = u"\u03B1",
epsilon
=
u"
\u03B5
"
)
epsilon
=
u"
\u03B5
"
)
ppow
=
OperatorPrinter
(
'**'
,
1
,
'right'
)
pprint
=
PPrinter
()
pneg
=
OperatorPrinter
(
'-'
,
0
,
'either'
)
pprint
.
assign
(
lambda
pstate
,
r
:
True
,
DefaultPrinter
())
pmul
=
OperatorPrinter
(
'*'
,
-
1
,
'either'
)
pprint
.
assign
(
lambda
pstate
,
r
:
hasattr
(
pstate
,
'target'
)
and
pstate
.
target
is
not
r
and
r
.
name
is
not
None
,
pdiv
=
OperatorPrinter
(
'/'
,
-
1
,
'left'
)
LeafPrinter
())
padd
=
OperatorPrinter
(
'+'
,
-
2
,
'either'
)
psub
=
OperatorPrinter
(
'-'
,
-
2
,
'left'
)
pp
=
pprint
pdot
=
OperatorPrinter
(
special
[
'middle_dot'
],
-
1
,
'left'
)
psum
=
OperatorPrinter
(
special
[
'big_sigma'
]
+
' '
,
-
2
,
'left'
)
from
..tensor
import
inplace
as
I
def
pprinter
():
pp
=
PPrinter
()
pp
.
assign
(
lambda
pstate
,
r
:
True
,
DefaultPrinter
())
pp
.
assign
(
T
.
add
,
padd
)
pp
.
assign
(
T
.
mul
,
pmul
)
pp
.
assign
(
T
.
sub
,
psub
)
pp
.
assign
(
T
.
neg
,
pneg
)
pp
.
assign
(
T
.
div
,
pdiv
)
pp
.
assign
(
T
.
pow
,
ppow
)
pp
.
assign
(
T
.
dot
,
pdot
)
pp
.
assign
(
T
.
Sum
(),
FunctionPrinter
(
'sum'
))
pp
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
T
.
DimShuffle
),
DimShufflePrinter
())
pp
.
assign
(
T
.
tensor_copy
,
IgnorePrinter
())
pp
.
assign
(
T
.
log
,
FunctionPrinter
(
'log'
))
pp
.
assign
(
T
.
tanh
,
FunctionPrinter
(
'tanh'
))
pp
.
assign
(
I
.
transpose_inplace
,
MemberPrinter
(
'T'
))
pp
.
assign
(
T
.
abs_
,
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
pp
.
assign
(
T
.
sgn
,
FunctionPrinter
(
'sgn'
))
pp
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
T
.
Filler
)
and
r
.
owner
.
op
.
value
==
0
,
FunctionPrinter
(
'seros'
))
pp
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
T
.
Filler
)
and
r
.
owner
.
op
.
value
==
1
,
FunctionPrinter
(
'ones'
))
pp
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
T
.
Subtensor
),
SubtensorPrinter
())
pp
.
assign
(
T
.
shape
,
MemberPrinter
(
'shape'
))
pp
.
assign
(
T
.
fill
,
FunctionPrinter
(
'fill'
))
#pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
pp
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
T
.
MakeVector
),
MakeVectorPrinter
())
return
pp
pp
=
pprinter
()
pp2
=
pprinter
()
pp2
.
assign
(
lambda
pstate
,
r
:
hasattr
(
pstate
,
'target'
)
and
pstate
.
target
is
not
r
and
r
.
name
is
not
None
,
LeafPrinter
())
theano/scalar/basic.py
浏览文件 @
f1d61ed9
...
@@ -252,17 +252,26 @@ def upcast_out(*types):
...
@@ -252,17 +252,26 @@ def upcast_out(*types):
return
Scalar
(
dtype
=
Scalar
.
upcast
(
*
types
)),
return
Scalar
(
dtype
=
Scalar
.
upcast
(
*
types
)),
def
same_out
(
type
):
def
same_out
(
type
):
return
type
,
return
type
,
class
transfer_type
:
class
transfer_type
(
gof
.
utils
.
object2
)
:
def
__init__
(
self
,
i
):
def
__init__
(
self
,
*
transfer
):
assert
type
(
i
)
==
int
assert
all
(
type
(
x
)
==
int
for
x
in
transfer
)
self
.
i
=
i
self
.
transfer
=
transfer
def
__call__
(
self
,
*
types
):
def
__call__
(
self
,
*
types
):
return
types
[
self
.
i
],
upcast
=
upcast_out
(
*
types
)
class
specific_out
:
return
[
upcast
if
i
is
None
else
types
[
i
]
for
i
in
self
.
transfer
]
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
transfer
==
other
.
transfer
def
__hash__
(
self
):
return
hash
(
self
.
transfer
)
class
specific_out
(
gof
.
utils
.
object2
):
def
__init__
(
self
,
*
spec
):
def
__init__
(
self
,
*
spec
):
self
.
spec
=
spec
self
.
spec
=
spec
def
__call__
(
self
,
*
types
):
def
__call__
(
self
,
*
types
):
return
self
.
spec
return
self
.
spec
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
spec
==
other
.
spec
def
__hash__
(
self
):
return
hash
(
self
.
spec
)
def
int_out
(
*
types
):
def
int_out
(
*
types
):
return
int64
,
return
int64
,
def
float_out
(
*
types
):
def
float_out
(
*
types
):
...
@@ -328,9 +337,10 @@ class ScalarOp(Op):
...
@@ -328,9 +337,10 @@ class ScalarOp(Op):
raise
AbstractFunctionError
()
raise
AbstractFunctionError
()
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
\
test
=
type
(
self
)
==
type
(
other
)
\
and
getattr
(
self
,
'output_types_preference'
,
None
)
\
and
getattr
(
self
,
'output_types_preference'
,
None
)
\
==
getattr
(
other
,
'output_types_preference'
,
None
)
==
getattr
(
other
,
'output_types_preference'
,
None
)
return
test
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
getattr
(
self
,
'output_types_preference'
,
0
))
return
hash
(
getattr
(
self
,
'output_types_preference'
,
0
))
...
...
theano/tensor/basic.py
浏览文件 @
f1d61ed9
...
@@ -20,7 +20,8 @@ import elemwise
...
@@ -20,7 +20,8 @@ import elemwise
from
..
import
scalar
as
scal
from
..
import
scalar
as
scal
from
..gof.python25
import
partial
from
..gof.python25
import
partial
from
..
import
compile
from
..
import
compile
,
printing
from
..printing
import
pprint
### set up the external interface
### set up the external interface
...
@@ -614,6 +615,8 @@ def _scal_elemwise(symbol):
...
@@ -614,6 +615,8 @@ def _scal_elemwise(symbol):
rval
.
__epydoc_asRoutine
=
symbol
rval
.
__epydoc_asRoutine
=
symbol
rval
.
__module__
=
'tensor'
rval
.
__module__
=
'tensor'
pprint
.
assign
(
rval
,
printing
.
FunctionPrinter
(
symbolname
))
return
rval
return
rval
...
@@ -661,33 +664,34 @@ def cast(t, dtype):
...
@@ -661,33 +664,34 @@ def cast(t, dtype):
return
mapping
[
dtype
](
t
)
return
mapping
[
dtype
](
t
)
#to be removed as we get the epydoc routine-documenting thing going -JB 20080924
#to be removed as we get the epydoc routine-documenting thing going -JB 20080924
def
_conversion
(
real_value
):
def
_conversion
(
real_value
,
name
):
__oplist_tag
(
real_value
,
'casting'
)
__oplist_tag
(
real_value
,
'casting'
)
real_value
.
__module__
=
'tensor.basic'
real_value
.
__module__
=
'tensor.basic'
pprint
.
assign
(
real_value
,
printing
.
FunctionPrinter
(
name
))
return
real_value
return
real_value
convert_to_int8
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int8
))))
convert_to_int8
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int8
)))
,
'int8'
)
"""Cast to 8-bit integer"""
"""Cast to 8-bit integer"""
convert_to_int16
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int16
))))
convert_to_int16
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int16
)))
,
'int16'
)
"""Cast to 16-bit integer"""
"""Cast to 16-bit integer"""
convert_to_int32
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int32
))))
convert_to_int32
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int32
)))
,
'int32'
)
"""Cast to 32-bit integer"""
"""Cast to 32-bit integer"""
convert_to_int64
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int64
))))
convert_to_int64
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
int64
)))
,
'int64'
)
"""Cast to 64-bit integer"""
"""Cast to 64-bit integer"""
convert_to_float32
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
float32
))))
convert_to_float32
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
float32
)))
,
'float32'
)
"""Cast to single-precision floating point"""
"""Cast to single-precision floating point"""
convert_to_float64
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
float64
))))
convert_to_float64
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
float64
)))
,
'float64'
)
"""Cast to double-precision floating point"""
"""Cast to double-precision floating point"""
convert_to_complex64
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
complex64
))))
convert_to_complex64
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
complex64
)))
,
'complex64'
)
"""Cast to single-precision complex"""
"""Cast to single-precision complex"""
convert_to_complex128
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
complex128
))))
convert_to_complex128
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
Identity
(
scal
.
specific_out
(
scal
.
complex128
)))
,
'complex128'
)
"""Cast to double-precision complex"""
"""Cast to double-precision complex"""
...
@@ -713,6 +717,9 @@ class Shape(Op):
...
@@ -713,6 +717,9 @@ class Shape(Op):
def
shape
(
a
):
def
shape
(
a
):
pass
pass
pprint
.
assign
(
shape
,
printing
.
MemberPrinter
(
'shape'
))
class
MaxAndArgmax
(
Op
):
class
MaxAndArgmax
(
Op
):
"""Calculate the max and argmax over a given axis"""
"""Calculate the max and argmax over a given axis"""
nin
=
2
# tensor, axis
nin
=
2
# tensor, axis
...
@@ -834,6 +841,9 @@ def abs_(a):
...
@@ -834,6 +841,9 @@ def abs_(a):
"""
"""
pprint
.
assign
(
abs_
,
printing
.
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
@_scal_elemwise
@_scal_elemwise
def
exp
(
a
):
def
exp
(
a
):
"""e^`a`"""
"""e^`a`"""
...
@@ -902,6 +912,8 @@ def second(a, b):
...
@@ -902,6 +912,8 @@ def second(a, b):
"""Create a matrix by filling the shape of a with b"""
"""Create a matrix by filling the shape of a with b"""
fill
=
second
fill
=
second
pprint
.
assign
(
fill
,
printing
.
FunctionPrinter
(
'fill'
))
@constructor
@constructor
def
ones_like
(
model
):
def
ones_like
(
model
):
...
@@ -967,10 +979,15 @@ def one():
...
@@ -967,10 +979,15 @@ def one():
"""WRITEME"""
"""WRITEME"""
return
Ones
(
0
)([])
return
Ones
(
0
)([])
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
Filler
)
and
r
.
owner
.
op
.
value
==
0
,
printing
.
FunctionPrinter
(
'zeros'
))
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
Filler
)
and
r
.
owner
.
op
.
value
==
1
,
printing
.
FunctionPrinter
(
'ones'
))
@_redefine
(
elemwise
.
Elemwise
(
scal
.
identity
))
@_redefine
(
elemwise
.
Elemwise
(
scal
.
identity
))
def
tensor_copy
(
a
):
def
tensor_copy
(
a
):
"""Create a duplicate of `a` (with duplicated storage)"""
"""Create a duplicate of `a` (with duplicated storage)"""
pprint
.
assign
(
tensor_copy
,
printing
.
IgnorePrinter
())
@_redefine
(
elemwise
.
Elemwise
(
scal
.
identity
,
inplace_pattern
=
{
0
:
[
0
]}))
@_redefine
(
elemwise
.
Elemwise
(
scal
.
identity
,
inplace_pattern
=
{
0
:
[
0
]}))
def
view
(
a
):
def
view
(
a
):
...
@@ -981,6 +998,9 @@ def sum(input, axis = None):
...
@@ -981,6 +998,9 @@ def sum(input, axis = None):
"""WRITEME"""
"""WRITEME"""
return
elemwise
.
Sum
(
axis
)(
input
)
return
elemwise
.
Sum
(
axis
)(
input
)
pprint
.
assign
(
Sum
(),
printing
.
FunctionPrinter
(
'sum'
))
@constructor
@constructor
def
mean
(
input
,
axis
=
None
):
def
mean
(
input
,
axis
=
None
):
"""WRITEME"""
"""WRITEME"""
...
@@ -1043,6 +1063,14 @@ def mod(a, b):
...
@@ -1043,6 +1063,14 @@ def mod(a, b):
def
pow
(
a
,
b
):
def
pow
(
a
,
b
):
"""elementwise power"""
"""elementwise power"""
pprint
.
assign
(
add
,
printing
.
OperatorPrinter
(
'+'
,
-
2
,
'either'
))
pprint
.
assign
(
mul
,
printing
.
OperatorPrinter
(
'*'
,
-
1
,
'either'
))
pprint
.
assign
(
sub
,
printing
.
OperatorPrinter
(
'-'
,
-
2
,
'left'
))
pprint
.
assign
(
neg
,
printing
.
OperatorPrinter
(
'-'
,
0
,
'either'
))
pprint
.
assign
(
div
,
printing
.
OperatorPrinter
(
'/'
,
-
1
,
'left'
))
pprint
.
assign
(
pow
,
printing
.
OperatorPrinter
(
'**'
,
1
,
'right'
))
##########################
##########################
# View Operations
# View Operations
...
@@ -1214,6 +1242,36 @@ class Subtensor(Op):
...
@@ -1214,6 +1242,36 @@ class Subtensor(Op):
return
"
%
s{
%
s}"
%
(
self
.
__class__
.
__name__
,
", "
.
join
(
indices
))
return
"
%
s{
%
s}"
%
(
self
.
__class__
.
__name__
,
", "
.
join
(
indices
))
class
SubtensorPrinter
:
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print Subtensor."
)
elif
isinstance
(
r
.
owner
.
op
,
Subtensor
):
idxs
=
r
.
owner
.
op
.
idx_list
inputs
=
list
(
r
.
owner
.
inputs
)
input
=
inputs
.
pop
()
sidxs
=
[]
inbrack_pstate
=
pstate
.
clone
(
precedence
=
-
1000
)
for
entry
in
idxs
:
if
isinstance
(
entry
,
int
):
sidxs
.
append
(
str
(
entry
))
elif
isinstance
(
entry
,
scal
.
Scalar
):
sidxs
.
append
(
inbrack_pstate
.
pprinter
.
process
(
inputs
.
pop
()))
elif
isinstance
(
entry
,
slice
):
sidxs
.
append
(
"
%
s:
%
s
%
s"
%
(
""
if
entry
.
start
is
None
or
entry
.
start
==
0
else
entry
.
start
,
""
if
entry
.
stop
is
None
or
entry
.
stop
==
sys
.
maxint
else
entry
.
stop
,
""
if
entry
.
step
is
None
else
":
%
s"
%
entry
.
step
))
return
"
%
s[
%
s]"
%
(
pstate
.
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
)),
", "
.
join
(
sidxs
))
else
:
raise
TypeError
(
"Can only print Subtensor."
)
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
Subtensor
),
SubtensorPrinter
())
class
SetSubtensor
(
Subtensor
):
class
SetSubtensor
(
Subtensor
):
"""WRITEME"""
"""WRITEME"""
view_map
=
{}
view_map
=
{}
...
@@ -1474,6 +1532,11 @@ def join(axis, *tensors):
...
@@ -1474,6 +1532,11 @@ def join(axis, *tensors):
"""
"""
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
Join
),
printing
.
FunctionPrinter
(
'join'
))
@constructor
@constructor
def
shape_padleft
(
tensor
,
n_ones
):
def
shape_padleft
(
tensor
,
n_ones
):
"""Reshape `tensor` by left-padding the shape with `n_ones` 1s
"""Reshape `tensor` by left-padding the shape with `n_ones` 1s
...
@@ -1630,6 +1693,21 @@ make_lvector = MakeVector(lscalar)
...
@@ -1630,6 +1693,21 @@ make_lvector = MakeVector(lscalar)
"""WRITEME"""
"""WRITEME"""
class
MakeVectorPrinter
:
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print make_vector."
)
elif
isinstance
(
r
.
owner
.
op
,
MakeVector
):
return
"[
%
s]"
%
", "
.
join
(
pstate
.
pprinter
.
process
(
input
,
pstate
.
clone
(
precedence
=
1000
))
for
input
in
r
.
owner
.
inputs
)
else
:
raise
TypeError
(
"Can only print make_vector."
)
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
MakeVector
),
MakeVectorPrinter
())
#########################
#########################
# Linalg : Dot
# Linalg : Dot
#########################
#########################
...
@@ -1696,6 +1774,7 @@ class Dot(Op):
...
@@ -1696,6 +1774,7 @@ class Dot(Op):
def
__str__
(
self
):
def
__str__
(
self
):
return
"dot"
return
"dot"
dot
=
Dot
()
dot
=
Dot
()
pprint
.
assign
(
dot
,
printing
.
OperatorPrinter
(
printing
.
special
[
'middle_dot'
],
-
1
,
'left'
))
class
Outer
(
Op
):
class
Outer
(
Op
):
""" Compute vector-vector outer product
""" Compute vector-vector outer product
...
@@ -1964,6 +2043,8 @@ class Gemm(Op):
...
@@ -1964,6 +2043,8 @@ class Gemm(Op):
gemm
=
Gemm
()
gemm
=
Gemm
()
pprint
.
assign
(
gemm
,
printing
.
FunctionPrinter
(
'gemm'
))
#########################
#########################
# Gradient
# Gradient
...
...
theano/tensor/elemwise.py
浏览文件 @
f1d61ed9
...
@@ -6,6 +6,8 @@ from .. import gof
...
@@ -6,6 +6,8 @@ from .. import gof
from
..gof
import
Op
,
Apply
from
..gof
import
Op
,
Apply
from
..
import
scalar
from
..
import
scalar
from
..scalar
import
Scalar
from
..scalar
import
Scalar
from
..
import
printing
from
..printing
import
pprint
from
..gof.python25
import
all
from
..gof.python25
import
all
from
copy
import
copy
from
copy
import
copy
...
@@ -182,6 +184,31 @@ class DimShuffle(Op):
...
@@ -182,6 +184,31 @@ class DimShuffle(Op):
return
DimShuffle
(
gz
.
type
.
broadcastable
,
grad_order
)(
gz
),
return
DimShuffle
(
gz
.
type
.
broadcastable
,
grad_order
)(
gz
),
class
DimShufflePrinter
:
def
__p
(
self
,
new_order
,
pstate
,
r
):
if
new_order
!=
()
and
new_order
[
0
]
==
'x'
:
# return "%s" % self.__p(new_order[1:], pstate, r)
return
"[
%
s]"
%
self
.
__p
(
new_order
[
1
:],
pstate
,
r
)
if
list
(
new_order
)
==
range
(
r
.
type
.
ndim
):
return
pstate
.
pprinter
.
process
(
r
)
if
list
(
new_order
)
==
list
(
reversed
(
range
(
r
.
type
.
ndim
))):
return
"
%
s.T"
%
pstate
.
pprinter
.
process
(
r
)
return
"DimShuffle{
%
s}(
%
s)"
%
(
", "
.
join
(
map
(
str
,
new_order
)),
pstate
.
pprinter
.
process
(
r
))
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print DimShuffle."
)
elif
isinstance
(
r
.
owner
.
op
,
DimShuffle
):
ord
=
r
.
owner
.
op
.
new_order
return
self
.
__p
(
ord
,
pstate
,
r
.
owner
.
inputs
[
0
])
else
:
raise
TypeError
(
"Can only print DimShuffle."
)
pprint
.
assign
(
lambda
pstate
,
r
:
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
DimShuffle
),
DimShufflePrinter
())
################
################
### Elemwise ###
### Elemwise ###
################
################
...
...
theano/tensor/inplace.py
浏览文件 @
f1d61ed9
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
from
basic
import
_scal_elemwise
,
_transpose_inplace
from
basic
import
_scal_elemwise
,
_transpose_inplace
from
..
import
scalar
as
scal
from
..
import
scalar
as
scal
import
elemwise
import
elemwise
from
..
import
printing
from
..printing
import
pprint
def
_scal_inplace
(
symbol
):
def
_scal_inplace
(
symbol
):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
...
@@ -24,6 +26,16 @@ def _scal_inplace(symbol):
...
@@ -24,6 +26,16 @@ def _scal_inplace(symbol):
rval
.
__epydoc_asRoutine
=
symbol
rval
.
__epydoc_asRoutine
=
symbol
rval
.
__module__
=
'theano.tensor.inplace'
rval
.
__module__
=
'theano.tensor.inplace'
def
chk
(
pstate
,
r
):
if
not
r
.
owner
:
return
False
op
=
r
.
owner
.
op
# print op, rval, r.owner and op == rval
# print op.inplace_pattern, rval.inplace_pattern, op.inplace_pattern == rval.inplace_pattern
# print op.scalar_op, rval.scalar_op, op.scalar_op == rval.scalar_op
return
r
.
owner
.
op
==
rval
pprint
.
assign
(
chk
,
printing
.
FunctionPrinter
(
symbolname
.
replace
(
'_inplace'
,
'='
)))
return
rval
return
rval
...
@@ -132,6 +144,8 @@ def second_inplace(a):
...
@@ -132,6 +144,8 @@ def second_inplace(a):
"""Fill `a` with `b`"""
"""Fill `a` with `b`"""
fill_inplace
=
second_inplace
fill_inplace
=
second_inplace
pprint
.
assign
(
fill_inplace
,
printing
.
FunctionPrinter
(
'fill='
))
@_scal_inplace
@_scal_inplace
def
add_inplace
(
a
,
b
):
def
add_inplace
(
a
,
b
):
...
@@ -157,7 +171,17 @@ def mod_inplace(a, b):
...
@@ -157,7 +171,17 @@ def mod_inplace(a, b):
def
pow_inplace
(
a
,
b
):
def
pow_inplace
(
a
,
b
):
"""elementwise power (inplace on `a`)"""
"""elementwise power (inplace on `a`)"""
pprint
.
assign
(
add_inplace
,
printing
.
OperatorPrinter
(
'+='
,
-
2
,
'either'
))
pprint
.
assign
(
mul_inplace
,
printing
.
OperatorPrinter
(
'*='
,
-
1
,
'either'
))
pprint
.
assign
(
sub_inplace
,
printing
.
OperatorPrinter
(
'-='
,
-
2
,
'left'
))
pprint
.
assign
(
neg_inplace
,
printing
.
OperatorPrinter
(
'-='
,
0
,
'either'
))
pprint
.
assign
(
div_inplace
,
printing
.
OperatorPrinter
(
'/='
,
-
1
,
'left'
))
pprint
.
assign
(
pow_inplace
,
printing
.
OperatorPrinter
(
'**='
,
1
,
'right'
))
transpose_inplace
=
_transpose_inplace
transpose_inplace
=
_transpose_inplace
"""WRITEME"""
"""WRITEME"""
pprint
.
assign
(
transpose_inplace
,
printing
.
MemberPrinter
(
'T'
))
theano/tensor/opt.py
浏览文件 @
f1d61ed9
...
@@ -31,7 +31,7 @@ def in2out(*local_opts):
...
@@ -31,7 +31,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1
=
gof
.
PatternSub
((
I
.
sub_inplace
,
gemm_pattern_1
=
gof
.
PatternSub
((
T
.
sub
,
'd'
,
'd'
,
(
T
.
mul
,
(
T
.
mul
,
dict
(
pattern
=
(
T
.
DimShuffle
((),
[
'x'
,
'x'
],
inplace
=
True
),
'a'
),
dict
(
pattern
=
(
T
.
DimShuffle
((),
[
'x'
,
'x'
],
inplace
=
True
),
'a'
),
...
@@ -77,7 +77,10 @@ def _insert_inplace_optimizer(env):
...
@@ -77,7 +77,10 @@ def _insert_inplace_optimizer(env):
for
candidate_input
in
candidate_inputs
:
for
candidate_input
in
candidate_inputs
:
inplace_pattern
=
dict
(
baseline
,
**
{
candidate_output
:
candidate_input
})
inplace_pattern
=
dict
(
baseline
,
**
{
candidate_output
:
candidate_input
})
try
:
try
:
new
=
Elemwise
(
op
.
scalar_op
,
inplace_pattern
)
.
make_node
(
*
node
.
inputs
)
new
=
Elemwise
(
op
.
scalar_op
.
__class__
(
scalar
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
xrange
(
len
(
node
.
outputs
))])),
inplace_pattern
)
.
make_node
(
*
node
.
inputs
)
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
new
.
outputs
))
env
.
replace_all_validate
(
zip
(
node
.
outputs
,
new
.
outputs
))
except
Exception
,
e
:
except
Exception
,
e
:
continue
continue
...
@@ -89,7 +92,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
...
@@ -89,7 +92,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer
=
gof
.
InplaceOptimizer
(
inplace_optimizer
=
gof
.
InplaceOptimizer
(
gof
.
SeqOptimizer
(
out2in
(
gemm_pattern_1
),
gof
.
SeqOptimizer
(
out2in
(
gemm_pattern_1
),
out2in
(
dot_to_gemm
),
#
out2in(dot_to_gemm),
insert_inplace_optimizer
,
insert_inplace_optimizer
,
failure_callback
=
gof
.
keep_going
))
failure_callback
=
gof
.
keep_going
))
compile
.
optdb
.
register
(
'inplace'
,
inplace_optimizer
,
99
,
'fast_run'
)
compile
.
optdb
.
register
(
'inplace'
,
inplace_optimizer
,
99
,
'fast_run'
)
...
@@ -537,7 +540,7 @@ def mul_calculate(num, denum, aslist = False):
...
@@ -537,7 +540,7 @@ def mul_calculate(num, denum, aslist = False):
return
[
v
]
return
[
v
]
return
v
return
v
local_mul_canonizer
=
Canonizer
(
T
.
mul
,
T
.
div
,
T
.
inv
,
mul_calculate
,
False
)
local_mul_canonizer
=
Canonizer
(
T
.
mul
,
T
.
div
,
T
.
inv
,
mul_calculate
)
@gof.local_optimizer
([
T
.
neg
])
@gof.local_optimizer
([
T
.
neg
])
def
local_neg_to_mul
(
node
):
def
local_neg_to_mul
(
node
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论