Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d4e705ea
提交
d4e705ea
authored
2月 22, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
minor fixes, working on producer in build mode... getting state right
上级
32249295
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
92 行增加
和
36 行删除
+92
-36
core.py
core.py
+37
-7
__init__.py
gof/__init__.py
+1
-10
lib.py
gof/lib.py
+5
-4
result.py
gof/result.py
+48
-14
grad.py
grad.py
+1
-1
没有找到文件。
core.py
浏览文件 @
d4e705ea
...
@@ -84,6 +84,8 @@ def _compile_dir():
...
@@ -84,6 +84,8 @@ def _compile_dir():
sys
.
path
.
append
(
cachedir
)
sys
.
path
.
append
(
cachedir
)
return
cachedir
return
cachedir
class
Allocated
:
"""Memory has been allocated, but contents are not the owner's output."""
class
Numpy2
(
ResultBase
):
class
Numpy2
(
ResultBase
):
"""Result storing a numpy ndarray"""
"""Result storing a numpy ndarray"""
__slots__
=
[
'_dtype'
,
'_shape'
,
]
__slots__
=
[
'_dtype'
,
'_shape'
,
]
...
@@ -121,6 +123,7 @@ class Numpy2(ResultBase):
...
@@ -121,6 +123,7 @@ class Numpy2(ResultBase):
def
data_alloc
(
self
):
def
data_alloc
(
self
):
self
.
data
=
numpy
.
ndarray
(
self
.
shape
,
self
.
dtype
)
self
.
data
=
numpy
.
ndarray
(
self
.
shape
,
self
.
dtype
)
self
.
state
=
Allocated
# self._dtype is used when self._data hasn't been set yet
# self._dtype is used when self._data hasn't been set yet
def
__dtype_get
(
self
):
def
__dtype_get
(
self
):
...
@@ -351,15 +354,15 @@ class _testCase_literal(unittest.TestCase):
...
@@ -351,15 +354,15 @@ class _testCase_literal(unittest.TestCase):
def
cgetspecs
(
names
,
vals
,
converters
):
def
cgen
(
name
,
behavior
,
names
,
vals
,
converters
=
None
):
def
cgetspecs
(
names
,
vals
,
converters
):
d
=
{}
d
=
{}
for
name
,
value
in
zip
(
names
,
vals
):
for
name
,
value
in
zip
(
names
,
vals
):
d
[
name
]
=
value
.
data
d
[
name
]
=
value
.
data
specs
=
weave
.
ext_tools
.
assign_variable_types
(
names
,
d
,
type_converters
=
converters
)
#, auto_downcast = 0)
specs
=
weave
.
ext_tools
.
assign_variable_types
(
names
,
d
,
type_converters
=
converters
)
#, auto_downcast = 0)
return
d
,
specs
return
d
,
specs
def
cgen
(
name
,
behavior
,
names
,
vals
,
converters
=
None
):
if
not
converters
:
if
not
converters
:
converters
=
type_spec
.
default
converters
=
type_spec
.
default
for
converter
in
converters
:
for
converter
in
converters
:
...
@@ -872,6 +875,27 @@ array = wrap_producer(numpy.array)
...
@@ -872,6 +875,27 @@ array = wrap_producer(numpy.array)
zeros
=
wrap_producer
(
numpy
.
zeros
)
zeros
=
wrap_producer
(
numpy
.
zeros
)
ones
=
wrap_producer
(
numpy
.
ones
)
ones
=
wrap_producer
(
numpy
.
ones
)
class
_testCase_producer_build_mode
(
unittest
.
TestCase
):
def
test_0
(
self
):
"""producer in build mode"""
build_mode
()
a
=
ones
(
4
)
self
.
failUnless
(
a
.
data
is
None
)
self
.
failUnless
(
a
.
state
is
gof
.
result
.
Empty
)
self
.
failUnless
(
a
.
shape
==
(
4
,))
self
.
failUnless
(
a
.
dtype
==
'float64'
)
pop_mode
()
def
test_1
(
self
):
"""producer in build_eval mode"""
build_eval_mode
()
a
=
ones
(
4
)
self
.
failUnless
((
a
.
data
==
numpy
.
ones
(
4
))
.
all
())
self
.
failUnless
(
a
.
state
is
gof
.
result
.
Computed
)
self
.
failUnless
(
a
.
shape
==
(
4
,))
self
.
failUnless
(
a
.
dtype
==
'float64'
)
pop_mode
()
# Wrapper to ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)
# Wrapper to ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)
def
assert_same_shapes
(
impl
):
def
assert_same_shapes
(
impl
):
...
@@ -923,6 +947,8 @@ add_elemwise_inplace = add_elemwise.inplace_version()
...
@@ -923,6 +947,8 @@ add_elemwise_inplace = add_elemwise.inplace_version()
add_elemwise_inplace
.
set_impl
(
assert_same_shapes
(
numpy
.
ndarray
.
__iadd__
))
add_elemwise_inplace
.
set_impl
(
assert_same_shapes
(
numpy
.
ndarray
.
__iadd__
))
class
add_scalar
(
tensor_scalar_op
):
class
add_scalar
(
tensor_scalar_op
):
impl
=
tensor_scalar_impl
(
numpy
.
ndarray
.
__add__
)
impl
=
tensor_scalar_impl
(
numpy
.
ndarray
.
__add__
)
def
grad
(
x
,
a
,
gz
):
def
grad
(
x
,
a
,
gz
):
...
@@ -932,6 +958,13 @@ class add_scalar(tensor_scalar_op):
...
@@ -932,6 +958,13 @@ class add_scalar(tensor_scalar_op):
add_scalar_inplace
=
add_scalar
.
inplace_version
()
add_scalar_inplace
=
add_scalar
.
inplace_version
()
add_scalar_inplace
.
set_impl
(
tensor_scalar_impl
(
numpy
.
ndarray
.
__iadd__
))
add_scalar_inplace
.
set_impl
(
tensor_scalar_impl
(
numpy
.
ndarray
.
__iadd__
))
class
_testCase_add_build_mode
(
unittest
.
TestCase
):
def
setUp
(
self
):
build_mode
()
numpy
.
random
.
seed
(
44
)
def
tearDown
(
self
):
pop_mode
()
class
twice
(
elemwise
):
class
twice
(
elemwise
):
def
impl
(
x
):
def
impl
(
x
):
return
2.0
*
x
return
2.0
*
x
...
@@ -1100,10 +1133,7 @@ class dot(omega_op):
...
@@ -1100,10 +1133,7 @@ class dot(omega_op):
def
c_impl
((
_x
,
_y
),
(
_z
,
)):
def
c_impl
((
_x
,
_y
),
(
_z
,
)):
return
blas
.
gemm_code
(
''
,
'1.0'
,
'0.0'
)
return
blas
.
gemm_code
(
''
,
'1.0'
,
'0.0'
)
if
0
:
class
_testCase_dot
(
unittest
.
TestCase
):
print
'SKIPPING DOT TESTS'
else
:
class
_testCase_dot
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
build_eval_mode
()
build_eval_mode
()
numpy
.
random
.
seed
(
44
)
numpy
.
random
.
seed
(
44
)
...
...
gof/__init__.py
浏览文件 @
d4e705ea
# from op import *
import
op
,
ext
,
lib
,
link
,
result
,
env
,
prog
,
features
,
opt
,
graph
# from value import *
# from opt import *
# from env import *
# from prog import *
# from diff import *
# import dispatchers
from
op
import
*
from
op
import
*
from
ext
import
*
from
ext
import
*
...
@@ -18,5 +11,3 @@ from features import *
...
@@ -18,5 +11,3 @@ from features import *
from
opt
import
*
from
opt
import
*
import
graph
import
graph
#import utils
gof/lib.py
浏览文件 @
d4e705ea
...
@@ -44,6 +44,8 @@ def compute_from(nodes, history):
...
@@ -44,6 +44,8 @@ def compute_from(nodes, history):
if
hasattr
(
node
,
'owner'
):
#node is storage
if
hasattr
(
node
,
'owner'
):
#node is storage
compute_recursive
(
node
.
owner
)
compute_recursive
(
node
.
owner
)
else
:
#node is op
else
:
#node is op
if
node
.
destroy_map
():
raise
ValueError
(
'compute_from() does not work on nodes with destroy_maps'
)
for
input
in
node
.
inputs
:
for
input
in
node
.
inputs
:
compute_recursive
(
input
)
compute_recursive
(
input
)
node
.
perform
()
node
.
perform
()
...
@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
...
@@ -95,8 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
else
:
else
:
return
True
return
True
class
DestroyHandler
(
features
.
Listener
,
features
.
Constraint
,
features
.
Orderings
):
class
DestroyHandler
(
features
.
Listener
,
features
.
Constraint
,
features
.
Orderings
):
def
__init__
(
self
,
env
):
def
__init__
(
self
,
env
):
...
@@ -383,13 +383,14 @@ class PythonOp(Op):
...
@@ -383,13 +383,14 @@ class PythonOp(Op):
return
all
([
is_result
(
i
)
for
i
in
self
.
inputs
])
return
all
([
is_result
(
i
)
for
i
in
self
.
inputs
])
def
gen_outputs
(
self
):
def
gen_outputs
(
self
):
raise
NotImplemented
Error
()
raise
AbstractFunction
Error
()
def
view_map
(
self
):
return
{}
def
view_map
(
self
):
return
{}
def
destroy_map
(
self
):
return
{}
def
destroy_map
(
self
):
return
{}
def
root_inputs
(
self
,
input
):
@staticmethod
def
root_inputs
(
input
):
owner
=
input
.
owner
owner
=
input
.
owner
if
owner
:
if
owner
:
view_map
=
owner
.
view_map
()
view_map
=
owner
.
view_map
()
...
...
gof/result.py
浏览文件 @
d4e705ea
...
@@ -11,7 +11,7 @@ from err import GofError
...
@@ -11,7 +11,7 @@ from err import GofError
from
utils
import
AbstractFunctionError
from
utils
import
AbstractFunctionError
__all__
=
[
'is_result'
,
'ResultBase'
,
'BrokenLink'
,
'BrokenLinkError'
]
__all__
=
[
'is_result'
,
'ResultBase'
,
'BrokenLink'
,
'BrokenLinkError'
]
class
BrokenLink
:
class
BrokenLink
:
...
@@ -36,6 +36,10 @@ class BrokenLinkError(GofError):
...
@@ -36,6 +36,10 @@ class BrokenLinkError(GofError):
pass
pass
# ResultBase state keywords
class
Empty
:
pass
class
Computed
:
pass
############################
############################
# Result
# Result
...
@@ -53,6 +57,7 @@ class ResultBase(object):
...
@@ -53,6 +57,7 @@ class ResultBase(object):
_role - None or (owner, index) or BrokenLink
_role - None or (owner, index) or BrokenLink
_data - anything
_data - anything
constant - Boolean
constant - Boolean
state - one of (Empty, Allocated, Computed)
Properties:
Properties:
role - (rw)
role - (rw)
...
@@ -60,13 +65,12 @@ class ResultBase(object):
...
@@ -60,13 +65,12 @@ class ResultBase(object):
index - (ro)
index - (ro)
data - (rw)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods:
Abstract Methods:
data_filter
data_filter
Notes:
Notes
(from previous implementation)
:
A Result instance should be immutable: indeed, if some aspect of a
A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become
Result is changed, operations that use it might suddenly become
...
@@ -89,22 +93,28 @@ class ResultBase(object):
...
@@ -89,22 +93,28 @@ class ResultBase(object):
class
AbstractFunction
(
Exception
):
class
AbstractFunction
(
Exception
):
"""Exception thrown when an abstract function is called"""
"""Exception thrown when an abstract function is called"""
__slots__
=
[
'_role'
,
'
_data'
,
'constant
'
]
__slots__
=
[
'_role'
,
'
constant'
,
'_data'
,
'state
'
]
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
self
.
_role
=
role
self
.
_role
=
role
self
.
constant
=
constant
self
.
constant
=
constant
if
data
is
None
:
#None is not filtered
if
data
is
None
:
#None is not filtered
self
.
_data
=
None
self
.
_data
=
None
self
.
state
=
Empty
else
:
else
:
try
:
try
:
self
.
_data
=
self
.
data_filter
(
data
)
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
except
ResultBase
.
AbstractFunction
:
self
.
_data
=
data
self
.
_data
=
data
self
.
state
=
Computed
#
# role
#
#role is pair: (owner, outputs_position)
def
__get_role
(
self
):
def
__get_role
(
self
):
return
self
.
_role
return
self
.
_role
def
__set_role
(
self
,
role
):
def
__set_role
(
self
,
role
):
owner
,
index
=
role
owner
,
index
=
role
if
self
.
_role
is
not
None
:
if
self
.
_role
is
not
None
:
...
@@ -116,29 +126,41 @@ class ResultBase(object):
...
@@ -116,29 +126,41 @@ class ResultBase(object):
raise
ValueError
(
"Result
%
s was already mapped to a different index."
%
self
)
raise
ValueError
(
"Result
%
s was already mapped to a different index."
%
self
)
return
# because _owner is owner and _index == index
return
# because _owner is owner and _index == index
self
.
_role
=
role
self
.
_role
=
role
role
=
property
(
__get_role
,
__set_role
)
role
=
property
(
__get_role
,
__set_role
)
#owner is role[0]
#
# owner
#
def
__get_owner
(
self
):
def
__get_owner
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
0
]
return
self
.
_role
[
0
]
owner
=
property
(
__get_owner
,
owner
=
property
(
__get_owner
,
doc
=
"Op of which this Result is an output, or None if role is None"
)
doc
=
"Op of which this Result is an output, or None if role is None"
)
#index is role[1]
#
# index
#
def
__get_index
(
self
):
def
__get_index
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
1
]
return
self
.
_role
[
1
]
index
=
property
(
__get_index
,
index
=
property
(
__get_index
,
doc
=
"position of self in owner's outputs, or None if role is None"
)
doc
=
"position of self in owner's outputs, or None if role is None"
)
# assigning to self.data will invoke self.data_filter(value) if that
#
# function is defined
# data
#
def
__get_data
(
self
):
def
__get_data
(
self
):
return
self
.
_data
return
self
.
_data
def
__set_data
(
self
,
data
):
def
__set_data
(
self
,
data
):
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
...
@@ -146,27 +168,39 @@ class ResultBase(object):
...
@@ -146,27 +168,39 @@ class ResultBase(object):
self
.
_data
=
self
.
data_filter
(
data
)
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
#use default behaviour
except
ResultBase
.
AbstractFunction
:
#use default behaviour
self
.
_data
=
data
self
.
_data
=
data
self
.
state
=
Computed
data
=
property
(
__get_data
,
__set_data
,
data
=
property
(
__get_data
,
__set_data
,
doc
=
"The storage associated with this result"
)
doc
=
"The storage associated with this result"
)
def
data_filter
(
self
,
data
):
def
data_filter
(
self
,
data
):
"""(abstract) Return an appropriate _data based on data."""
"""(abstract) Return an appropriate _data based on data.
If a subclass overrides this function, then that overriding
implementation will be used in __set_data to map the argument to
self._data. This gives a subclass the opportunity to ensure that
the contents of self._data remain sensible.
"""
raise
ResultBase
.
AbstractFunction
()
raise
ResultBase
.
AbstractFunction
()
#
# replaced
# replaced
def
__get_replaced
(
self
):
return
isinstance
(
self
.
_role
,
ResultBase
.
BrokenLink
)
#
def
__get_replaced
(
self
):
return
isinstance
(
self
.
_role
,
ResultBase
.
BrokenLink
)
def
__set_replaced
(
self
,
replace
):
def
__set_replaced
(
self
,
replace
):
if
replace
==
self
.
replaced
:
return
if
replace
==
self
.
replaced
:
return
if
replace
:
if
replace
:
self
.
_role
=
ResultBase
.
BrokenLink
(
self
.
_role
)
self
.
_role
=
ResultBase
.
BrokenLink
(
self
.
_role
)
else
:
else
:
self
.
_role
=
self
.
_role
.
old_role
self
.
_role
=
self
.
_role
.
old_role
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
# computed
#TODO: think about how to handle this more correctly
computed
=
property
(
lambda
self
:
self
.
_data
is
not
None
)
#################
#################
...
...
grad.py
浏览文件 @
d4e705ea
...
@@ -77,7 +77,7 @@ class Grad(object):
...
@@ -77,7 +77,7 @@ class Grad(object):
r
.
shape
,
dr
.
shape
))
r
.
shape
,
dr
.
shape
))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if
r
.
c
omputed
:
if
r
.
state
is
gof
.
result
.
C
omputed
:
self
.
_compute_history
.
add
(
r
)
self
.
_compute_history
.
add
(
r
)
# add dr to self[r]
# add dr to self[r]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论