提交 d21050ba authored 作者: Joseph Turian's avatar Joseph Turian

merge

...@@ -8,12 +8,35 @@ or correct documentation. ...@@ -8,12 +8,35 @@ or correct documentation.
What happens if grad is not defined? What happens if grad is not defined?
====================================== ======================================
You should define gradient, even if it is undefined.
[give log factorial example]
If an Op does not define ``grad``, but this Op does not appear in the path when If an Op does not define ``grad``, but this Op does not appear in the path when
you compute the gradient, then there is no problem. you compute the gradient, then there is no problem.
If an Op does not define ``grad``, and this Op *does* appear in the path when If an Op does not define ``grad``, and this Op *does* appear in the path when
you compute the gradient, **WRITEME**. you compute the gradient, **WRITEME**.
Gradients for a particular result can be one of four kinds:
1) forgot to implement it
You will get an exception of the following form.
theano.gof.utils.MethodNotDefined: ('grad', <class 'pylearn.algorithms.sandbox.cost.LogFactorial'>, 'LogFactorial')
2) a symbolic result
3) None / zero
4) undefined mathematically
currently, there is no way for a grad() method to distinguish between cases 3
and 4
but the distinction is important because graphs with type-3 gradients are ok
to run, whereas graphs with type-4 gradients are not.
so I suggested that Joseph return a type-4 gradient by defining an Op with no
perform method.
the idea would be that this would suit the graph-construction phase, but would
prevent linking.
how does that sound to you?
**This documentation is useful when we show users how to write Ops.** **This documentation is useful when we show users how to write Ops.**
====================================== ======================================
......
...@@ -542,7 +542,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -542,7 +542,7 @@ class _Linker(gof.link.LocalLinker):
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
try: try:
if not self.maker.mode.check_c_code: if not self.maker.mode.check_c_code:
raise utils.AbstractFunctionError() raise utils.MethodNotDefined()
e = Env(*graph.clone(node.inputs, node.outputs)) e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes #WARNING: STOCHASTIC ORDER e.toposort = lambda: e.nodes #WARNING: STOCHASTIC ORDER
...@@ -575,7 +575,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -575,7 +575,7 @@ class _Linker(gof.link.LocalLinker):
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunks_c.append(thunk) thunks_c.append(thunk)
except (NotImplementedError, utils.AbstractFunctionError): except (NotImplementedError, utils.MethodNotDefined):
thunks_c.append(None) thunks_c.append(None)
p = node.op.perform p = node.op.perform
......
...@@ -39,5 +39,5 @@ from type import \ ...@@ -39,5 +39,5 @@ from type import \
Type, Generic, generic Type, Generic, generic
from utils import \ from utils import \
object2, AbstractFunctionError object2, MethodNotDefined
...@@ -402,7 +402,7 @@ class CLinker(link.Linker): ...@@ -402,7 +402,7 @@ class CLinker(link.Linker):
consts.append(result) consts.append(result)
self.orphans.remove(result) self.orphans.remove(result)
continue continue
except (utils.AbstractFunctionError, NotImplementedError): except (utils.MethodNotDefined, NotImplementedError):
pass pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same # orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup], policy = [[get_c_declare, get_c_extract, get_c_cleanup],
...@@ -465,11 +465,11 @@ class CLinker(link.Linker): ...@@ -465,11 +465,11 @@ class CLinker(link.Linker):
op = node.op op = node.op
try: behavior = op.c_code(node, name, isyms, osyms, sub) try: behavior = op.c_code(node, name, isyms, osyms, sub)
except utils.AbstractFunctionError: except utils.MethodNotDefined:
raise NotImplementedError("%s cannot produce C code" % op) raise NotImplementedError("%s cannot produce C code" % op)
try: cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub) try: cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub)
except utils.AbstractFunctionError: except utils.MethodNotDefined:
cleanup = "" cleanup = ""
blocks.append(CodeBlock("", behavior, cleanup, sub)) blocks.append(CodeBlock("", behavior, cleanup, sub))
...@@ -517,7 +517,7 @@ class CLinker(link.Linker): ...@@ -517,7 +517,7 @@ class CLinker(link.Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret.append(x.c_support_code()) try: ret.append(x.c_support_code())
except utils.AbstractFunctionError: pass except utils.MethodNotDefined: pass
return ret return ret
def compile_args(self): def compile_args(self):
...@@ -530,7 +530,7 @@ class CLinker(link.Linker): ...@@ -530,7 +530,7 @@ class CLinker(link.Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_compile_args() try: ret += x.c_compile_args()
except utils.AbstractFunctionError: pass except utils.MethodNotDefined: pass
return ret return ret
def headers(self): def headers(self):
...@@ -543,7 +543,7 @@ class CLinker(link.Linker): ...@@ -543,7 +543,7 @@ class CLinker(link.Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_headers() try: ret += x.c_headers()
except utils.AbstractFunctionError: pass except utils.MethodNotDefined: pass
return ret return ret
def libraries(self): def libraries(self):
...@@ -556,7 +556,7 @@ class CLinker(link.Linker): ...@@ -556,7 +556,7 @@ class CLinker(link.Linker):
ret = [] ret = []
for x in [y.type for y in self.results] + [y.op for y in self.node_order]: for x in [y.type for y in self.results] + [y.op for y in self.node_order]:
try: ret += x.c_libraries() try: ret += x.c_libraries()
except utils.AbstractFunctionError: pass except utils.MethodNotDefined: pass
return ret return ret
def __compile__(self, input_storage = None, output_storage = None): def __compile__(self, input_storage = None, output_storage = None):
...@@ -840,7 +840,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -840,7 +840,7 @@ class OpWiseCLinker(link.LocalLinker):
thunk.inputs = node_input_storage thunk.inputs = node_input_storage
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunks.append(thunk) thunks.append(thunk)
except (NotImplementedError, utils.AbstractFunctionError): except (NotImplementedError, utils.MethodNotDefined):
if self.fallback_on_perform: if self.fallback_on_perform:
p = node.op.perform p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o) thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
......
...@@ -69,7 +69,7 @@ class Linker(object): ...@@ -69,7 +69,7 @@ class Linker(object):
print new_e.data # 3.0 print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined("make_thunk", type(self), self.__class__.__name__)
## DELETEME ## ## DELETEME ##
def make_function(self, unpack_single = True, **kwargs): def make_function(self, unpack_single = True, **kwargs):
...@@ -306,7 +306,7 @@ class LocalLinker(Linker): ...@@ -306,7 +306,7 @@ class LocalLinker(Linker):
# 3. output storage # 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1) # 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1) # 5. order: list of nodes, in the order they will be run by the function in (1)
raise AbstractFunctionError raise MethodNotDefined("make_all", type(self), self.__class__.__name__)
def gc_helper(node_list): def gc_helper(node_list):
""" """
......
...@@ -45,10 +45,10 @@ class CLinkerOp(object): ...@@ -45,10 +45,10 @@ class CLinkerOp(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError('%s.c_code is not defined' \ raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__) % self.__class__.__name__)
def c_code_cleanup(self, node, name, inputs, outputs, sub): def c_code_cleanup(self, node, name, inputs, outputs, sub):
...@@ -77,10 +77,11 @@ class CLinkerOp(object): ...@@ -77,10 +77,11 @@ class CLinkerOp(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined('%s.c_code_cleanup' \
% self.__class__.__name__)
def c_compile_args(self): def c_compile_args(self):
"""Optional: Return a list of recommended gcc compiler arguments. """Optional: Return a list of recommended gcc compiler arguments.
...@@ -93,7 +94,8 @@ class CLinkerOp(object): ...@@ -93,7 +94,8 @@ class CLinkerOp(object):
WRITEME WRITEME
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined('%s.c_compile_args' \
% self.__class__.__name__)
def c_headers(self): def c_headers(self):
"""Optional: Return a list of header files that must be included to compile the C code. """Optional: Return a list of header files that must be included to compile the C code.
...@@ -105,10 +107,11 @@ class CLinkerOp(object): ...@@ -105,10 +107,11 @@ class CLinkerOp(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined('%s.c_headers' \
% self.__class__.__name__)
def c_libraries(self): def c_libraries(self):
"""Optional: Return a list of libraries to link against to manipulate this `Op`. """Optional: Return a list of libraries to link against to manipulate this `Op`.
...@@ -118,10 +121,11 @@ class CLinkerOp(object): ...@@ -118,10 +121,11 @@ class CLinkerOp(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined('%s.c_libraries' \
% self.__class__.__name__)
def c_support_code(self): def c_support_code(self):
"""Optional: Return support code for use by the code that is returned by `c_code`. """Optional: Return support code for use by the code that is returned by `c_code`.
...@@ -133,10 +137,11 @@ class CLinkerOp(object): ...@@ -133,10 +137,11 @@ class CLinkerOp(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined('%s.c_support_code' \
% self.__class__.__name__)
class PureOp(object): class PureOp(object):
""" """
...@@ -185,10 +190,10 @@ class PureOp(object): ...@@ -185,10 +190,10 @@ class PureOp(object):
All subclasses should over-ride this function. All subclasses should over-ride this function.
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError(self) raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__)
def __call__(self, *inputs): def __call__(self, *inputs):
"""Optional: Return some or all output[s] of `make_node`. """Optional: Return some or all output[s] of `make_node`.
...@@ -241,10 +246,10 @@ class PureOp(object): ...@@ -241,10 +246,10 @@ class PureOp(object):
sees fit. sees fit.
:Exceptions: :Exceptions:
- `AbstractFunctionError`: the subclass does not override this method - `MethodNotDefined`: the subclass does not override this method
""" """
raise utils.AbstractFunctionError(self) raise utils.MethodNotDefined("perform", type(self), self.__class__.__name__)
class Op(utils.object2, PureOp, CLinkerOp): class Op(utils.object2, PureOp, CLinkerOp):
"""Convenience class to bundle `PureOp` and `CLinkerOp`""" """Convenience class to bundle `PureOp` and `CLinkerOp`"""
......
...@@ -342,7 +342,7 @@ class LocalOptimizer(object): ...@@ -342,7 +342,7 @@ class LocalOptimizer(object):
""" """
raise utils.AbstractFunctionError() raise utils.MethodNotDefined("transform", type(self), self.__class__.__name__)
def add_requirements(self, env): def add_requirements(self, env):
"""If this local optimization wants to add some requirements to the env, """If this local optimization wants to add some requirements to the env,
......
...@@ -4,7 +4,7 @@ __docformat__ = "restructuredtext en" ...@@ -4,7 +4,7 @@ __docformat__ = "restructuredtext en"
import copy import copy
import utils import utils
from utils import AbstractFunctionError, object2 from utils import MethodNotDefined, object2
from graph import Result from graph import Result
import traceback import traceback
...@@ -41,10 +41,10 @@ class CLinkerType(object): ...@@ -41,10 +41,10 @@ class CLinkerType(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_literal", type(self), self.__class__.__name__)
def c_declare(self, name, sub): def c_declare(self, name, sub):
"""Required: Return c code to declare variables that will be """Required: Return c code to declare variables that will be
...@@ -59,12 +59,12 @@ class CLinkerType(object): ...@@ -59,12 +59,12 @@ class CLinkerType(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined()
def c_init(self, name, sub): def c_init(self, name, sub):
raise AbstractFunctionError() raise MethodNotDefined("c_init", type(self), self.__class__.__name__)
def c_extract(self, name, sub): def c_extract(self, name, sub):
"""Required: Return c code to extract a PyObject * instance. """Required: Return c code to extract a PyObject * instance.
...@@ -89,10 +89,10 @@ class CLinkerType(object): ...@@ -89,10 +89,10 @@ class CLinkerType(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_extract", type(self), self.__class__.__name__)
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
"""Optional: Return c code to clean up after `c_extract`. """Optional: Return c code to clean up after `c_extract`.
...@@ -110,10 +110,10 @@ class CLinkerType(object): ...@@ -110,10 +110,10 @@ class CLinkerType(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined()
def c_sync(self, name, sub): def c_sync(self, name, sub):
"""Required: Return c code to pack C types back into a PyObject. """Required: Return c code to pack C types back into a PyObject.
...@@ -131,10 +131,10 @@ class CLinkerType(object): ...@@ -131,10 +131,10 @@ class CLinkerType(object):
WRITEME WRITEME
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_sync", type(self), self.__class__.__name__)
def c_compile_args(self): def c_compile_args(self):
"""Optional: Return a list of compile args recommended to compile the """Optional: Return a list of compile args recommended to compile the
...@@ -143,10 +143,10 @@ class CLinkerType(object): ...@@ -143,10 +143,10 @@ class CLinkerType(object):
WRITEME: example of formatting for -I, -L, -f args. WRITEME: example of formatting for -I, -L, -f args.
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_compile_args", type(self), self.__class__.__name__)
def c_headers(self): def c_headers(self):
"""Optional: Return a list of header files required by code returned by """Optional: Return a list of header files required by code returned by
...@@ -155,10 +155,10 @@ class CLinkerType(object): ...@@ -155,10 +155,10 @@ class CLinkerType(object):
WRITEME: example of local file, standard file. WRITEME: example of local file, standard file.
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_headers", type(self), self.__class__.__name__)
def c_libraries(self): def c_libraries(self):
"""Optional: Return a list of libraries required by code returned by """Optional: Return a list of libraries required by code returned by
...@@ -174,10 +174,10 @@ class CLinkerType(object): ...@@ -174,10 +174,10 @@ class CLinkerType(object):
QUESTION: What about via the c_compile_args? a -L option is allowed no? QUESTION: What about via the c_compile_args? a -L option is allowed no?
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_libraries", type(self), self.__class__.__name__)
def c_support_code(self): def c_support_code(self):
"""Optional: Return utility code for use by a `Result` or `Op` to be """Optional: Return utility code for use by a `Result` or `Op` to be
...@@ -187,10 +187,10 @@ class CLinkerType(object): ...@@ -187,10 +187,10 @@ class CLinkerType(object):
with many instances of the same type? with many instances of the same type?
:Exceptions: :Exceptions:
- `AbstractFunctionError`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise AbstractFunctionError() raise MethodNotDefined("c_support_code", type(self), self.__class__.__name__)
class PureType(object): class PureType(object):
"""Interface specification for result type instances. """Interface specification for result type instances.
...@@ -214,10 +214,10 @@ class PureType(object): ...@@ -214,10 +214,10 @@ class PureType(object):
argument. If it is False, filter may cast it to an appropriate type. argument. If it is False, filter may cast it to an appropriate type.
:Exceptions: :Exceptions:
- `AbstractFunctionError`: subclass doesn't implement this function. - `MethodNotDefined`: subclass doesn't implement this function.
""" """
raise AbstractFunctionError() raise MethodNotDefined("filter", type(self), self.__class__.__name__)
def is_valid_value(self, a): def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Result of this Type""" """Required: Return True for any python object `a` that would be a legal value for a Result of this Type"""
......
...@@ -9,9 +9,7 @@ def hashgen(): ...@@ -9,9 +9,7 @@ def hashgen():
return hashgen.next return hashgen.next
hashgen.next = 0 hashgen.next = 0
class OmegaError(Exception): pass class MethodNotDefined(Exception):
class AbstractFunctionError(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
......
...@@ -334,10 +334,10 @@ class ScalarOp(Op): ...@@ -334,10 +334,10 @@ class ScalarOp(Op):
storage[0] = result storage[0] = result
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise utils.MethodNotDefined("impl", type(self), self.__class__.__name__)
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
raise AbstractFunctionError() raise utils.MethodNotDefined("grad", type(self), self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
test = type(self) == type(other) \ test = type(self) == type(other) \
......
...@@ -11,7 +11,7 @@ import numpy ...@@ -11,7 +11,7 @@ import numpy
from copy import copy from copy import copy
from .. import gof from .. import gof
from ..gof import Result, Op, utils, AbstractFunctionError, Type, Constant, Apply, Value from ..gof import Result, Op, utils, Type, Constant, Apply, Value
from .. import gradient from .. import gradient
......
...@@ -12,7 +12,6 @@ from theano import gradient ...@@ -12,7 +12,6 @@ from theano import gradient
from theano import gof from theano import gof
from theano.gof.python25 import any from theano.gof.python25 import any
from theano import gof from theano import gof
from theano.gof.utils import AbstractFunctionError
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.compile.mode import default_mode from theano.compile.mode import default_mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论