Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4f8d207b
提交
4f8d207b
authored
11月 19, 2016
作者:
khaotik
提交者:
khaotik
1月 27, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP8
上级
c0fda9c0
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
29 行增加
和
24 行删除
+29
-24
builders.py
theano/compile/builders.py
+26
-19
test_builders.py
theano/compile/tests/test_builders.py
+3
-5
没有找到文件。
theano/compile/builders.py
浏览文件 @
4f8d207b
...
@@ -26,8 +26,6 @@ class OpFromGraphBase(gof.Op):
...
@@ -26,8 +26,6 @@ class OpFromGraphBase(gof.Op):
'inputs and outputs must be Variable instances'
,
i
)
'inputs and outputs must be Variable instances'
,
i
)
if
'updates'
in
kwargs
or
'givens'
in
kwargs
:
if
'updates'
in
kwargs
or
'givens'
in
kwargs
:
raise
TypeError
(
'updates and givens are not allowed here'
)
raise
TypeError
(
'updates and givens are not allowed here'
)
# To correctly support shared variables the inner fct should
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
# not see them. Otherwise there is a problem with the gradient.
self
.
shared_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
self
.
shared_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
...
@@ -46,7 +44,6 @@ class OpFromGraphBase(gof.Op):
...
@@ -46,7 +44,6 @@ class OpFromGraphBase(gof.Op):
assert
not
update_expr
assert
not
update_expr
assert
not
shared_inputs
assert
not
shared_inputs
self
.
internal_inputs
=
internal_inputs
self
.
internal_inputs
=
internal_inputs
self
.
internal_outputs
=
internal_outputs
self
.
internal_outputs
=
internal_outputs
self
.
inputs
=
inputs
self
.
inputs
=
inputs
...
@@ -77,7 +74,7 @@ class OpFromGraphBase(gof.Op):
...
@@ -77,7 +74,7 @@ class OpFromGraphBase(gof.Op):
grad_ops_l
=
self
.
grad_ops
grad_ops_l
=
self
.
grad_ops
if
isinstance
(
grad_ops_l
,
list
):
if
isinstance
(
grad_ops_l
,
list
):
assert
len
(
grad_ops_l
)
<=
len
(
self
.
internal_inputs
)
assert
len
(
grad_ops_l
)
<=
len
(
self
.
internal_inputs
)
if
len
(
grad_ops_l
)
<
len
(
self
.
internal_inputs
):
if
len
(
grad_ops_l
)
<
len
(
self
.
internal_inputs
):
grad_ops_l
+=
[
None
]
*
(
grad_ops_l
+=
[
None
]
*
(
len
(
self
.
internal_inputs
)
-
len
(
grad_ops_l
))
len
(
self
.
internal_inputs
)
-
len
(
grad_ops_l
))
# It is normal if some inputs are not needed in order
# It is normal if some inputs are not needed in order
...
@@ -92,10 +89,12 @@ class OpFromGraphBase(gof.Op):
...
@@ -92,10 +89,12 @@ class OpFromGraphBase(gof.Op):
disconnected_inputs
=
'ignore'
)
disconnected_inputs
=
'ignore'
)
),
on_unused_input
=
'ignore'
),
on_unused_input
=
'ignore'
)
for
go
,
inp
in
izip
(
grad_ops_l
,
self
.
internal_inputs
)]
)
for
go
,
inp
in
izip
(
grad_ops_l
,
self
.
internal_inputs
)]
# since OpFromGraphBase only accepts input sequence,
# since OpFromGraphBase only accepts input sequence,
# additional filtering is needed
# additional filtering is needed
grad_ops
=
lambda
inps
,
grds
:[
def
grad_ops
(
inps
,
grds
):
(
go
(
inps
,
grds
)
if
ov
else
go
(
*
(
inps
+
grds
)))
nonlocal
gs
,
grad_ops_l
return
[(
go
(
inps
,
grds
)
if
ov
else
go
(
*
(
inps
+
grds
)))
for
go
,
ov
in
izip
(
gs
,
grad_ops_l
)]
for
go
,
ov
in
izip
(
gs
,
grad_ops_l
)]
else
:
else
:
grad_ops
=
grad_ops_l
grad_ops
=
grad_ops_l
...
@@ -111,10 +110,12 @@ class OpFromGraphBase(gof.Op):
...
@@ -111,10 +110,12 @@ class OpFromGraphBase(gof.Op):
if
g
is
None
:
if
g
is
None
:
grad_ops_l
.
append
(
lambda
*
args
:
None
)
grad_ops_l
.
append
(
lambda
*
args
:
None
)
else
:
else
:
grad_ops_l
.
append
(
type
(
self
)(
grad_inps
,
grad_ops_l
.
append
(
type
(
self
)(
[
g
],
grad_inps
,
[
g
],
on_unused_input
=
'ignore'
))
on_unused_input
=
'ignore'
))
grad_ops
=
lambda
inps
,
grds
:[
go
(
*
(
inps
+
grds
))
for
go
in
grad_ops_l
]
def
grad_ops
(
inps
,
grds
):
nonlocal
grad_ops_l
return
[
go
(
*
(
inps
+
grds
))
for
go
in
grad_ops_l
]
self
.
grad_ops
=
grad_ops
self
.
grad_ops
=
grad_ops
self
.
cached_grad_ops
=
True
self
.
cached_grad_ops
=
True
return
grad_ops
(
inputs
,
output_grads
)
return
grad_ops
(
inputs
,
output_grads
)
...
@@ -125,8 +126,8 @@ class OpFromGraphBase(gof.Op):
...
@@ -125,8 +126,8 @@ class OpFromGraphBase(gof.Op):
raise
TypeError
(
"Wrong type, expected
%
s but got
%
s"
%
raise
TypeError
(
"Wrong type, expected
%
s but got
%
s"
%
(
type
,
input
.
type
))
(
type
,
input
.
type
))
apply_node
=
gof
.
Apply
(
self
,
apply_node
=
gof
.
Apply
(
list
(
inputs
)
+
self
.
shared_inputs
,
self
,
list
(
inputs
)
+
self
.
shared_inputs
,
[
type
()
for
type
in
self
.
output_types
])
[
type
()
for
type
in
self
.
output_types
])
apply_node
.
internal_inputs
=
self
.
internal_inputs
apply_node
.
internal_inputs
=
self
.
internal_inputs
apply_node
.
internal_outputs
=
self
.
internal_outputs
apply_node
.
internal_outputs
=
self
.
internal_outputs
...
@@ -137,7 +138,8 @@ class OpFromGraphBase(gof.Op):
...
@@ -137,7 +138,8 @@ class OpFromGraphBase(gof.Op):
Return connection pattern of subfgraph defined by inputs and outputs.
Return connection pattern of subfgraph defined by inputs and outputs.
"""
"""
return
io_connection_pattern
(
self
.
internal_inputs
,
self
.
internal_outputs
)
return
io_connection_pattern
(
self
.
internal_inputs
,
self
.
internal_outputs
)
def
infer_shape
(
self
,
node
,
shapes
):
def
infer_shape
(
self
,
node
,
shapes
):
out_shp
=
theano
.
scan_module
.
scan_utils
.
infer_shape
(
out_shp
=
theano
.
scan_module
.
scan_utils
.
infer_shape
(
...
@@ -162,9 +164,11 @@ class OpFromGraphBase(gof.Op):
...
@@ -162,9 +164,11 @@ class OpFromGraphBase(gof.Op):
used
+=
nb
used
+=
nb
return
ret
return
ret
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
OpFromGraphPrecompiled
(
OpFromGraphBase
):
class
OpFromGraphPrecompiled
(
OpFromGraphBase
):
"""
"""
The Op's inner graph is compiled into a theano function.
The Op's inner graph is compiled into a theano function.
...
@@ -183,12 +187,15 @@ class OpFromGraphPrecompiled(OpFromGraphBase):
...
@@ -183,12 +187,15 @@ class OpFromGraphPrecompiled(OpFromGraphBase):
# we wont need this copy anymore
# we wont need this copy anymore
output
[
0
]
=
variable
.
copy
()
output
[
0
]
=
variable
.
copy
()
class
OpFromGraphInline
(
OpFromGraphBase
):
class
OpFromGraphInline
(
OpFromGraphBase
):
"""
"""
The Op's inner graph is expanded into the outer graph at compile time
The Op's inner graph is expanded into the outer graph at compile time
"""
"""
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
raise
RuntimeError
(
type
(
self
)
.
__name__
+
' is not supposed to be executed at runtime'
)
raise
RuntimeError
(
type
(
self
)
.
__name__
+
' is not supposed to be executed at runtime'
)
@gof.local_optimizer
([
OpFromGraphInline
])
@gof.local_optimizer
([
OpFromGraphInline
])
def
inline_ofg_expansion
(
node
):
def
inline_ofg_expansion
(
node
):
...
@@ -199,7 +206,7 @@ def inline_ofg_expansion(node):
...
@@ -199,7 +206,7 @@ def inline_ofg_expansion(node):
return
False
return
False
outputs
=
theano
.
clone
(
outputs
=
theano
.
clone
(
op
.
internal_outputs
,
{
op
.
internal_outputs
,
{
u
:
v
for
u
,
v
in
izip
(
u
:
v
for
u
,
v
in
izip
(
node
.
op
.
internal_inputs
,
node
.
inputs
)})
node
.
op
.
internal_inputs
,
node
.
inputs
)})
return
outputs
return
outputs
...
@@ -218,7 +225,8 @@ OpFromGraph = OpFromGraphPrecompiled
...
@@ -218,7 +225,8 @@ OpFromGraph = OpFromGraphPrecompiled
# API for OpFromGraph*
# API for OpFromGraph*
def
op_from_graph
(
def
op_from_graph
(
inputs
,
outputs
,
inline
=
False
,
grad_overrides
=
None
,
**
kwargs
):
inputs
,
outputs
,
inline
=
False
,
grad_overrides
=
None
,
**
kwargs
):
"""
"""
This creates an `Op` from inputs and outputs lists of variables.
This creates an `Op` from inputs and outputs lists of variables.
The signature is similar to theano.function() and the resulting
The signature is similar to theano.function() and the resulting
...
@@ -270,8 +278,8 @@ def op_from_graph(
...
@@ -270,8 +278,8 @@ def op_from_graph(
invisible to the user. They can be as input to the node or in the
invisible to the user. They can be as input to the node or in the
inner graph.
inner graph.
- We support unused inputs. This is needed for the grad.
- We support unused inputs. This is needed for the grad.
- `inline=True` will cause better runtime optimization at the cost of
longer
- `inline=True` will cause better runtime optimization at the cost of
compilation, only works with optimizer "fast_run" or "fast_compile"
longer
compilation, only works with optimizer "fast_run" or "fast_compile"
Examples
Examples
--------
--------
...
@@ -329,4 +337,3 @@ def op_from_graph(
...
@@ -329,4 +337,3 @@ def op_from_graph(
cls_opfromgraph
=
OpFromGraphPrecompiled
cls_opfromgraph
=
OpFromGraphPrecompiled
return
cls_opfromgraph
(
return
cls_opfromgraph
(
inputs
,
outputs
,
grad_overrides
=
grad_overrides
,
**
kwargs
)
inputs
,
outputs
,
grad_overrides
=
grad_overrides
,
**
kwargs
)
theano/compile/tests/test_builders.py
浏览文件 @
4f8d207b
...
@@ -18,7 +18,6 @@ test_params = unittest_tools.parameterized.expand(
...
@@ -18,7 +18,6 @@ test_params = unittest_tools.parameterized.expand(
class
T_OpFromGraph
(
unittest_tools
.
InferShapeTester
):
class
T_OpFromGraph
(
unittest_tools
.
InferShapeTester
):
@test_params
@test_params
def
test_straightforward
(
self
,
cls_ofg
):
def
test_straightforward
(
self
,
cls_ofg
):
x
,
y
,
z
=
T
.
matrices
(
'xyz'
)
x
,
y
,
z
=
T
.
matrices
(
'xyz'
)
...
@@ -122,7 +121,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
...
@@ -122,7 +121,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
@test_params
@test_params
def
test_grad_override
(
self
,
cls_ofg
):
def
test_grad_override
(
self
,
cls_ofg
):
x
,
y
=
T
.
vectors
(
'xy'
)
x
,
y
=
T
.
vectors
(
'xy'
)
def
go
(
inps
,
gs
):
def
go
(
inps
,
gs
):
x
,
y
=
inps
x
,
y
=
inps
...
@@ -132,8 +131,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
...
@@ -132,8 +131,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
# single override case
# single override case
op_mul
=
cls_ofg
([
x
,
y
],
[
x
*
y
],
grad_overrides
=
go
)
op_mul
=
cls_ofg
([
x
,
y
],
[
x
*
y
],
grad_overrides
=
go
)
xx
,
yy
=
T
.
vector
(
'xx'
),
T
.
vector
(
'yy'
)
xx
,
yy
=
T
.
vector
(
'xx'
),
T
.
vector
(
'yy'
)
zz
=
T
.
sum
(
op_mul
(
xx
,
yy
))
zz
=
T
.
sum
(
op_mul
(
xx
,
yy
))
dx
,
dy
=
T
.
grad
(
zz
,
[
xx
,
yy
])
dx
,
dy
=
T
.
grad
(
zz
,
[
xx
,
yy
])
fn
=
function
([
xx
,
yy
],
[
dx
,
dy
])
fn
=
function
([
xx
,
yy
],
[
dx
,
dy
])
xv
=
numpy
.
random
.
rand
(
16
)
.
astype
(
config
.
floatX
)
xv
=
numpy
.
random
.
rand
(
16
)
.
astype
(
config
.
floatX
)
...
@@ -247,4 +246,3 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
...
@@ -247,4 +246,3 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
np
.
ones
([
3
,
4
],
dtype
=
config
.
floatX
)],
np
.
ones
([
3
,
4
],
dtype
=
config
.
floatX
)],
cls_ofg
,
cls_ofg
,
check_topo
=
is_compile
)
check_topo
=
is_compile
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论