提交 020003bc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

redefined is_valid_value in terms of filter - renamed values_eq_enough to values_eq_approx

上级 fd071943
......@@ -35,27 +35,29 @@ efficient machine learning algorithms while minimizing human
time. Theano was named after the `Greek mathematician`_ who may have
been Pythagoras' wife.
Theano is released under a BSD license (:ref:`link <license>`)
You can keep reading from :ref:`here <usingtheano>`.
Getting started
===============
TODO: I want to bold the links below. How the fuck do I bold links? ``**`link name`_**`` doesn't work! :(
:ref:`install`
Instructions to download and install Theano on your system.
:ref:`basictutorial`
Getting started with Theano's basic features. Go there if you are new!
Getting started with Theano's basic features. Go there if you are
new!
:ref:`advtutorial`
This tutorial is for more advanced users who want to define their own
operations and optimizations. It is recommended to go through the
:ref:`basictutorial` first.
This tutorial is for more advanced users who want to define their
own operations and optimizations. It is recommended to go through
the :ref:`basictutorial` first.
......
......@@ -39,6 +39,8 @@ efficient machine learning algorithms while minimizing human
time. Theano was named after the `Greek mathematician`_ who may have
been Pythagoras' wife.
Theano is released under a BSD license (:ref:`link <license>`)
.. _usingtheano:
......
......@@ -36,6 +36,8 @@ Now we're ready for the tour:
`Using Module`_
Getting serious
WRITEME: using modes?
`Wrapping up`_
A guide to what to look at next
......@@ -58,3 +60,4 @@ Now we're ready for the tour:
.. _More examples: examples.html
.. _Using Module: module.html
.. _Wrapping up: wrapup.html
......@@ -120,7 +120,7 @@ class BadDestroyMap(DebugModeError):
print >> sio, " old val:", self.old_val
print >> sio, " new val:", self.new_val
print >> sio, ""
print >> sio, " Hint: this can be caused by a deficient values_eq_enough() or __eq__() implementation that compares node input values"
print >> sio, " Hint: this can be caused by a deficient values_eq_approx() or __eq__() implementation that compares node input values"
return sio.getvalue()
class StochasticOrder(DebugModeError):
......@@ -212,7 +212,7 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes):
destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
for r_idx, r in enumerate(node.inputs):
if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]):
if not r.type.values_eq_approx(r_vals[r], storage_map[r][0]):
# some input node 'r' got changed by running the node
# this may or may not be ok...
if r in destroyed_res_list:
......@@ -572,7 +572,7 @@ class _Linker(gof.link.LocalLinker):
raise InvalidValueError(r, storage_map[r][0])
# compares the version from thunk_py (in r_vals)
# to the version produced by thunk_c (in storage_map)
if not r.type.values_eq_enough(r_vals[r], storage_map[r][0]):
if not r.type.values_eq_approx(r_vals[r], storage_map[r][0]):
raise BadClinkerOutput(r, val_py=r_vals[r], val_c=storage_map[r][0])
storage_map[r][0] = None #clear the storage_map for the thunk_c
......@@ -596,7 +596,7 @@ class _Linker(gof.link.LocalLinker):
r_val = r_vals[r]
assert r.type == new_r.type
if not r.type.values_eq_enough(r_val, new_r_val):
if not r.type.values_eq_approx(r_val, new_r_val):
raise BadOptimization(old_r=r,
new_r=new_r,
old_r_val=r_val,
......
......@@ -221,7 +221,11 @@ class PureType(object):
def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Result of this Type"""
raise AbstractFunctionError()
try:
self.filter(a, True)
return True
except TypeError:
return False
def make_result(self, name = None):
"""Return a new `Result` instance of Type `self`.
......@@ -246,8 +250,17 @@ class PureType(object):
r.tag.trace = traceback.extract_stack()[:-1]
return r
def values_eq_enough(self, a, b):
"""Return True if a and b can be considered equal as Op outputs, else False.
def values_eq(self, a, b):
"""
Return True if a and b can be considered exactly equal.
a and b are assumed to be valid values of this Type.
"""
return a == b
def values_eq_approx(self, a, b):
"""
Return True if a and b can be considered approximately equal.
:param a: a potential value for a Result of this Type.
......@@ -255,12 +268,14 @@ class PureType(object):
:rtype: Bool
This function is used by theano debugging tools to decide whether two values are
equivalent, admitting a certain amount of numerical instability. For example,
for floating-point numbers this function should be an approximate comparison.
This function is used by theano debugging tools to decide
whether two values are equivalent, admitting a certain amount
of numerical instability. For example, for floating-point
numbers this function should be an approximate comparison.
By default, this does an exact comparison.
"""
return (a == b)
return self.values_eq(a, b)
_nothing = """
......
......@@ -52,13 +52,8 @@ class Scalar(Type):
except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
def values_eq_enough(self, a, b):
return abs(a - b) / (a+b) < 1e-4
def is_valid_value(self, a):
_a = numpy.asarray(a)
rval = (_a.ndim == 0) and (str(_a.dtype) == self.dtype)
return rval
def values_eq_approx(self, a, b, tolerance = 1e-4):
return abs(a - b) / (a+b) < tolerance
def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype
......
......@@ -222,21 +222,10 @@ class Tensor(Type):
"""Compare True iff other is the same kind of Tensor"""
return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable
def values_eq_enough(self, a, b):
def values_eq_approx(self, a, b):
return type(a) is numpy.ndarray and type(b) is numpy.ndarray \
and (a.shape == b.shape) and numpy.allclose(a, b)
def is_valid_value(self, a):
rval = (type(a) is numpy.ndarray) and (self.ndim == a.ndim) \
and (str(a.dtype) == self.dtype) \
and all([((si == 1) or not bi) for si, bi in zip(a.shape, self.broadcastable)])
if not rval:
print type(a),(type(a) is numpy.ndarray)
print a.ndim, (self.ndim == a.ndim)
print a.dtype, (str(a.dtype) == self.dtype)
print a.shape, self.broadcastable, ([(shp_i == 1) for shp_i in a.shape] == self.broadcastable)
return rval
def __hash__(self):
"""Hash equal for same kinds of Tensor"""
return hash(self.dtype) ^ hash(self.broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论