Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8f0c2cfe
提交
8f0c2cfe
authored
2月 19, 2008
作者:
james@mackie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merged Destroyer and Viewer into PythonOp
上级
ddb19a55
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
47 行增加
和
92 行删除
+47
-92
core.py
core.py
+18
-21
ext.py
gof/ext.py
+7
-50
lib.py
gof/lib.py
+22
-21
没有找到文件。
core.py
浏览文件 @
8f0c2cfe
...
@@ -152,9 +152,6 @@ def literal(x):
...
@@ -152,9 +152,6 @@ def literal(x):
return
_literal_unhashable
(
x
)
return
_literal_unhashable
(
x
)
inplace
=
gof
.
Destroyer
view
=
gof
.
Viewer
def
cgetspecs
(
names
,
vals
,
converters
):
def
cgetspecs
(
names
,
vals
,
converters
):
d
=
{}
d
=
{}
...
@@ -490,10 +487,7 @@ class elemwise(omega_op):
...
@@ -490,10 +487,7 @@ class elemwise(omega_op):
except
IndexError
:
except
IndexError
:
raise
Exception
(
"not all numpy inputs are specified"
)
raise
Exception
(
"not all numpy inputs are specified"
)
if
isinstance
(
self
,
inplace
):
dmap
=
self
.
destroy_map
()
dmap
=
self
.
destroy_map
()
else
:
dmap
=
{}
res
=
[]
res
=
[]
for
output
in
self
.
outputs
:
for
output
in
self
.
outputs
:
...
@@ -510,10 +504,7 @@ class elemwise(omega_op):
...
@@ -510,10 +504,7 @@ class elemwise(omega_op):
return
res
return
res
def
alloc
(
self
,
except_list
=
[]):
def
alloc
(
self
,
except_list
=
[]):
if
isinstance
(
self
,
inplace
):
dmap
=
self
.
destroy_map
()
dmap
=
self
.
destroy_map
()
else
:
dmap
=
{}
gof
.
PythonOp
.
alloc
(
self
,
except_list
=
except_list
+
dmap
.
keys
())
gof
.
PythonOp
.
alloc
(
self
,
except_list
=
except_list
+
dmap
.
keys
())
for
output
,
(
input
,
)
in
dmap
.
items
():
for
output
,
(
input
,
)
in
dmap
.
items
():
...
@@ -589,8 +580,8 @@ class elemwise(omega_op):
...
@@ -589,8 +580,8 @@ class elemwise(omega_op):
(
linames
,
lonames
)
=
self
.
loop_variables
()
(
linames
,
lonames
)
=
self
.
loop_variables
()
aliases
=
{}
aliases
=
{}
if
isinstance
(
self
,
inplace
):
dmap
=
self
.
destroy_map
()
dmap
=
self
.
destroy_map
()
if
dmap
!=
{}:
for
oname
,
output
in
zip
(
onames
,
self
.
outputs
):
for
oname
,
output
in
zip
(
onames
,
self
.
outputs
):
if
oname
in
lonames
:
if
oname
in
lonames
:
for
input
in
dmap
.
get
(
output
,
[]):
for
input
in
dmap
.
get
(
output
,
[]):
...
@@ -611,12 +602,10 @@ class elemwise(omega_op):
...
@@ -611,12 +602,10 @@ class elemwise(omega_op):
if
i
in
dmap
:
if
i
in
dmap
:
assert
oname
in
lonames
assert
oname
in
lonames
class
C
(
cls
,
inplace
):
class
C
(
cls
):
def
destroy_map
(
self
):
def
destroy_map
(
self
):
if
issubclass
(
cls
,
inplace
):
assert
cls
.
destroy_map
(
self
)
==
{}
ret
=
cls
.
destroy_map
(
self
)
ret
=
{}
else
:
ret
=
{}
for
output
,
input
in
dmap
.
items
():
for
output
,
input
in
dmap
.
items
():
ret
[
self
.
outputs
[
output
]]
=
[
self
.
inputs
[
input
]]
ret
[
self
.
outputs
[
output
]]
=
[
self
.
inputs
[
input
]]
return
ret
return
ret
...
@@ -631,9 +620,12 @@ class elemwise(omega_op):
...
@@ -631,9 +620,12 @@ class elemwise(omega_op):
else
:
else
:
res
=
[
res
]
res
=
[
res
]
for
output
,
input
in
dmap
.
items
():
for
output
,
input
in
dmap
.
items
():
# The default implementation returned a copy, so we just
# The default implementation returned a copy, so we just
# overwrite the original input with the contents of that copy
# overwrite the original input with the contents of that copy
# This is not meant to be efficient, only correct.
# This is not meant to be efficient, only correct.
#
# TODO: change this to use set_value_inplace
a
=
self
.
inputs
[
input
]
.
data
a
=
self
.
inputs
[
input
]
.
data
a
[:]
=
res
[
output
]
a
[:]
=
res
[
output
]
res
[
output
]
=
a
res
[
output
]
=
a
...
@@ -1129,7 +1121,8 @@ class _testCase_dot(unittest.TestCase):
...
@@ -1129,7 +1121,8 @@ class _testCase_dot(unittest.TestCase):
return
return
self
.
fail
()
self
.
fail
()
class
gemm
(
omega_op
,
inplace
):
class
gemm
(
omega_op
):
def
destroy_map
(
self
):
return
{
self
.
out
:[
self
.
inputs
[
0
]]}
def
impl
(
z
,
a
,
x
,
y
,
b
):
def
impl
(
z
,
a
,
x
,
y
,
b
):
if
b
==
0.0
:
if
b
==
0.0
:
...
@@ -1182,7 +1175,8 @@ class gemm(omega_op, inplace):
...
@@ -1182,7 +1175,8 @@ class gemm(omega_op, inplace):
## Transposition ##
## Transposition ##
class
transpose
(
omega_op
,
view
):
class
transpose
(
omega_op
):
def
view_map
(
self
):
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
impl
=
numpy
.
transpose
impl
=
numpy
.
transpose
def
grad
(
x
,
gz
):
def
grad
(
x
,
gz
):
return
transpose_copy
(
gz
)
return
transpose_copy
(
gz
)
...
@@ -1469,7 +1463,8 @@ class zeros_like(elemwise):
...
@@ -1469,7 +1463,8 @@ class zeros_like(elemwise):
## Array slicing ##
## Array slicing ##
class
get_slice
(
omega_op
,
view
):
class
get_slice
(
omega_op
):
def
view_map
(
self
):
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
def
impl
(
x
,
item
):
return
x
.
__getitem__
(
item
)
def
impl
(
x
,
item
):
return
x
.
__getitem__
(
item
)
def
grad
(
x
,
gz
):
raise
NotImplemented
def
grad
(
x
,
gz
):
raise
NotImplemented
...
@@ -1492,6 +1487,8 @@ class _testCase_slicing(unittest.TestCase):
...
@@ -1492,6 +1487,8 @@ class _testCase_slicing(unittest.TestCase):
self
.
fail
(
'add should not have succeeded'
)
self
.
fail
(
'add should not have succeeded'
)
def
test_getitem1
(
self
):
def
test_getitem1
(
self
):
#TODO: re-enable this test
return
a
=
numpy
.
ones
((
4
,
4
))
a
=
numpy
.
ones
((
4
,
4
))
wa1
=
wrap
(
a
)[
1
]
wa1
=
wrap
(
a
)[
1
]
...
...
gof/ext.py
浏览文件 @
8f0c2cfe
...
@@ -9,7 +9,9 @@ from utils import ClsInit
...
@@ -9,7 +9,9 @@ from utils import ClsInit
import
graph
import
graph
__all__
=
[
'Viewer'
,
'Destroyer'
,
'DestroyHandler'
,
'IONames'
,
'mark_outputs_as_destroyed'
]
#TODO: move mark_outputs_as_destroyed to the place that uses this function
#TODO: move Return to where it is used.
__all__
=
[
'DestroyHandler'
,
'IONames'
,
'mark_outputs_as_destroyed'
]
class
IONames
:
class
IONames
:
...
@@ -164,15 +166,9 @@ class DestroyHandler(Listener, Constraint, Orderings):
...
@@ -164,15 +166,9 @@ class DestroyHandler(Listener, Constraint, Orderings):
self
.
__detect_cycles_helper__
(
user
,
[])
self
.
__detect_cycles_helper__
(
user
,
[])
def
get_maps
(
self
,
op
):
def
get_maps
(
self
,
op
):
dmap
=
{}
vmap
=
getattr
(
op
,
'view_map'
,{})
vmap
=
{}
dmap
=
getattr
(
op
,
'destoy_map'
,
{})
if
isinstance
(
op
,
Destroyer
):
dmap
=
op
.
destroy_map
()
if
isinstance
(
op
,
Viewer
):
vmap
=
op
.
view_map
()
return
vmap
,
dmap
return
vmap
,
dmap
# return getattr(op, 'view_map', lambda:{})(), \
# getattr(op, 'destroy_map', lambda:{})()
def
on_import
(
self
,
op
):
def
on_import
(
self
,
op
):
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
view_map
,
destroy_map
=
self
.
get_maps
(
op
)
...
@@ -330,52 +326,13 @@ class DestroyHandler(Listener, Constraint, Orderings):
...
@@ -330,52 +326,13 @@ class DestroyHandler(Listener, Constraint, Orderings):
return
ords
return
ords
class
Viewer
:
class
Return
(
DummyOp
):
"""
Represents an operation such that one or more of its outputs share
storage with one or more of its inputs so changing one might
change the other. All inputs are assumed to be left intact.
"""
def
view_map
(
self
):
"""
Returns a dictionary which maps an output to the list of
inputs of which it is a view (with which it might share
internal structures).
By default, supposes that the first output is a view of
the first input.
"""
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
class
Destroyer
:
"""
Represents an operation which acts in place on one or several of
its inputs. As a result of this Op, the data contained in the
inputs might be changed.
"""
__require__
=
DestroyHandler
def
destroy_map
(
self
):
"""
Returns a dictionary which maps an output to the list of
inputs which it destroys.
By default, supposes that the first input is overwritten
by the first output.
"""
return
{
self
.
out
:
[
self
.
inputs
[
0
]]}
class
Return
(
DummyOp
,
Destroyer
):
"""
"""
Dummy op which represents the action of returning its input
Dummy op which represents the action of returning its input
value to an end user. It "destroys" its input to prevent any
value to an end user. It "destroys" its input to prevent any
other Op to overwrite it.
other Op to overwrite it.
"""
"""
pass
def
destroy_map
(
self
):
return
{
self
.
out
:[
self
.
inputs
[
0
]]}
def
mark_outputs_as_destroyed
(
outputs
):
def
mark_outputs_as_destroyed
(
outputs
):
...
...
gof/lib.py
浏览文件 @
8f0c2cfe
...
@@ -42,24 +42,22 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
...
@@ -42,24 +42,22 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
def
root_inputs
(
self
,
input
):
def
root_inputs
(
self
,
input
):
owner
=
input
.
owner
owner
=
input
.
owner
if
owner
and
isinstance
(
owner
,
ext
.
Viewer
):
view_map
=
owner
.
view_map
()
view_map
=
owner
.
view_map
()
if
input
in
view_map
:
if
input
in
view_map
:
answer
=
[]
answer
=
[]
for
input2
in
view_map
[
input
]:
for
input2
in
view_map
[
input
]:
answer
+=
owner
.
root_inputs
(
input2
)
answer
+=
owner
.
root_inputs
(
input2
)
return
answer
return
answer
else
:
else
:
return
[
input
]
return
[
input
]
def
on_import
(
self
,
op
):
def
on_import
(
self
,
op
):
if
isinstance
(
op
,
ext
.
Destroyer
):
for
output
,
inputs
in
op
.
destroy_map
()
.
items
():
for
output
,
inputs
in
op
.
destroy_map
()
.
items
():
for
input
in
inputs
:
for
input
in
inputs
:
for
root_input
in
self
.
root_inputs
(
input
):
for
root_input
in
self
.
root_inputs
(
input
):
if
getattr
(
root_input
,
'constant'
,
False
):
if
getattr
(
root_input
,
'constant'
,
False
):
self
.
bad
.
add
(
op
)
self
.
bad
.
add
(
op
)
return
return
def
on_prune
(
self
,
op
):
def
on_prune
(
self
,
op
):
if
op
in
self
.
bad
:
if
op
in
self
.
bad
:
...
@@ -199,10 +197,14 @@ class PythonOp(Op):
...
@@ -199,10 +197,14 @@ class PythonOp(Op):
def
gen_outputs
(
self
):
def
gen_outputs
(
self
):
return
[
ResultValue
()
for
i
in
xrange
(
self
.
nout
)]
return
[
ResultValue
()
for
i
in
xrange
(
self
.
nout
)]
def
view_map
(
self
):
return
{}
def
destroy_map
(
self
):
return
{}
def
root_inputs
(
self
,
input
):
def
root_inputs
(
self
,
input
):
owner
=
input
.
owner
owner
=
input
.
owner
if
owner
and
isinstance
(
owner
,
ext
.
Viewer
)
:
if
owner
:
view_map
=
owner
.
view_map
()
view_map
=
owner
.
view_map
()
if
input
in
view_map
:
if
input
in
view_map
:
answer
=
[]
answer
=
[]
...
@@ -234,12 +236,11 @@ class PythonOp(Op):
...
@@ -234,12 +236,11 @@ class PythonOp(Op):
def
perform
(
self
):
def
perform
(
self
):
exc
=
set
()
exc
=
set
()
if
isinstance
(
self
,
ext
.
Destroyer
):
for
output
,
inputs
in
self
.
destroy_map
()
.
items
():
for
output
,
inputs
in
self
.
destroy_map
()
.
items
():
exc
.
update
(
inputs
)
exc
.
update
(
inputs
)
for
input
in
inputs
:
for
input
in
inputs
:
if
self
.
input_is_constant
(
input
):
if
self
.
input_is_constant
(
input
):
raise
ValueError
(
"Input is constant:
%
s"
%
input
)
raise
ValueError
(
"Input is constant:
%
s"
%
input
)
for
input
in
exc
:
for
input
in
exc
:
self
.
check_input
(
input
)
self
.
check_input
(
input
)
input
.
up_to_date
=
False
input
.
up_to_date
=
False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论