提交 020003bc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

redefined is_valid_value in terms of filter - renamed values_eq_enough to values_eq_approx

上级 fd071943
...@@ -35,27 +35,29 @@ efficient machine learning algorithms while minimizing human ...@@ -35,27 +35,29 @@ efficient machine learning algorithms while minimizing human
time. Theano was named after the `Greek mathematician`_ who may have time. Theano was named after the `Greek mathematician`_ who may have
been Pythagoras' wife. been Pythagoras' wife.
Theano is released under a BSD license (:ref:`link <license>`)
You can keep reading from :ref:`here <usingtheano>`. You can keep reading from :ref:`here <usingtheano>`.
Getting started Getting started
=============== ===============
TODO: I want to bold the links below. How the fuck do I bold links? ``**`link name`_**`` doesn't work! :(
:ref:`install` :ref:`install`
Instructions to download and install Theano on your system. Instructions to download and install Theano on your system.
:ref:`basictutorial` :ref:`basictutorial`
Getting started with Theano's basic features. Go there if you are new! Getting started with Theano's basic features. Go there if you are
new!
:ref:`advtutorial` :ref:`advtutorial`
This tutorial is for more advanced users who want to define their own This tutorial is for more advanced users who want to define their
operations and optimizations. It is recommended to go through the own operations and optimizations. It is recommended to go through
:ref:`basictutorial` first. the :ref:`basictutorial` first.
......
...@@ -39,6 +39,8 @@ efficient machine learning algorithms while minimizing human ...@@ -39,6 +39,8 @@ efficient machine learning algorithms while minimizing human
time. Theano was named after the `Greek mathematician`_ who may have time. Theano was named after the `Greek mathematician`_ who may have
been Pythagoras' wife. been Pythagoras' wife.
Theano is released under a BSD license (:ref:`link <license>`)
.. _usingtheano: .. _usingtheano:
......
...@@ -36,6 +36,8 @@ Now we're ready for the tour: ...@@ -36,6 +36,8 @@ Now we're ready for the tour:
`Using Module`_ `Using Module`_
Getting serious Getting serious
WRITEME: using modes?
`Wrapping up`_ `Wrapping up`_
A guide to what to look at next A guide to what to look at next
...@@ -58,3 +60,4 @@ Now we're ready for the tour: ...@@ -58,3 +60,4 @@ Now we're ready for the tour:
.. _More examples: examples.html .. _More examples: examples.html
.. _Using Module: module.html .. _Using Module: module.html
.. _Wrapping up: wrapup.html .. _Wrapping up: wrapup.html
...@@ -120,7 +120,7 @@ class BadDestroyMap(DebugModeError): ...@@ -120,7 +120,7 @@ class BadDestroyMap(DebugModeError):
print >> sio, " old val:", self.old_val print >> sio, " old val:", self.old_val
print >> sio, " new val:", self.new_val print >> sio, " new val:", self.new_val
print >> sio, "" print >> sio, ""
print >> sio, " Hint: this can be caused by a deficient values_eq_enough() or __eq__() implementation that compares node input values" print >> sio, " Hint: this can be caused by a deficient values_eq_approx() or __eq__() implementation that compares node input values"
return sio.getvalue() return sio.getvalue()
class StochasticOrder(DebugModeError): class StochasticOrder(DebugModeError):
...@@ -212,7 +212,7 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes): ...@@ -212,7 +212,7 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes):
destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list] destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
for r_idx, r in enumerate(node.inputs): for r_idx, r in enumerate(node.inputs):
if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]): if not r.type.values_eq_approx(r_vals[r], storage_map[r][0]):
# some input node 'r' got changed by running the node # some input node 'r' got changed by running the node
# this may or may not be ok... # this may or may not be ok...
if r in destroyed_res_list: if r in destroyed_res_list:
...@@ -572,7 +572,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -572,7 +572,7 @@ class _Linker(gof.link.LocalLinker):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0])
# compares the version from thunk_py (in r_vals) # compares the version from thunk_py (in r_vals)
# to the version produced by thunk_c (in storage_map) # to the version produced by thunk_c (in storage_map)
if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]): if not r.type.values_eq_approx(r_vals[r], storage_map[r][0]):
raise BadClinkerOutput(r, val_py=r_vals[r], val_c=storage_map[r][0]) raise BadClinkerOutput(r, val_py=r_vals[r], val_c=storage_map[r][0])
storage_map[r][0] = None #clear the storage_map for the thunk_c storage_map[r][0] = None #clear the storage_map for the thunk_c
...@@ -596,7 +596,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -596,7 +596,7 @@ class _Linker(gof.link.LocalLinker):
r_val = r_vals[r] r_val = r_vals[r]
assert r.type == new_r.type assert r.type == new_r.type
if not r.type.values_eq_enough(r_val, new_r_val): if not r.type.values_eq_approx(r_val, new_r_val):
raise BadOptimization(old_r=r, raise BadOptimization(old_r=r,
new_r=new_r, new_r=new_r,
old_r_val=r_val, old_r_val=r_val,
......
...@@ -221,7 +221,11 @@ class PureType(object): ...@@ -221,7 +221,11 @@ class PureType(object):
def is_valid_value(self, a): def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Result of this Type""" """Required: Return True for any python object `a` that would be a legal value for a Result of this Type"""
raise AbstractFunctionError() try:
self.filter(a, True)
return True
except TypeError:
return False
def make_result(self, name = None): def make_result(self, name = None):
"""Return a new `Result` instance of Type `self`. """Return a new `Result` instance of Type `self`.
...@@ -246,8 +250,17 @@ class PureType(object): ...@@ -246,8 +250,17 @@ class PureType(object):
r.tag.trace = traceback.extract_stack()[:-1] r.tag.trace = traceback.extract_stack()[:-1]
return r return r
def values_eq_enough(self, a, b): def values_eq(self, a, b):
"""Return True if a and b can be considered equal as Op outputs, else False. """
Return True if a and b can be considered exactly equal.
a and b are assumed to be valid values of this Type.
"""
return a == b
def values_eq_approx(self, a, b):
"""
Return True if a and b can be considered approximately equal.
:param a: a potential value for a Result of this Type. :param a: a potential value for a Result of this Type.
...@@ -255,12 +268,14 @@ class PureType(object): ...@@ -255,12 +268,14 @@ class PureType(object):
:rtype: Bool :rtype: Bool
This function is used by theano debugging tools to decide whether two values are This function is used by theano debugging tools to decide
equivalent, admitting a certain amount of numerical instability. For example, whether two values are equivalent, admitting a certain amount
for floating-point numbers this function should be an approximate comparison. of numerical instability. For example, for floating-point
numbers this function should be an approximate comparison.
By default, this does an exact comparison.
""" """
return (a == b) return self.values_eq(a, b)
_nothing = """ _nothing = """
......
...@@ -52,13 +52,8 @@ class Scalar(Type): ...@@ -52,13 +52,8 @@ class Scalar(Type):
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e) raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
def values_eq_enough(self, a, b): def values_eq_approx(self, a, b, tolerance = 1e-4):
return abs(a - b) / (a+b) < 1e-4 return abs(a - b) / (a+b) < tolerance
def is_valid_value(self, a):
_a = numpy.asarray(a)
rval = (_a.ndim == 0) and (str(_a.dtype) == self.dtype)
return rval
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype return type(self) == type(other) and other.dtype == self.dtype
......
...@@ -222,21 +222,10 @@ class Tensor(Type): ...@@ -222,21 +222,10 @@ class Tensor(Type):
"""Compare True iff other is the same kind of Tensor""" """Compare True iff other is the same kind of Tensor"""
return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable
def values_eq_enough(self, a, b): def values_eq_approx(self, a, b):
return type(a) is numpy.ndarray and type(b) is numpy.ndarray \ return type(a) is numpy.ndarray and type(b) is numpy.ndarray \
and (a.shape == b.shape) and numpy.allclose(a, b) and (a.shape == b.shape) and numpy.allclose(a, b)
def is_valid_value(self, a):
rval = (type(a) is numpy.ndarray) and (self.ndim == a.ndim) \
and (str(a.dtype) == self.dtype) \
and all([((si == 1) or not bi) for si, bi in zip(a.shape, self.broadcastable)])
if not rval:
print type(a),(type(a) is numpy.ndarray)
print a.ndim, (self.ndim == a.ndim)
print a.dtype, (str(a.dtype) == self.dtype)
print a.shape, self.broadcastable, ([(shp_i == 1) for shp_i in a.shape] == self.broadcastable)
return rval
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of Tensor""" """Hash equal for same kinds of Tensor"""
return hash(self.dtype) ^ hash(self.broadcastable) return hash(self.dtype) ^ hash(self.broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论