Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a007e686
提交
a007e686
authored
2月 22, 2008
作者:
bergstrj@iro.umontreal.ca
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
removed dead code, installed ResultBase to gof
上级
b0b514e4
全部展开
显示空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
149 行增加
和
621 行删除
+149
-621
autotest.py
autotest.py
+2
-1
core.py
core.py
+5
-215
compile.py
gof/compile.py
+0
-169
ext.py
gof/ext.py
+0
-1
graph.py
gof/graph.py
+1
-1
lib.py
gof/lib.py
+4
-83
listeners.py
gof/listeners.py
+0
-0
op.py
gof/op.py
+1
-1
result.py
gof/result.py
+130
-146
grad.py
grad.py
+6
-4
没有找到文件。
autotest.py
浏览文件 @
a007e686
...
...
@@ -2,7 +2,8 @@ import unittest, os, sys
if
__name__
==
'__main__'
:
suite
=
None
for
filename
in
os
.
listdir
(
'.'
):
filenames
=
os
.
listdir
(
'.'
)
+
[
'gof.'
+
s
for
s
in
os
.
listdir
(
'gof'
)]
for
filename
in
filenames
:
if
filename
[
-
3
:]
==
'.py'
:
modname
=
filename
[:
-
3
]
tests
=
unittest
.
TestLoader
()
.
loadTestsFromModule
(
__import__
(
modname
))
...
...
core.py
浏览文件 @
a007e686
...
...
@@ -12,7 +12,7 @@ from scipy import weave
import
gof
from
gof
import
current_mode
,
set_mode
,
build_mode
,
eval_mode
,
build_eval_mode
from
gof
import
pop_mode
,
is_result
from
gof
import
pop_mode
,
is_result
,
ResultBase
import
type_spec
import
cutils
...
...
@@ -84,138 +84,6 @@ def _compile_dir():
sys
.
path
.
append
(
cachedir
)
return
cachedir
class
ResultBase
(
object
):
"""Base class for storing Op inputs and outputs
Attributes:
_role - None or (owner, index) or BrokenLink
_data - anything
constant - Boolean
Properties:
role - (rw)
owner - (ro)
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods:
data_filter
"""
class
BrokenLink
:
"""The owner of a Result that was replaced by another Result"""
__slots__
=
[
'old_role'
]
def
__init__
(
self
,
role
):
self
.
old_role
=
role
def
__nonzero__
(
self
):
return
False
class
BrokenLinkError
(
Exception
):
"""Exception thrown when an owner is a BrokenLink"""
class
AbstractFunction
(
Exception
):
"""Exception thrown when an abstract function is called"""
__slots__
=
[
'_role'
,
'_data'
,
'constant'
]
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
self
.
_role
=
role
self
.
constant
=
constant
if
data
is
None
:
#None is not filtered
self
.
_data
=
None
else
:
try
:
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
self
.
_data
=
data
#role is pair: (owner, outputs_position)
def
__get_role
(
self
):
return
self
.
_role
def
__set_role
(
self
,
role
):
owner
,
index
=
role
if
self
.
_role
is
not
None
:
# this is either an error or a no-op
_owner
,
_index
=
self
.
_role
if
_owner
is
not
owner
:
raise
ValueError
(
"Result
%
s already has an owner."
%
self
)
if
_index
!=
index
:
raise
ValueError
(
"Result
%
s was already mapped to a different index."
%
self
)
return
# because _owner is owner and _index == index
self
.
_role
=
role
role
=
property
(
__get_role
,
__set_role
)
#owner is role[0]
def
__get_owner
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
0
]
owner
=
property
(
__get_owner
,
doc
=
"Op of which this Result is an output, or None if role is None"
)
#index is role[1]
def
__get_index
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
1
]
index
=
property
(
__get_index
,
doc
=
"position of self in owner's outputs, or None if role is None"
)
# assigning to self.data will invoke self.data_filter(value) if that
# function is defined
def
__get_data
(
self
):
return
self
.
_data
def
__set_data
(
self
,
data
):
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
try
:
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
#use default behaviour
self
.
_data
=
data
data
=
property
(
__get_data
,
__set_data
,
doc
=
"The storage associated with this result"
)
def
data_filter
(
self
,
data
):
"""(abstract) Return an appropriate _data based on data."""
raise
ResultBase
.
AbstractFunction
()
# replaced
def
__get_replaced
(
self
):
return
isinstance
(
self
.
_role
,
ResultBase
.
BrokenLink
)
def
__set_replaced
(
self
,
replace
):
if
replace
==
self
.
replaced
:
return
if
replace
:
self
.
_role
=
ResultBase
.
BrokenLink
(
self
.
_role
)
else
:
self
.
_role
=
self
.
_role
.
old_role
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
# computed
#TODO: think about how to handle this more correctly
computed
=
property
(
lambda
self
:
self
.
_data
is
not
None
)
#################
# NumpyR Compatibility
#
up_to_date
=
property
(
lambda
self
:
True
)
def
refresh
(
self
):
pass
def
set_owner
(
self
,
owner
,
idx
):
self
.
role
=
(
owner
,
idx
)
def
set_value
(
self
,
value
):
self
.
data
=
value
#may raise exception
class
_test_ResultBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
build_eval_mode
()
numpy
.
random
.
seed
(
44
)
def
tearDown
(
self
):
pop_mode
()
def
test_0
(
self
):
r
=
ResultBase
()
class
Numpy2
(
ResultBase
):
"""Result storing a numpy ndarray"""
__slots__
=
[
'_dtype'
,
'_shape'
,
]
...
...
@@ -986,92 +854,14 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return
normal_f
(
x
,
y
)
return
f
if
0
:
class
NumpyR
(
gof
.
ResultValue
):
"""The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal
ResultValue:
- operator overloads that correspond to omega ops such as add() and scale()
- special attributes that make it behave like an ndarray when passed to
numpy functions.
Attributes:
__array__ - alias of self.data.__array_struct__
__array_struct__ - alias of self.data.__array_struct__
Methods:
set_value() -
"""
# The following attributes make NumpyR instances look like normal ndarray
# instances to many numpy functions, such as argmax(), dot(), svd(), sum(),
# etc. These are documented in the numpy book.
__array__
=
property
(
lambda
self
:
self
.
data
.
__array__
)
__array_struct__
=
property
(
lambda
self
:
self
.
data
.
__array_struct__
)
def
set_value_filter
(
self
,
value
):
return
numpy
.
asarray
(
value
)
def
set_value_inplace
(
self
,
value
):
if
value
is
UNCOMPUTED
:
raise
ValueError
()
else
:
if
0
==
len
(
self
.
data
.
shape
):
self
.
data
.
itemset
(
value
)
# for scalars
else
:
self
.
data
[:]
=
value
# for matrices
self
.
refresh
()
self
.
up_to_date
=
True
def
refresh
(
self
):
if
self
.
data
is
not
UNCOMPUTED
:
self
.
spec
=
(
numpy
.
ndarray
,
self
.
data
.
dtype
,
self
.
data
.
shape
)
def
alloc
(
self
):
shape
,
dtype
=
self
.
spec
[
2
],
self
.
spec
[
1
]
self
.
data
=
numpy
.
ndarray
(
shape
,
dtype
=
dtype
)
def
__add__
(
self
,
y
):
return
add
(
self
,
y
)
def
__radd__
(
self
,
x
):
return
add
(
x
,
self
)
def
__iadd__
(
self
,
y
):
return
add_inplace
(
self
,
y
)
def
__sub__
(
self
,
y
):
return
sub
(
self
,
y
)
def
__rsub__
(
self
,
x
):
return
sub
(
x
,
self
)
def
__isub__
(
self
,
y
):
return
sub_inplace
(
self
,
y
)
def
__mul__
(
self
,
y
):
return
mul
(
self
,
y
)
def
__rmul__
(
self
,
x
):
return
mul
(
x
,
self
)
def
__imul__
(
self
,
y
):
return
mul_inplace
(
self
,
y
)
def
__div__
(
self
,
y
):
return
div
(
self
,
y
)
def
__rdiv__
(
self
,
x
):
return
div
(
x
,
self
)
def
__idiv__
(
self
,
y
):
return
div_inplace
(
self
,
y
)
def
__pow__
(
self
,
y
):
return
pow
(
self
,
y
)
def
__rpow__
(
self
,
x
):
return
pow
(
x
,
self
)
def
__ipow__
(
self
,
y
):
return
pow_inplace
(
self
,
y
)
def
__neg__
(
self
):
return
neg
(
self
)
T
=
property
(
lambda
self
:
transpose
(
self
))
Tc
=
property
(
lambda
self
:
transpose_copy
(
self
))
def
__copy__
(
self
):
return
array_copy
(
self
)
def
__getitem__
(
self
,
item
):
return
get_slice
(
self
,
item
)
def
__getslice__
(
self
,
*
args
):
return
get_slice
(
self
,
slice
(
*
args
))
from
grad
import
Gra
d
from
grad
import
Undefine
d
def
wrap_producer
(
f
):
class
producer
(
omega_op
):
impl
=
f
def
grad
(
*
args
):
return
[
Grad
.
Undefined
]
*
(
len
(
args
)
-
1
)
return
[
Undefined
]
*
(
len
(
args
)
-
1
)
producer
.
__name__
=
f
.
__name__
def
ret
(
dim
,
dtype
=
'float'
,
order
=
'C'
):
return
producer
(
dim
,
dtype
,
order
)
...
...
@@ -1811,11 +1601,11 @@ class sum(elemwise):
class
ones_like
(
elemwise
):
impl
=
numpy
.
ones_like
def
grad
(
x
,
gz
):
return
Grad
.
Undefined
def
grad
(
x
,
gz
):
return
Undefined
class
zeros_like
(
elemwise
):
impl
=
numpy
.
zeros_like
def
grad
(
x
,
gz
):
return
Grad
.
Undefined
def
grad
(
x
,
gz
):
return
Undefined
## Array slicing ##
...
...
gof/compile.py
deleted
100644 → 0
浏览文件 @
b0b514e4
import
env
import
tools
import
utils
class
Compiler
:
""" What is this? Please document.
"""
def
__init__
(
self
,
optimizer
,
features
):
self
.
features
=
set
(
features
)
self
.
features
.
update
(
optimizer
.
require
())
self
.
optimizer
=
optimizer
def
compile
(
self
,
inputs
,
outputs
,
features
):
features
=
self
.
features
.
union
(
features
)
e
=
env
.
Env
(
inputs
,
outputs
,
features
,
False
)
self
.
optimizer
.
apply
(
e
)
if
not
e
.
consistent
():
raise
env
.
InconsistencyError
(
"The graph is inconsistent."
)
return
e
def
__call__
(
self
,
inputs
,
outputs
,
features
):
return
self
.
compile
(
inputs
,
outputs
,
features
)
# def __init__(self, inputs, outputs, preprocessors, features, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.features = features
# self.optimizer = optimizer
# features = features + [tools.EquivTool] + optimizer.require()
# features = utils.uniq_features(features)
# self.env = env.Env(inputs,
# outputs,
# features,
# False)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.__optimize__()
# self.thunks = [op.thunk() for op in self.order]
# def __optimize__(self):
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# def equiv(self, r):
# return self.env.equiv(r)
# def __getitem__(self, r):
# return self.equiv(r)
# def __setitem__(self, r, value):
# if isinstance(r, tuple):
# for a, b in zip(r, value):
# self.__setitem__(a, b)
# else:
# self.equiv(r).set_value(value)
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# import env
# import opt
# from value import AsValue
# class Prog:
# def __init__(self, inputs, outputs, optimizer):
# self.inputs = inputs
# self.outputs = outputs
# self.env = env.Env(inputs,
# outputs,
# False,
# op_db = env.OpDb,
# changed = env.ChangeListener,
# # pr = env.PrintListener,
# scope = env.ScopeListener)
# ## self.adjustments = adjustments
# self.optimizer = optimizer
# ## if self.adjustments:
# ## self.adjustments.apply(self.env)
# if not self.env.consistent():
# raise env.InconsistencyError("The graph is inconsistent.")
# self.optimizer.apply(self.env)
# self.order = self.env.toposort()
# print "==================="
# for op in self.order:
# print op
# print "==================="
# self.thunks = [op.thunk() for op in self.order]
# def equiv(self, v):
# v = AsValue(v)
# return self.env.equiv(v)
# def __getitem__(self, v):
# return self.equiv(v).storage
# def __setitem__(self, v, value):
# if isinstance(v, tuple):
# for a, b in zip(v, value):
# self.__setitem__(a, b)
# else:
# self.equiv(v).value = value
# def __call__(self, *args):
# if args:
# for input, arg in zip(self.inputs, args):
# if arg is not None:
# input.value = arg
# for thunk, op in zip(self.thunks, self.order):
# try:
# thunk()
# except Exception, e:
# raise e.__class__("Error in " + str(op) + ": " + str(e))
# return [output.value for output in self.outputs]
# def prog(i, o):
# if not isinstance(i, (list, tuple)):
# i = [i]
# if not isinstance(o, (list, tuple)):
# o = [o]
# i = [AsValue(input) for input in i]
# o = [AsValue(output) for output in o]
# return Prog(i,
# o,
# opt.TagFilterMultiOptimizer(opt.opt_registry, None, None))
gof/ext.py
浏览文件 @
a007e686
...
...
@@ -2,7 +2,6 @@
from
copy
import
copy
from
op
import
Op
from
lib
import
DummyOp
from
result
import
Result
from
features
import
Listener
,
Constraint
,
Orderings
from
env
import
InconsistencyError
from
utils
import
ClsInit
...
...
gof/graph.py
浏览文件 @
a007e686
from
copy
import
copy
from
result
import
Result
,
BrokenLink
,
BrokenLinkError
from
result
import
BrokenLink
,
BrokenLinkError
from
op
import
Op
import
utils
...
...
gof/lib.py
浏览文件 @
a007e686
from
op
import
Op
from
result
import
Result
,
is_result
from
result
import
is_result
,
ResultBase
from
utils
import
ClsInit
,
Keyword
,
AbstractFunctionError
import
opt
import
env
...
...
@@ -15,7 +15,6 @@ __all__ = [ 'UNDEFINED',
'eval_mode'
,
'build_eval_mode'
,
'pop_mode'
,
#'ResultValue',
'DummyOp'
,
'DummyRemover'
,
'PythonOp'
,
...
...
@@ -98,83 +97,6 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
if
0
:
class
ResultValue
(
Result
):
"""Augment Result to wrap a computed value.
Attributes:
data -
spec -
constant -
up_to_date -
Properties:
Methods:
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
Notes:
"""
__slots__
=
[
'data'
,
'spec'
,
'constant'
,
'up_to_date'
]
def
__init__
(
self
,
x
=
None
,
constant
=
False
):
self
.
constant
=
False
self
.
set_value
(
x
)
# allow set_value before constant = True
self
.
constant
=
constant
self
.
up_to_date
=
True
self
.
refresh
()
# to set spec
def
__str__
(
self
):
return
str
(
self
.
data
)
def
__repr__
(
self
):
return
repr
(
self
.
data
)
#TODO: document this function, what does it do?
def
refresh
(
self
):
self
.
spec
=
id
(
self
.
data
)
####################################################
#
# Functionality provided by this class
#
def
set_value
(
self
,
value
):
if
self
.
constant
:
raise
Exception
(
"This Result is a constant. Its value cannot be changed."
)
if
value
is
None
or
value
is
UNCOMPUTED
:
self
.
data
=
UNCOMPUTED
elif
is_result
(
value
):
self
.
set_value
(
value
.
data
)
else
:
try
:
self
.
data
=
self
.
set_value_filter
(
value
)
except
AbstractFunctionError
,
e
:
self
.
data
=
value
self
.
up_to_date
=
True
self
.
refresh
()
####################################################
#
# Pure virtual functions for subclasses to implement
#
# Perform error checking or automatic conversion of value, and return the
# result (which will be stored as self.data)
# Called by: set_value()
def
set_value_filter
(
self
,
value
):
raise
AbstractFunctionError
()
# For mutable data types, overwrite the current contents with value
# Also, call refresh and set up_to_date = True
def
set_value_inplace
(
self
,
value
):
raise
AbstractFunctionError
()
# Instantiate data (according to spec)
def
alloc
(
self
):
raise
AbstractFunctionError
()
class
DestroyHandler
(
features
.
Listener
,
features
.
Constraint
,
features
.
Orderings
):
def
__init__
(
self
,
env
):
...
...
@@ -279,14 +201,14 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
if
destroyed
:
# self.parent[output] = None
if
is
instance
(
destroyed
,
Result
):
if
is
_result
(
destroyed
):
destroyed
=
[
destroyed
]
for
input
in
destroyed
:
path
=
self
.
__path__
(
input
)
self
.
__add_destroyer__
(
path
+
[
output
])
elif
views
:
if
is
instance
(
views
,
Result
):
if
is
_result
(
views
):
views
=
[
views
]
if
len
(
views
)
>
1
:
#views was inputs before?
raise
Exception
(
"Output is a view of too many inputs."
)
...
...
@@ -462,7 +384,6 @@ class PythonOp(Op):
def
gen_outputs
(
self
):
raise
NotImplementedError
()
#return [ResultValue() for i in xrange(self.nout)]
def
view_map
(
self
):
return
{}
...
...
@@ -657,7 +578,7 @@ class PythonOpt(opt.Optimizer):
class
DummyOp
(
Op
):
def
__init__
(
self
,
input
):
Op
.
__init__
(
self
,
[
input
],
[
Result
()])
Op
.
__init__
(
self
,
[
input
],
[
Result
Base
()])
def
thunk
(
self
):
return
lambda
:
None
...
...
gof/listeners.py
deleted
100644 → 0
浏览文件 @
b0b514e4
差异被折叠。
点击展开。
gof/op.py
浏览文件 @
a007e686
...
...
@@ -4,7 +4,7 @@ Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines.
"""
from
result
import
Result
,
BrokenLink
from
result
import
BrokenLink
from
utils
import
ClsInit
,
all_bases
,
all_bases_collect
from
copy
import
copy
...
...
gof/result.py
浏览文件 @
a007e686
...
...
@@ -5,12 +5,13 @@ value that is the input or the output of an Op.
"""
import
unittest
from
err
import
GofError
from
utils
import
AbstractFunctionError
__all__
=
[
'is_result'
,
'Result'
,
'BrokenLink'
,
'BrokenLinkError'
]
__all__
=
[
'is_result'
,
'Result
Base
'
,
'BrokenLink'
,
'BrokenLinkError'
]
class
BrokenLink
:
...
...
@@ -45,24 +46,27 @@ def is_result(obj):
attr_list
=
'owner'
,
return
all
([
hasattr
(
obj
,
attr
)
for
attr
in
attr_list
])
class
Result
(
object
):
"""
Storage node for data in a graph of Op instances.
class
Result
Base
(
object
):
"""
Base class for storing Op inputs and outputs
Attributes:
owner - represents the Op which computes this Result. Contains either None
or an instance of Op.
index - the index of this Result in owner.outputs.
_role - None or (owner, index) or BrokenLink
_data - anything
constant - Boolean
Methods:
-
Properties:
role - (rw)
owner - (ro)
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Notes:
Abstract Methods:
data_filter
Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results.
The Result class is abstract. It must be subclassed to support the
types of data needed for computation.
Notes:
A Result instance should be immutable: indeed, if some aspect of a
Result is changed, operations that use it might suddenly become
...
...
@@ -71,140 +75,120 @@ class Result(object):
called on the Result which is replaced (this will make its owner a
BrokenLink instance, which behaves like False in conditional
expressions).
"""
__slots__
=
[
'_owner'
,
'_index'
]
def
get_owner
(
self
):
if
not
hasattr
(
self
,
'_owner'
):
self
.
_owner
=
None
return
self
.
_owner
owner
=
property
(
get_owner
,
doc
=
"The Op of which this Result is an output or None if there is no such Op."
)
def
set_owner
(
self
,
owner
,
index
):
if
self
.
owner
is
not
None
:
if
self
.
owner
is
not
owner
:
"""
class
BrokenLink
:
"""The owner of a Result that was replaced by another Result"""
__slots__
=
[
'old_role'
]
def
__init__
(
self
,
role
):
self
.
old_role
=
role
def
__nonzero__
(
self
):
return
False
class
BrokenLinkError
(
Exception
):
"""Exception thrown when an owner is a BrokenLink"""
class
AbstractFunction
(
Exception
):
"""Exception thrown when an abstract function is called"""
__slots__
=
[
'_role'
,
'_data'
,
'constant'
]
def
__init__
(
self
,
role
=
None
,
data
=
None
,
constant
=
False
):
self
.
_role
=
role
self
.
constant
=
constant
if
data
is
None
:
#None is not filtered
self
.
_data
=
None
else
:
try
:
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
self
.
_data
=
data
#role is pair: (owner, outputs_position)
def
__get_role
(
self
):
return
self
.
_role
def
__set_role
(
self
,
role
):
owner
,
index
=
role
if
self
.
_role
is
not
None
:
# this is either an error or a no-op
_owner
,
_index
=
self
.
_role
if
_owner
is
not
owner
:
raise
ValueError
(
"Result
%
s already has an owner."
%
self
)
elif
self
.
index
!=
index
:
if
_
index
!=
index
:
raise
ValueError
(
"Result
%
s was already mapped to a different index."
%
self
)
self
.
_owner
=
owner
self
.
_index
=
index
def
invalidate
(
self
):
if
self
.
owner
is
None
:
raise
Exception
(
"Cannot invalidate a Result instance with no owner."
)
elif
not
isinstance
(
self
.
owner
,
BrokenLink
):
self
.
_owner
=
BrokenLink
(
self
.
_owner
,
self
.
_index
)
del
self
.
_index
def
revalidate
(
self
):
if
isinstance
(
self
.
owner
,
BrokenLink
):
owner
,
index
=
self
.
_owner
.
owner
,
self
.
_owner
.
index
self
.
_owner
=
owner
self
.
_index
=
index
def
perform
(
self
):
"""Calls self.owner.perform() if self.owner exists.
This is a mutually recursive function with gof.op.Op
"""
if
self
.
owner
:
self
.
owner
.
perform
()
# def extract(self):
# """
# Returns a representation of this datum for use in Op.impl.
# Successive calls to extract should always return the same object.
# """
# raise NotImplementedError
# def sync(self):
# """
# After calling Op.impl, synchronizes the Result instance with the
# new contents of the storage. This might usually not be necessary.
# """
# raise NotImplementedError
# def c_libs(self):
# """
# Returns a list of libraries that must be included to work with
# this Result.
# """
# raise NotImplementedError
# def c_imports(self):
# """
# Returns a list of strings representing headers to import when
# building a C interface that uses this Result.
# """
# raise NotImplementedError
# def c_declare(self):
# """
# Returns code which declares and initializes a C variable in
# which this Result can be held.
# """
# raise NotImplementedError
# def pyo_to_c(self):
# raise NotImplementedError
# def c_to_pyo(self):
# raise NotImplementedError
############################
# Utilities
############################
# class SelfContainedResult(Result):
# """
# This represents a Result which acts as its own data container. It
# is recommended to subclass this if you wish to be able to use the
# Result in normal computations as well as working with a graph
# representation.
# """
# # def extract(self):
# # """Returns self."""
# # return self
# # def sync(self):
# # """Does nothing."""
# # pass
# class HolderResult(Result):
# """
# HolderResult adds a 'data' slot which is meant to contain the
# object used by the Op implementation. It is recommended to subclass
# this if you want to be able to use the exact same object at
# different points in a computation.
# """
# __slots__ = ['data']
# # def extract(self):
# # """Returns self.data."""
# # return self.data
# # def sync(self):
# # """
# # Does nothing. Override if you have additional fields or
# # functionality in your subclass which need to be computed from
# # the data.
# # """
# # pass
return
# because _owner is owner and _index == index
self
.
_role
=
role
role
=
property
(
__get_role
,
__set_role
)
#owner is role[0]
def
__get_owner
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
0
]
owner
=
property
(
__get_owner
,
doc
=
"Op of which this Result is an output, or None if role is None"
)
#index is role[1]
def
__get_index
(
self
):
if
self
.
_role
is
None
:
return
None
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
return
self
.
_role
[
1
]
index
=
property
(
__get_index
,
doc
=
"position of self in owner's outputs, or None if role is None"
)
# assigning to self.data will invoke self.data_filter(value) if that
# function is defined
def
__get_data
(
self
):
return
self
.
_data
def
__set_data
(
self
,
data
):
if
self
.
replaced
:
raise
ResultBase
.
BrokenLinkError
()
if
self
.
constant
:
raise
Exception
(
'cannot set constant ResultBase'
)
try
:
self
.
_data
=
self
.
data_filter
(
data
)
except
ResultBase
.
AbstractFunction
:
#use default behaviour
self
.
_data
=
data
data
=
property
(
__get_data
,
__set_data
,
doc
=
"The storage associated with this result"
)
def
data_filter
(
self
,
data
):
"""(abstract) Return an appropriate _data based on data."""
raise
ResultBase
.
AbstractFunction
()
# replaced
def
__get_replaced
(
self
):
return
isinstance
(
self
.
_role
,
ResultBase
.
BrokenLink
)
def
__set_replaced
(
self
,
replace
):
if
replace
==
self
.
replaced
:
return
if
replace
:
self
.
_role
=
ResultBase
.
BrokenLink
(
self
.
_role
)
else
:
self
.
_role
=
self
.
_role
.
old_role
replaced
=
property
(
__get_replaced
,
__set_replaced
,
doc
=
"has this Result been replaced?"
)
# computed
#TODO: think about how to handle this more correctly
computed
=
property
(
lambda
self
:
self
.
_data
is
not
None
)
#################
# NumpyR Compatibility
#
up_to_date
=
property
(
lambda
self
:
True
)
def
refresh
(
self
):
pass
def
set_owner
(
self
,
owner
,
idx
):
self
.
role
=
(
owner
,
idx
)
def
set_value
(
self
,
value
):
self
.
data
=
value
#may raise exception
class
_test_ResultBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
build_eval_mode
()
numpy
.
random
.
seed
(
44
)
def
tearDown
(
self
):
pop_mode
()
def
test_0
(
self
):
r
=
ResultBase
()
if
__name__
==
'__main__'
:
unittest
.
main
()
grad.py
浏览文件 @
a007e686
...
...
@@ -2,6 +2,9 @@ import gof
from
gof.lib
import
compute_from
,
is_result
import
core
class
Undefined
:
"""A special class representing a gradient of 0"""
class
Grad
(
object
):
"""A dictionary-like class, into which derivative expressions may be added.
...
...
@@ -17,7 +20,6 @@ class Grad(object):
__call__()
__getitem__()
"""
class
Undefined
:
pass
def
__init__
(
self
,
dct
=
{}):
self
.
map
=
{}
...
...
@@ -36,7 +38,7 @@ class Grad(object):
try
:
return
self
.
map
[
key
]
except
KeyError
:
return
Grad
.
Undefined
return
Undefined
def
__setitem__
(
self
,
item
,
val
):
"""Map item to its id and store internally."""
...
...
@@ -59,7 +61,7 @@ class Grad(object):
r may be uncomputed or NumpyR
"""
if
dr
is
Grad
.
Undefined
:
if
dr
is
Undefined
:
# nothing to do
return
...
...
@@ -124,7 +126,7 @@ class Grad(object):
if
not
self
.
did_bprop
:
raise
Exception
(
'Grad.__call__ only makes sense after a bprop'
)
rval
=
self
[
item
]
if
rval
is
not
Grad
.
Undefined
\
if
rval
is
not
Undefined
\
and
core
.
current_mode
()
==
'build_eval'
:
compute_from
([
rval
],
self
.
_compute_history
)
return
rval
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论