Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ddb19a55
提交
ddb19a55
authored
2月 16, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
m
上级
efc14dba
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
119 行增加
和
71 行删除
+119
-71
core.py
core.py
+12
-16
lib.py
gof/lib.py
+59
-21
result.py
gof/result.py
+21
-15
utils.py
gof/utils.py
+7
-2
grad.py
grad.py
+3
-3
rand.py
rand.py
+1
-1
sparse.py
sparse.py
+16
-13
没有找到文件。
core.py
浏览文件 @
ddb19a55
...
@@ -11,7 +11,8 @@ import numpy
...
@@ -11,7 +11,8 @@ import numpy
from
scipy
import
weave
from
scipy
import
weave
import
gof
import
gof
from
gof
import
current_mode
,
set_mode
,
build_mode
,
eval_mode
,
build_eval_mode
,
pop_mode
,
UNCOMPUTED
,
UNDEFINED
,
PythonR
from
gof
import
current_mode
,
set_mode
,
build_mode
,
eval_mode
,
build_eval_mode
from
gof
import
pop_mode
,
UNCOMPUTED
,
UNDEFINED
,
ResultValue
import
type_spec
import
type_spec
import
cutils
import
cutils
...
@@ -105,12 +106,12 @@ def input(x):
...
@@ -105,12 +106,12 @@ def input(x):
elif
isinstance
(
x
,
gof
.
Result
):
elif
isinstance
(
x
,
gof
.
Result
):
raise
TypeError
(
"
%
s is already a result."
%
x
)
raise
TypeError
(
"
%
s is already a result."
%
x
)
else
:
else
:
return
PythonR
(
x
)
return
ResultValue
(
x
)
def
wrap
(
x
):
def
wrap
(
x
):
if
isinstance
(
x
,
NumpyR
):
if
isinstance
(
x
,
NumpyR
):
return
x
return
x
elif
isinstance
(
x
,
PythonR
):
elif
isinstance
(
x
,
ResultValue
):
return
x
return
x
elif
isinstance
(
x
,
omega_op
):
elif
isinstance
(
x
,
omega_op
):
return
x
.
out
return
x
.
out
...
@@ -144,7 +145,7 @@ def _literal_unhashable(x):
...
@@ -144,7 +145,7 @@ def _literal_unhashable(x):
return
r
return
r
def
literal
(
x
):
def
literal
(
x
):
"""Return a
PythonR
instance wrapping a literal."""
"""Return a
ResultValue
instance wrapping a literal."""
if
_hashable
(
x
):
if
_hashable
(
x
):
return
_literal_hashable
(
x
)
return
_literal_hashable
(
x
)
else
:
else
:
...
@@ -253,18 +254,18 @@ class omega_op(gof.PythonOp):
...
@@ -253,18 +254,18 @@ class omega_op(gof.PythonOp):
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of
PythonR
instances whose
In general, grad() should return a list of
ResultValue
instances whose
length matches that of self.inputs, and whose elements are the
length matches that of self.inputs, and whose elements are the
gradients of self.inputs.
gradients of self.inputs.
There is a (but often used) special feature in place to automatically
There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a
PythonR
instance
wrap the return value of grad() in a list if it is a
ResultValue
instance
and the op is unary. This makes many grad implementations a little
and the op is unary. This makes many grad implementations a little
cuter.
cuter.
"""
"""
inputgs
=
self
.
grad
(
*
(
self
.
inputs
+
[
grad_d
[
output
]
for
output
in
self
.
outputs
]))
inputgs
=
self
.
grad
(
*
(
self
.
inputs
+
[
grad_d
[
output
]
for
output
in
self
.
outputs
]))
if
len
(
self
.
inputs
)
==
1
and
isinstance
(
inputgs
,
gof
.
PythonR
):
if
len
(
self
.
inputs
)
==
1
and
isinstance
(
inputgs
,
gof
.
ResultValue
):
inputgs
=
[
inputgs
]
inputgs
=
[
inputgs
]
else
:
else
:
assert
len
(
inputgs
)
==
len
(
self
.
inputs
)
assert
len
(
inputgs
)
==
len
(
self
.
inputs
)
...
@@ -660,9 +661,10 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
...
@@ -660,9 +661,10 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return
normal_f
(
x
,
y
)
return
normal_f
(
x
,
y
)
return
f
return
f
class
NumpyR
(
gof
.
PythonR
):
class
NumpyR
(
gof
.
ResultValue
):
"""The class for storing ndarray return values from omega ops.
"""The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal PythonR:
The class provides additional functionality compared to the normal
ResultValue:
- operator overloads that correspond to omega ops such as add() and scale()
- operator overloads that correspond to omega ops such as add() and scale()
- special attributes that make it behave like an ndarray when passed to
- special attributes that make it behave like an ndarray when passed to
numpy functions.
numpy functions.
...
@@ -681,13 +683,7 @@ class NumpyR(gof.PythonR):
...
@@ -681,13 +683,7 @@ class NumpyR(gof.PythonR):
__array__
=
property
(
lambda
self
:
self
.
data
.
__array__
)
__array__
=
property
(
lambda
self
:
self
.
data
.
__array__
)
__array_struct__
=
property
(
lambda
self
:
self
.
data
.
__array_struct__
)
__array_struct__
=
property
(
lambda
self
:
self
.
data
.
__array_struct__
)
def
set_value
(
self
,
value
):
def
set_value_filter
(
self
,
value
):
return
numpy
.
asarray
(
value
)
if
value
is
UNCOMPUTED
:
self
.
data
=
UNCOMPUTED
else
:
self
.
data
=
numpy
.
asarray
(
value
)
self
.
refresh
()
self
.
up_to_date
=
True
def
set_value_inplace
(
self
,
value
):
def
set_value_inplace
(
self
,
value
):
if
value
is
UNCOMPUTED
:
if
value
is
UNCOMPUTED
:
...
...
gof/lib.py
浏览文件 @
ddb19a55
from
op
import
Op
from
op
import
Op
from
result
import
Result
#, HolderResult
from
result
import
Result
#, HolderResult
from
utils
import
ClsInit
,
Keyword
from
utils
import
ClsInit
,
Keyword
,
AbstractFunctionError
import
opt
import
opt
import
env
import
env
import
features
import
features
...
@@ -16,7 +16,7 @@ __all__ = ['UNCOMPUTED',
...
@@ -16,7 +16,7 @@ __all__ = ['UNCOMPUTED',
'eval_mode'
,
'eval_mode'
,
'build_eval_mode'
,
'build_eval_mode'
,
'pop_mode'
,
'pop_mode'
,
'
PythonR
'
,
'
ResultValue
'
,
'DummyOp'
,
'DummyOp'
,
'DummyRemover'
,
'DummyRemover'
,
'PythonOp'
,
'PythonOp'
,
...
@@ -27,7 +27,6 @@ __all__ = ['UNCOMPUTED',
...
@@ -27,7 +27,6 @@ __all__ = ['UNCOMPUTED',
UNCOMPUTED
=
Keyword
(
"UNCOMPUTED"
,
False
)
UNCOMPUTED
=
Keyword
(
"UNCOMPUTED"
,
False
)
UNDEFINED
=
Keyword
(
"UNDEFINED"
,
False
)
UNDEFINED
=
Keyword
(
"UNDEFINED"
,
False
)
def
make_static
(
cls
,
fname
):
def
make_static
(
cls
,
fname
):
f
=
getattr
(
cls
,
fname
)
f
=
getattr
(
cls
,
fname
)
if
hasattr
(
f
,
'im_func'
):
if
hasattr
(
f
,
'im_func'
):
...
@@ -79,7 +78,26 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
...
@@ -79,7 +78,26 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class
PythonR
(
Result
):
class
ResultValue
(
Result
):
"""Augment Result to wrap a computed value.
Attributes:
data -
spec -
constant -
up_to_date -
Properties:
Methods:
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
Notes:
"""
__slots__
=
[
'data'
,
'spec'
,
'constant'
,
'up_to_date'
]
__slots__
=
[
'data'
,
'spec'
,
'constant'
,
'up_to_date'
]
...
@@ -90,37 +108,57 @@ class PythonR(Result):
...
@@ -90,37 +108,57 @@ class PythonR(Result):
self
.
up_to_date
=
True
self
.
up_to_date
=
True
self
.
spec
=
None
self
.
spec
=
None
def
__str__
(
self
):
return
str
(
self
.
data
)
def
__repr__
(
self
):
return
repr
(
self
.
data
)
#TODO: document this function, what does it do?
def
refresh
(
self
):
self
.
spec
=
id
(
self
.
data
)
####################################################
#
# Functionality provided by this class
#
def
set_value
(
self
,
value
):
def
set_value
(
self
,
value
):
if
self
.
constant
:
if
self
.
constant
:
raise
Exception
(
"This Result is a constant. Its value cannot be changed."
)
raise
Exception
(
"This Result is a constant. Its value cannot be changed."
)
if
value
is
None
or
value
is
UNCOMPUTED
:
if
value
is
None
or
value
is
UNCOMPUTED
:
self
.
data
=
UNCOMPUTED
self
.
data
=
UNCOMPUTED
elif
isinstance
(
value
,
PythonR
):
elif
isinstance
(
value
,
ResultValue
):
self
.
set_value
(
value
.
data
)
self
.
set_value
(
value
.
data
)
else
:
else
:
try
:
self
.
data
=
self
.
set_value_filter
(
value
)
except
AbstractFunctionError
,
e
:
self
.
data
=
value
self
.
data
=
value
self
.
up_to_date
=
True
self
.
up_to_date
=
True
self
.
refresh
()
self
.
refresh
()
def
set_value_inplace
(
self
,
value
):
def
compute
(
self
):
raise
NotImplementedError
()
#HACK: this is potentially very broken behaviour
"""Overrides Op.compute(). Only recurses if self.data is UNCOMPUTED"""
if
self
.
data
is
UNCOMPUTED
:
Result
.
compute
(
self
)
def
__str__
(
self
):
####################################################
return
str
(
self
.
data
)
#
# Pure virtual functions for subclasses to implement
#
def
__repr__
(
self
):
# Perform error checking or automatic conversion of value, and return the
return
repr
(
self
.
data
)
# result (which will be stored as self.data)
# Called by: set_value()
def
set_value_filter
(
self
,
value
):
raise
AbstractFunctionError
()
def
refresh
(
self
):
# For mutable data types, overwrite the current contents with value
self
.
spec
=
id
(
self
.
data
)
# Also, call refresh and set up_to_date = True
def
set_value_inplace
(
self
,
value
):
raise
AbstractFunctionError
()
def
alloc
(
self
):
# Instantiate data (according to spec)
raise
TypeError
(
"Cannot allocate following this specification."
)
def
alloc
(
self
):
raise
AbstractFunctionError
(
)
def
compute
(
self
):
"""Overrides Op.compute(). Only recurses if self.data is UNCOMPUTED"""
if
self
.
data
is
UNCOMPUTED
:
self
.
owner
.
compute
()
class
PythonOp
(
Op
):
class
PythonOp
(
Op
):
...
@@ -157,10 +195,10 @@ class PythonOp(Op):
...
@@ -157,10 +195,10 @@ class PythonOp(Op):
def
__validate__
(
self
):
def
__validate__
(
self
):
for
input
in
self
.
inputs
:
for
input
in
self
.
inputs
:
assert
isinstance
(
input
,
PythonR
)
assert
isinstance
(
input
,
ResultValue
)
def
gen_outputs
(
self
):
def
gen_outputs
(
self
):
return
[
PythonR
()
for
i
in
xrange
(
self
.
nout
)]
return
[
ResultValue
()
for
i
in
xrange
(
self
.
nout
)]
def
root_inputs
(
self
,
input
):
def
root_inputs
(
self
,
input
):
owner
=
input
.
owner
owner
=
input
.
owner
...
...
gof/result.py
浏览文件 @
ddb19a55
...
@@ -7,6 +7,7 @@ value that is the input or the output of an Op.
...
@@ -7,6 +7,7 @@ value that is the input or the output of an Op.
from
err
import
GofError
from
err
import
GofError
from
utils
import
AbstractFunctionError
__all__
=
[
'Result'
,
'BrokenLink'
,
'BrokenLinkError'
]
__all__
=
[
'Result'
,
'BrokenLink'
,
'BrokenLinkError'
]
...
@@ -40,13 +41,17 @@ class BrokenLinkError(GofError):
...
@@ -40,13 +41,17 @@ class BrokenLinkError(GofError):
############################
############################
class
Result
(
object
):
class
Result
(
object
):
"""
"""Storage node for data in a graph of Op instances.
The Result class represents a datum for use in a graph of Ops. It
has two slots:
- owner: represents the Op which computes this Result. Contains either None
Attributes:
owner - represents the Op which computes this Result. Contains either None
or an instance of Op.
or an instance of Op.
- index: the index of this Result in owner.outputs.
index - the index of this Result in owner.outputs.
Methods:
-
Notes:
Result has no __init__ or __new__ routine. It is the Op's
Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results.
responsibility to set the owner field of its results.
...
@@ -70,7 +75,8 @@ class Result(object):
...
@@ -70,7 +75,8 @@ class Result(object):
self
.
_owner
=
None
self
.
_owner
=
None
return
self
.
_owner
return
self
.
_owner
owner
=
property
(
get_owner
,
doc
=
"The Op of which this Result is an output or None if there is no such Op."
)
owner
=
property
(
get_owner
,
doc
=
"The Op of which this Result is an output or None if there is no such Op."
)
def
set_owner
(
self
,
owner
,
index
):
def
set_owner
(
self
,
owner
,
index
):
if
self
.
owner
is
not
None
:
if
self
.
owner
is
not
None
:
...
@@ -94,21 +100,21 @@ class Result(object):
...
@@ -94,21 +100,21 @@ class Result(object):
self
.
_owner
=
owner
self
.
_owner
=
owner
self
.
_index
=
index
self
.
_index
=
index
def
compute
(
self
):
"""If self has an owner, recursively compute it.
def
set_value
(
self
,
value
):
This is a mutually recursive function with gof.op.Op
"""
Copies the provided value in this Result. It is not required to
implement this method.
"""
raise
NotImplementedError
(
"This Result does not support set_value."
)
def
compute
(
self
):
"""
"""If self has an owner, recursively compute it."""
if
self
.
owner
:
if
self
.
owner
:
self
.
owner
.
compute
()
self
.
owner
.
compute
()
def
perform
(
self
):
def
perform
(
self
):
"""Calls self.owner.perform() if self.owner exists."""
"""Calls self.owner.perform() if self.owner exists.
This is a mutually recursive function with gof.op.Op
"""
if
self
.
owner
:
if
self
.
owner
:
self
.
owner
.
perform
()
self
.
owner
.
perform
()
...
...
gof/utils.py
浏览文件 @
ddb19a55
...
@@ -3,9 +3,14 @@
...
@@ -3,9 +3,14 @@
# import result
# import result
class
OmegaError
(
Exception
):
class
OmegaError
(
Exception
):
pass
pass
class
AbstractFunctionError
(
Exception
):
"""To be raised by functions defined as part of an interface.
When the user sees such an error, it is because an important interface
function has been left out of an implementation class.
"""
def
all_bases
(
cls
,
accept
):
def
all_bases
(
cls
,
accept
):
...
...
grad.py
浏览文件 @
ddb19a55
...
@@ -156,18 +156,18 @@ class update_gradient_via_grad:
...
@@ -156,18 +156,18 @@ class update_gradient_via_grad:
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of
PythonR
instances whose
In general, grad() should return a list of
ResultValue
instances whose
length matches that of self.inputs, and whose elements are the
length matches that of self.inputs, and whose elements are the
gradients of self.inputs.
gradients of self.inputs.
There is a (but often used) special feature in place to automatically
There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a
PythonR
instance
wrap the return value of grad() in a list if it is a
ResultValue
instance
and the op is unary. This makes many grad implementations a little
and the op is unary. This makes many grad implementations a little
cuter.
cuter.
"""
"""
inputgs
=
self
.
grad
(
*
(
self
.
inputs
+
[
grad_d
[
output
]
for
output
in
self
.
outputs
]))
inputgs
=
self
.
grad
(
*
(
self
.
inputs
+
[
grad_d
[
output
]
for
output
in
self
.
outputs
]))
if
len
(
self
.
inputs
)
==
1
and
isinstance
(
inputgs
,
gof
.
PythonR
):
if
len
(
self
.
inputs
)
==
1
and
isinstance
(
inputgs
,
gof
.
ResultValue
):
inputgs
=
[
inputgs
]
inputgs
=
[
inputgs
]
else
:
else
:
assert
len
(
inputgs
)
==
len
(
self
.
inputs
)
assert
len
(
inputgs
)
==
len
(
self
.
inputs
)
...
...
rand.py
浏览文件 @
ddb19a55
...
@@ -16,7 +16,7 @@ class RandomState(gof.Op, gof.ext.IONames):
...
@@ -16,7 +16,7 @@ class RandomState(gof.Op, gof.ext.IONames):
def
__init__
(
self
,
seed
):
def
__init__
(
self
,
seed
):
inputs
=
[
wrap
(
seed
)]
inputs
=
[
wrap
(
seed
)]
outputs
=
[
PythonR
()]
outputs
=
[
ResultValue
()]
gof
.
Op
.
__init__
(
self
,
inputs
,
outputs
)
gof
.
Op
.
__init__
(
self
,
inputs
,
outputs
)
def
thunk
(
self
):
def
thunk
(
self
):
...
...
sparse.py
浏览文件 @
ddb19a55
...
@@ -9,24 +9,27 @@ import grad
...
@@ -9,24 +9,27 @@ import grad
# Wrapper type
# Wrapper type
class
SparseR
(
gof
.
PythonR
):
class
SparseR
(
gof
.
ResultValue
):
"""
"""
Attribute:
Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__
format - a subclass of sparse.spmatrix indicating self.data.__class__
Properties:
T - read-only: return a transpose of self
Methods:
Notes:
"""
"""
def
__init__
(
self
,
x
=
core
.
UNCOMPUTED
,
constant
=
False
,
def
__init__
(
self
,
x
=
core
.
UNCOMPUTED
,
constant
=
False
,
format
=
sparse
.
csr_matrix
):
format
=
sparse
.
csr_matrix
):
gof
.
PythonR
.
__init__
(
self
,
x
,
constant
)
gof
.
ResultValue
.
__init__
(
self
,
x
,
constant
)
self
.
format
=
isinstance
(
x
,
sparse
.
spmatrix
)
and
x
.
__class__
or
format
self
.
format
=
isinstance
(
x
,
sparse
.
spmatrix
)
and
x
.
__class__
or
format
def
set_value
(
self
,
value
):
def
set_value_filter
(
self
,
value
):
"""Extend base impl, assert value is sparse matrix"""
if
isinstance
(
value
,
sparse
.
spmatrix
):
return
value
gof
.
PythonR
.
set_value
(
self
,
value
)
return
sparse
.
csr_matrix
(
value
)
if
self
.
data
is
not
core
.
UNCOMPUTED
:
if
not
isinstance
(
self
.
data
,
sparse
.
spmatrix
):
print
self
.
data
.
__class__
print
self
.
owner
.
__class__
raise
TypeError
((
'hrm'
,
value
))
def
__add__
(
left
,
right
):
return
add
(
left
,
right
)
def
__add__
(
left
,
right
):
return
add
(
left
,
right
)
def
__radd__
(
right
,
left
):
return
add
(
left
,
right
)
def
__radd__
(
right
,
left
):
return
add
(
left
,
right
)
...
@@ -148,11 +151,11 @@ class _testCase_dot(unittest.TestCase):
...
@@ -148,11 +151,11 @@ class _testCase_dot(unittest.TestCase):
m
=
mtype
(
a
)
m
=
mtype
(
a
)
ab
=
m
.
dot
(
b
)
ab
=
m
.
dot
(
b
)
try
:
try
:
z
=
dot
(
SparseR
(
m
),
gof
.
lib
.
PythonR
(
b
))
z
=
dot
(
SparseR
(
m
),
gof
.
lib
.
ResultValue
(
b
))
self
.
failUnless
(
z
.
data
.
shape
==
ab
.
shape
)
self
.
failUnless
(
z
.
data
.
shape
==
ab
.
shape
)
self
.
failUnless
(
type
(
z
.
data
)
==
type
(
ab
))
self
.
failUnless
(
type
(
z
.
data
)
==
type
(
ab
))
except
Exception
,
e
:
except
Exception
,
e
:
print
mtype
,
e
,
str
(
e
)
print
'cccc'
,
mtype
,
e
,
str
(
e
)
raise
raise
def
test_basic2
(
self
):
def
test_basic2
(
self
):
"""dot: sparse right"""
"""dot: sparse right"""
...
@@ -164,7 +167,7 @@ class _testCase_dot(unittest.TestCase):
...
@@ -164,7 +167,7 @@ class _testCase_dot(unittest.TestCase):
sparse
.
lil_matrix
]:
#, sparse.coo_matrix]:
sparse
.
lil_matrix
]:
#, sparse.coo_matrix]:
m
=
mtype
(
b
)
m
=
mtype
(
b
)
ab
=
m
.
transpose
()
.
dot
(
a
.
transpose
())
.
transpose
()
ab
=
m
.
transpose
()
.
dot
(
a
.
transpose
())
.
transpose
()
z
=
dot
(
gof
.
lib
.
PythonR
(
a
),
SparseR
(
mtype
(
b
)))
z
=
dot
(
gof
.
lib
.
ResultValue
(
a
),
SparseR
(
mtype
(
b
)))
self
.
failUnless
(
z
.
data
.
shape
==
ab
.
shape
)
self
.
failUnless
(
z
.
data
.
shape
==
ab
.
shape
)
self
.
failUnless
(
type
(
z
.
data
)
==
type
(
ab
))
self
.
failUnless
(
type
(
z
.
data
)
==
type
(
ab
))
def
test_graph_bprop0
(
self
):
def
test_graph_bprop0
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论