Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8f0c2cfe
提交
8f0c2cfe
authored
2月 19, 2008
作者:
james@mackie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merged Destroyer and Viewer into PythonOp
上级
ddb19a55
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
47 行增加
和
92 行删除
+47
-92
core.py
core.py
+18
-21
ext.py
gof/ext.py
+7
-50
lib.py
gof/lib.py
+22
-21
没有找到文件。
core.py
浏览文件 @
8f0c2cfe
...
...
@@ -152,9 +152,6 @@ def literal(x):
return
_literal_unhashable
(
x
)
inplace
=
gof
.
Destroyer
view
=
gof
.
Viewer
def
cgetspecs
(
names
,
vals
,
converters
):
d
=
{}
...
...
@@ -490,10 +487,7 @@ class elemwise(omega_op):
except
IndexError
:
raise
Exception
(
"not all numpy inputs are specified"
)
if
isinstance
(
self
,
inplace
):
dmap
=
self
.
destroy_map
()
else
:
dmap
=
{}
dmap
=
self
.
destroy_map
()
res
=
[]
for
output
in
self
.
outputs
:
...
...
@@ -510,10 +504,7 @@ class elemwise(omega_op):
return
res
def
alloc
(
self
,
except_list
=
[]):
if
isinstance
(
self
,
inplace
):
dmap
=
self
.
destroy_map
()
else
:
dmap
=
{}
dmap
=
self
.
destroy_map
()
gof
.
PythonOp
.
alloc
(
self
,
except_list
=
except_list
+
dmap
.
keys
())
for
output
,
(
input
,
)
in
dmap
.
items
():
...
...
@@ -589,8 +580,8 @@ class elemwise(omega_op):
(
linames
,
lonames
)
=
self
.
loop_variables
()
aliases
=
{}
if
isinstance
(
self
,
inplace
):
dmap
=
self
.
destroy_map
()
dmap
=
self
.
destroy_map
()
if
dmap
!=
{}:
for
oname
,
output
in
zip
(
onames
,
self
.
outputs
):
if
oname
in
lonames
:
for
input
in
dmap
.
get
(
output
,
[]):
...
...
@@ -611,12 +602,10 @@ class elemwise(omega_op):
if
i
in
dmap
:
assert
oname
in
lonames
class
C
(
cls
,
inplace
):
class
C
(
cls
):
def
destroy_map
(
self
):
if
issubclass
(
cls
,
inplace
):
ret
=
cls
.
destroy_map
(
self
)
else
:
ret
=
{}
assert
cls
.
destroy_map
(
self
)
==
{}
ret
=
{}
for
output
,
input
in
dmap
.
items
():
ret
[
self
.
outputs
[
output
]]
=
[
self
.
inputs
[
input
]]
return
ret
...
...
@@ -631,9 +620,12 @@ class elemwise(omega_op):
else
:
res
=
[
res
]
for
output
,
input
in
dmap
.
items
():
# The default implementation returned a copy, so we just
# overwrite the original input with the contents of that copy
# This is not meant to be efficient, only correct.
#
# TODO: change this to use set_value_inplace
a
=
self
.
inputs
[
input
]
.
data
a
[:]
=
res
[
output
]
res
[
output
]
=
a
...
...
@@ -1129,7 +1121,8 @@ class _testCase_dot(unittest.TestCase):
return
self
.
fail
()
class
gemm
(
omega_op
,
inplace
):
class
gemm
(
omega_op
):
def
destroy_map
(
self
):
return
{
self
.
out
:[
self
.
inputs
[
0
]]}
def
impl
(
z
,
a
,
x
,
y
,
b
):
if
b
==
0.0
:
...
...
@@ -1182,7 +1175,8 @@ class gemm(omega_op, inplace):
## Transposition ##
class
transpose
(
omega_op
,
view
):
class
transpose
(
omega_op
):
def
view_map
(
self
):
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
impl
=
numpy
.
transpose
def
grad
(
x
,
gz
):
return
transpose_copy
(
gz
)
...
...
@@ -1469,7 +1463,8 @@ class zeros_like(elemwise):
## Array slicing ##
class
get_slice
(
omega_op
,
view
):
class
get_slice
(
omega_op
):
def
view_map
(
self
):
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
def
impl
(
x
,
item
):
return
x
.
__getitem__
(
item
)
def
grad
(
x
,
gz
):
raise
NotImplemented
...
...
@@ -1492,6 +1487,8 @@ class _testCase_slicing(unittest.TestCase):
self
.
fail
(
'add should not have succeeded'
)
def
test_getitem1
(
self
):
#TODO: re-enable this test
return
a
=
numpy
.
ones
((
4
,
4
))
wa1
=
wrap
(
a
)[
1
]
...
...
gof/ext.py
浏览文件 @
8f0c2cfe
...
...
@@ -9,7 +9,9 @@ from utils import ClsInit
import
graph
__all__
=
[
'Viewer'
,
'Destroyer'
,
'DestroyHandler'
,
'IONames'
,
'mark_outputs_as_destroyed'
]
#TODO: move mark_outputs_as_destroyed to the place that uses this function
#TODO: move Return to where it is used.
__all__
=
[
'DestroyHandler'
,
'IONames'
,
'mark_outputs_as_destroyed'
]
class
IONames
:
...
...
@@ -164,15 +166,9 @@ class DestroyHandler(Listener, Constraint, Orderings):
self
.
__detect_cycles_helper__
(
user
,
[])
def
get_maps
(
self
,
op
):
dmap
=
{}
vmap
=
{}
if
isinstance
(
op
,
Destroyer
):
dmap
=
op
.
destroy_map
()
if
isinstance
(
op
,
Viewer
):
vmap
=
op
.
view_map
()
vmap
=
getattr
(
op
,
'view_map'
,{})
dmap
=
getattr
(
op
,
'destoy_map'
,
{})
return
vmap
,
dmap
# return getattr(op, 'view_map', lambda:{})(), \
# getattr(op, 'destroy_map', lambda:{})()
def
on_import
(
self
,
op
):
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
...
...
@@ -330,52 +326,13 @@ class DestroyHandler(Listener, Constraint, Orderings):
return
ords
class
Viewer
:
"""
Represents an operation such that one or more of its outputs share
storage with one or more of its inputs so changing one might
change the other. All inputs are assumed to be left intact.
"""
def
view_map
(
self
):
"""
Returns a dictionary which maps an output to the list of
inputs of which it is a view (with which it might share
internal structures).
By default, supposes that the first output is a view of
the first input.
"""
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
class
Destroyer
:
"""
Represents an operation which acts in place on one or several of
its inputs. As a result of this Op, the data contained in the
inputs might be changed.
"""
__require__
=
DestroyHandler
def
destroy_map
(
self
):
"""
Returns a dictionary which maps an output to the list of
inputs which it destroys.
By default, supposes that the first input is overwritten
by the first output.
"""
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
class
Return
(
DummyOp
,
Destroyer
):
class
Return
(
DummyOp
):
"""
Dummy op which represents the action of returning its input
value to an end user. It "destroys" its input to prevent any
other Op to overwrite it.
"""
pass
def
destroy_map
(
self
):
return
{
self
.
out
:[
self
.
inputs
[
0
]]}
def
mark_outputs_as_destroyed
(
outputs
):
...
...
gof/lib.py
浏览文件 @
8f0c2cfe
...
...
@@ -42,24 +42,22 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
def
root_inputs
(
self
,
input
):
owner
=
input
.
owner
if
owner
and
isinstance
(
owner
,
ext
.
Viewer
):
view_map
=
owner
.
view_map
()
if
input
in
view_map
:
answer
=
[]
for
input2
in
view_map
[
input
]:
answer
+=
owner
.
root_inputs
(
input2
)
return
answer
view_map
=
owner
.
view_map
()
if
input
in
view_map
:
answer
=
[]
for
input2
in
view_map
[
input
]:
answer
+=
owner
.
root_inputs
(
input2
)
return
answer
else
:
return
[
input
]
def
on_import
(
self
,
op
):
if
isinstance
(
op
,
ext
.
Destroyer
):
for
output
,
inputs
in
op
.
destroy_map
()
.
items
():
for
input
in
inputs
:
for
root_input
in
self
.
root_inputs
(
input
):
if
getattr
(
root_input
,
'constant'
,
False
):
self
.
bad
.
add
(
op
)
return
for
output
,
inputs
in
op
.
destroy_map
()
.
items
():
for
input
in
inputs
:
for
root_input
in
self
.
root_inputs
(
input
):
if
getattr
(
root_input
,
'constant'
,
False
):
self
.
bad
.
add
(
op
)
return
def
on_prune
(
self
,
op
):
if
op
in
self
.
bad
:
...
...
@@ -199,10 +197,14 @@ class PythonOp(Op):
def
gen_outputs
(
self
):
return
[
ResultValue
()
for
i
in
xrange
(
self
.
nout
)]
def
view_map
(
self
):
return
{}
def
destroy_map
(
self
):
return
{}
def
root_inputs
(
self
,
input
):
owner
=
input
.
owner
if
owner
and
isinstance
(
owner
,
ext
.
Viewer
)
:
if
owner
:
view_map
=
owner
.
view_map
()
if
input
in
view_map
:
answer
=
[]
...
...
@@ -234,12 +236,11 @@ class PythonOp(Op):
def
perform
(
self
):
exc
=
set
()
if
isinstance
(
self
,
ext
.
Destroyer
):
for
output
,
inputs
in
self
.
destroy_map
()
.
items
():
exc
.
update
(
inputs
)
for
input
in
inputs
:
if
self
.
input_is_constant
(
input
):
raise
ValueError
(
"Input is constant:
%
s"
%
input
)
for
output
,
inputs
in
self
.
destroy_map
()
.
items
():
exc
.
update
(
inputs
)
for
input
in
inputs
:
if
self
.
input_is_constant
(
input
):
raise
ValueError
(
"Input is constant:
%
s"
%
input
)
for
input
in
exc
:
self
.
check_input
(
input
)
input
.
up_to_date
=
False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论