cleaned up gof.op, added constant, same_properties and mergeable to Scalar

上级 cae06547
import unittest
from gof import ResultBase
from gof import Op
from gof import Env
from gof import modes
from gof import ResultBase, Op, Env, modes
from scalar_ops import *
......
......@@ -11,7 +11,9 @@ import graph
from copy import copy
__all__ = ['Op']
__all__ = ['Op',
'GuardedOp',
]
class Op(object):
......@@ -47,95 +49,6 @@ class Op(object):
raise AbstractFunctionError("Op is an abstract class. Its constructor does nothing, you must override it.")
# def __init__(self, *inputs):
# self._inputs = None
# self._outputs = None
# self.set_inputs(inputs)
# self.validate_update()
# def __get_input(self, i):
# input = self._inputs[i]
# # if input.replaced:
# # raise BrokenLinkError()
# return input
# def __set_input(self, i, new_input):
# self._inputs[i] = new_input
# def __get_inputs(self):
# # for input in self._inputs:
# # if input.replaced:
# # raise BrokenLinkError()
# return self._inputs
# def __set_inputs(self, new_inputs):
# self._inputs = list(new_inputs)
# def __get_output(self, i):
# return self._outputs[i]
# def __set_output(self, i, new_output):
# raise Exception("Cannot change outputs.")
# # old_output = self._outputs[i]
# # if old_output != new_output:
# # old_output.replaced = True
# # try:
# # # We try to reuse the old storage, if there is one
# # new_output.data = old_output.data
# # except:
# # pass
# # new_output.role = (self, i)
# # self._outputs[i] = new_output
# def __get_outputs(self):
# return self._outputs
# def __set_outputs(self, new_outputs):
# if self._outputs is None:
# for i, output in enumerate(new_outputs):
# output.role = (self, i)
# self._outputs = new_outputs
# return True
# raise Exception("Cannot change outputs.")
# # if len(self._outputs) != len(new_outputs):
# # raise TypeError("The new outputs must be exactly as many as the previous outputs.")
# # for i, new_output in enumerate(new_outputs):
# # self.__set_output(i, new_output)
# def get_input(self, i):
# return self.__get_input(i)
# def set_input(self, i, new_input):
# old_input = self.__get_input(i)
# try:
# self.__set_input(i, new_input)
# return self.validate_update()
# except:
# self.__set_input(i, old_input)
# self.validate_update()
# raise
# def get_inputs(self):
# return self.__get_inputs()
# def set_inputs(self, new_inputs):
# old_inputs = self.__get_inputs()
# try:
# self.__set_inputs(new_inputs)
# return self.validate_update()
# except:
# self._inputs = old_inputs
# raise
# def get_output(self, i):
# return self.__get_output(i)
# def get_outputs(self):
# return self.__get_outputs()
def get_input(self, i):
return self._inputs[i]
def set_input(self, i, new):
......@@ -164,46 +77,6 @@ class Op(object):
outputs = property(get_outputs, set_outputs, doc = "The list of this Op's output Results.")
# def validate_update(self):
# """
# (Abstract) This function must do two things:
# * validate: check all the inputs in self.inputs to ensure
# that they have the right type for this Op, etc.
# If the validation fails, raise an exception.
# * update: if self.outputs is None, create output Results
# and set the Op's outputs. Else, fail or update
# the outputs in place.
# If any changes were made to the outputs, return True. Else,
# return False.
# """
# raise AbstractFunctionError()
# def repair(self):
# """
# Repairs all the inputs that are broken links to use what
# they were replaced with. Then, calls self.validate_update()
# to validate the new inputs and make new outputs.
# """
# changed = False
# repaired_inputs = []
# old_inputs = self._inputs
# for input in self._inputs:
# if input.replaced:
# changed = True
# role = input.role.old_role
# input = role[0].outputs[role[1]]
# repaired_inputs.append(input)
# if changed:
# try:
# self.__set_inputs(repaired_inputs)
# self.validate_update()
# except:
# self._inputs = old_inputs
# raise
# return changed
#
# copy
#
......@@ -324,238 +197,3 @@ class GuardedOp(Op):
raise TypeError("The new inputs are not as many as the previous ones.")
for i, new in enumerate(new):
self.set_input(i, new)
# def __init__(self, inputs, outputs, use_self_setters = False):
# """
# Initializes the '_inputs' and '_outputs' slots and sets the
# owner of all outputs to self.
# If use_self_setters is False, Op::set_input and Op::set_output
# are used, which do the minimum checks and manipulations. Else,
# the user defined set_input and set_output functions are
# called (in any case, all inputs and outputs are initialized
# to None).
# """
# self._inputs = [None] * len(inputs)
# self._outputs = [None] * len(outputs)
# if use_self_setters:
# for i, input in enumerate(inputs):
# self.set_input(i, input, validate = False)
# for i, output in enumerate(outputs):
# self.set_output(i, output, validate = False)
# self.validate()
# else:
# for i, input in enumerate(inputs):
# Op.set_input(self, i, input, validate = False)
# for i, output in enumerate(outputs):
# Op.set_output(self, i, output, validate = False)
# self.validate()
# self.validate()
# def set_input(self, i, input, allow_changes = False, validate = True):
# """
# Sets the ith input of self.inputs to input. i must be an
# integer in the range from 0 to len(self.inputs) - 1 and input
# must be a Result instance. The method may raise a GofTypeError
# or a GofValueError accordingly to the semantics of the Op, if
# the new input is of the wrong type or has the wrong
# properties.
# If i > len(self.inputs), an IndexError must be raised. If i ==
# len(self.inputs), it is allowed for the Op to extend the list
# of inputs if it is a vararg Op, else an IndexError should be
# raised.
# For a vararg Op, it is also allowed to have the input
# parameter set to None for 0 <= i < len(self.inputs), in which
# case the rest of the inputs will be shifted left. In any other
# situation, a ValueError should be raised.
# In some cases, set_input may change some outputs: for example,
# a change of an input from float to double might require the
# output's type to also change from float to double. If
# allow_changes is True, set_input is allowed to perform those
# changes and must return a list of pairs, each pair containing
# the old output and the output it was replaced with (they
# _must_ be different Result instances). See Op::set_output for
# important information about replacing outputs. If
# allow_changes is False and some change in the outputs is
# required for the change in input to be correct, a
# PropagationError must be raised.
# This default implementation sets the ith input to input and
# changes no outputs. It returns None.
# """
# previous = self.inputs[i]
# self.inputs[i] = input
# if validate:
# try:
# self.validate()
# except:
# # this call gives a subclass the chance to undo the set_outputs
# # that it may have triggered...
# # TODO: test this functionality!
# self.set_input(i, previous, True, False)
# def set_output(self, i, output, validate = True):
# """
# Sets the ith output to output. The previous output, which is
# being replaced, must be invalidated using Result::invalidate.
# The new output must not already have an owner, or its owner must
# be self. It cannot be a broken link, unless it used to be at this
# spot, in which case it can be reinstated.
# For Ops that have vararg output lists, see the regulations in
# Op::set_input.
# """
# if isinstance(output.owner, BrokenLink) \
# and output.owner.owner is self \
# and output.owner.index == i:
# output.revalidate()
# else:
# output.set_owner(self, i) # this checks for an already existing owner
# previous = self.outputs[i]
# if previous:
# previous.invalidate()
# self.outputs[i] = output
# if validate:
# try:
# self.validate()
# except:
# self.set_output(i, previous, False)
# def _dontuse_repair(self, allow_changes = False):
# """
# This function attempts to repair all inputs that are broken
# links by calling set_input on the new Result that replaced
# them. Note that if a set_input operation invalidates one or
# more outputs, new broken links might appear in the other ops
# that use this op's outputs.
# It is possible that the new inputs are inconsistent with this
# op, in which case an exception will be raised and the previous
# inputs (and outputs) will be restored.
# refresh returns a list of (old_output, new_output) pairs
# detailing the changes, if any.
# """
# backtrack = []
# try:
# for i, input in enumerate(self.inputs):
# link = input.owner
# if isinstance(link, BrokenLink):
# current = link.owner.outputs[link.index]
# dirt = self.set_input(i, current, allow_changes)
# backtrack.append((i, input, dirt))
# except:
# # Restore the inputs and outputs that were successfully changed.
# for i, input, dirt in backtrack:
# self.inputs[i] = input
# if dirt:
# for old, new in dirt:
# new.invalidate()
# old.revalidate()
# self.outputs[self.outputs.index(new)] = old
# raise
# all_dirt = []
# for i, input, dirt in backtrack:
# if dirt:
# all_dirt += dirt
# return all_dirt
# def perform(self):
# """
# Performs the computation on the inputs and stores the results
# in the outputs. This function should check for the validity of
# the inputs and raise appropriate errors for debugging (for
# executing without checks, override _perform).
# An Op may define additional ways to perform the computation
# that are more efficient (e.g. a piece of C code or a C struct
# with direct references to the inputs and outputs), but
# perform() should always be available in order to have a
# consistent interface to execute graphs.
# """
# raise NotImplementedError
# def _perform(self):
# """
# Performs the computation on the inputs and stores the results
# in the outputs, like perform(), but is not required to check
# the existence or the validity of the inputs.
# """
# return self.perform()
# @classmethod
# def require(cls):
# """
# Returns a set of Feature subclasses that must be used by any
# Env manipulating this kind of op. For instance, a Destroyer
# requires ext.DestroyHandler to guarantee that various
# destructive operations don't interfere.
# By default, this collates the __require__ field of this class
# and the __require__ fields of all classes that are directly or
# indirectly superclasses to this class into a set.
# """
# r = set()
# bases = all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
# for base in bases:
# req = base.__env_require__
# if isinstance(req, (list, tuple)):
# r.update(req)
# else:
# r.add(req)
# return r
# def validate(self):
# """
# This class's __validate__ function will be called, as well as
# the __validate__ functions of all base classes down the class
# tree. If you do not want to execute __validate__ from the base
# classes, set the class variable __validate_override__ to True.
# """
# vfns = all_bases_collect(self.__class__, 'validate')
# for vfn in vfns:
# vfn(self)
# def __copy__(self):
# """
# Copies the inputs list shallowly and copies all the outputs
# because of the one owner per output restriction.
# """
# new_inputs = copy(self.inputs)
# # We copy the outputs because they are tied to a single Op.
# new_outputs = [copy(output) for output in self.outputs]
# op = self.__class__(new_inputs, new_outputs)
# op._inputs = new_inputs
# op._outputs = new_outputs
# for i, output in enumerate(op.outputs):
# # We adjust _owner and _index manually since the copies
# # point to the previous op (self).
# output._owner = op
# output._index = i
# return op
# def __deepcopy__(self, memo):
# """
# Not implemented. Use gof.graph.clone(inputs, outputs) to copy
# a subgraph.
# """
# raise NotImplementedError("Use gof.graph.clone(inputs, outputs) to copy a subgraph.")
import gof
from gof.lib import compute_from, is_result
import core
class Undefined:
"""A special class representing a gradient of 0"""
class Grad(object):
"""A dictionary-like class, into which derivative expressions may be added.
......@@ -182,189 +177,380 @@ class update_gradient_via_grad:
for input, inputg in zip(self.inputs, inputgs):
grad_d.add(input, inputg)
#
# UNIT TEST
#
import unittest
import numpy
import compile
class _testCase (unittest.TestCase):
class posneg(core.omega_op):
nout=2
def impl(x): return x, -x
def grad(x, gpos, gneg): return gpos - gneg
class posnegzero(core.omega_op):
nout=3
def impl(x): return x, -x, 0.0
def grad(x, gpos, gneg, gzero): return gpos - gneg
def setUp(self):
numpy.random.seed(1)
core.build_eval_mode()
def matinv(self,dim):
w = core.wrap(numpy.random.rand(dim,dim))
wi = core.wrap(numpy.random.rand(dim,dim))
ident = core.wrap(numpy.identity(dim))
for i in xrange(300):
wwi = core.dot(w, wi)
diff = wwi - ident
ssdiff = core.sum((diff**2))
if i == 0:
str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff
g = grad(ssdiff)
gw = g(w)
w.data[:] += -0.4 * gw.data
return str0, str(ssdiff.data)
def matinv_compiled(self, dim):
w = core.wrap(numpy.random.rand(dim,dim))
wi = core.wrap(numpy.random.rand(dim,dim))
ident = core.wrap(numpy.identity(dim))
wwi = core.dot(w, wi)
diff = wwi - ident
ssdiff = core.sum((diff**2))
str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff
g = grad(ssdiff)
gw = g(w)
prog = compile.single(g(w),ssdiff)
for i in xrange(300):
prog()
w.data[:] += -0.4 * gw.data
return str0, str(ssdiff.data)
def test0(self):
"""Matrix inversion by gradient descent (eval mode)"""
self.assertEqual(('2.67327580893', '0.000438649434819'), self.matinv(3))
def test1(self):
"""Matrix inversion by gradient descent (compiled mode)"""
self.assertEqual(('2.67327580893', '0.000438649434819'),
self.matinv_compiled(3))
def test_grad_wrt_ndarray_pointer(self):
"""Grad indexing by un-wrapped ndarray"""
a = numpy.ones((4, 4))
b = numpy.ones((4, 4))
c = numpy.ones((4, 4))
expr = core.sum(core.dot(core.add(a, b), c))
g = grad(expr)
g[a]
def test_bprop_call_order(self):
"""Ensure call before bprop is illegal"""
a = numpy.ones((3,3,3))
b = core.exp(a)
gb = Grad({b:core.wrap(a)})
try:
gb(a)
self.assertEqual('should have raised',0)
except Exception, e:
self.assertEqual(str(e), 'Grad.__call__ only makes sense after a bprop')
return
self.assertEqual('should have caught, returned',0)
def test_undefined_grad0(self):
"""Make sure posneg works with fully specified gradients"""
a = numpy.ones((3,3,3))
b,c = _testCase.posneg(a)
g = Grad({b:core.wrap(a),c:core.wrap(a)})
g.bprop()
max = numpy.max(g(a))
min = numpy.min(g(a))
self.assertEqual(max, min)
self.assertEqual(max, 0.0)
def test_undefined_grad1(self):
"""Propagate undefined values through posneg's first gradient"""
a = numpy.ones((3,3,3))
b,c = _testCase.posneg(a)
gb = Grad({b:core.wrap(a)})
try:
gb.bprop()
self.assertEqual('should have raised',0)
except AttributeError, e:
self.assertEqual(str(e), "class Undefined has no attribute 'shape'")
return
self.assertEqual("Should have been error", 0)
def test_undefined_grad2(self):
"""Propagate undefined values through posneg's second gradient"""
a = numpy.ones((3,3,3))
b,c = _testCase.posneg(a)
gc = Grad({c:core.wrap(a)})
try:
gc.bprop()
self.assertEqual('should have raised',0)
except AttributeError, e:
self.assertEqual(str(e), "class Undefined has no attribute 'shape'")
return
self.assertEqual("Should have been error", 0)
def test_undefined_grad3(self):
"""Ignore undefined values properly"""
a = numpy.ones((3,3,3))
b,c,d = _testCase.posnegzero(a)
#print b, c, d
g = Grad({b:core.wrap(a), c:core.wrap(a)})
g.bprop()
max = numpy.max(g(a))
min = numpy.min(g(a))
self.assertEqual(max, min)
self.assertEqual(max, 0.0)
def test_repeat_bprop(self):
"""Refuse to repeat bprop"""
a = numpy.ones((3,3,3))
b,c,d = _testCase.posnegzero(a)
#print b, c, d
g = Grad({b:core.wrap(a), c:core.wrap(a)})
g.bprop()
try:
g.bprop()
self.assertEqual('should have raised')
except Exception, e:
self.assertEqual(str(e), 'bprop has already been done. Consider calling with maybe_redo=True.')
return
self.assertEqual('should have caught')
def test_repeat_bprop1(self):
"""Force repeat bprop"""
a = numpy.ones((3,3,3))
z = numpy.zeros((3,3,3))
b,c,d = _testCase.posnegzero(a)
#print b, c, d
g = Grad({b:core.wrap(a), c:core.wrap(z)})
g.bprop()
g.bprop(maybe_redo=True)
max = numpy.max(g(a))
min = numpy.min(g(a))
self.assertEqual(max, min)
self.assertEqual(max, 2.0)
def tearDown(self):
core.pop_mode()
if __name__ == '__main__':
unittest.main()
# import gof
# from gof.lib import compute_from, is_result
# import core
# class Undefined:
# """A special class representing a gradient of 0"""
# class Grad(object):
# """A dictionary-like class, into which derivative expressions may be added.
# This class maps keys to their ids to deal with the ndarray, which is not
# hashable.
# Attributes: None
# Methods:
# add()
# bprop()
# __call__()
# __getitem__()
# """
# def __init__(self, dct={}):
# self.map = {}
# self.outputs = []
# self._compute_history = set([])
# self.did_bprop = False
# for key,val in dct.items():
# self.add_output(key,val)
# def __contains__(self, item):
# return item in self.map
# def __getitem__(self, item):
# """Map item to its id and retrieve it."""
# key = core.wrap(item)
# try:
# return self.map[key]
# except KeyError:
# return Undefined
# def __setitem__(self, item, val):
# """Map item to its id and store internally."""
# self.map[item] = val
# def add_output(self, r, dr):
# self.add(r, dr)
# self.outputs.append(r)
# def add(self, r, dr):
# """Add dr to the sum of gradients associated with r.
# This function should be fed as follows:
# if dr is undefined:
# r could be anything
# else dr might be core.UNCOMPUTED:
# r may be uncomputed or NumpyR
# else dr will be isinstance(NumpyR):
# r may be uncomputed or NumpyR
# """
# if dr is Undefined:
# # nothing to do
# return
# if r.data is not None and dr.data is not None:
# if not hasattr(r, 'shape'):
# raise ValueError(('Grad::add r lacks shape: type=',
# type(r)))
# if not hasattr(dr, 'shape'):
# raise ValueError(('Grad::add dr lacks shape: type=',
# type(dr)))
# if r.shape != dr.shape:
# raise ValueError(('Grad::add r, dr shape mismatch',
# r.shape, dr.shape))
# # prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
# if r.state is gof.result.Computed:
# self._compute_history.add(r)
# # add dr to self[r]
# if r in self:
# self[r] = self[r] + dr
# else:
# self[r] = dr
# def bprop(self, maybe_redo=False):
# """Build a backpropagation graph.
# The gradient associated with each value is stored in <self> which
# inherits from dictionary. The idea is that when we call
# op.update_gradient(self), that the op's update_gradient function calls
# back into <self>.add(), and says what gradient term goes with each of
# its inputs. Most of the time, the gradients of the op's outputs are
# necessary for the op to compute the gradient wrt its inputs, so
# op.update_gradient will usually call <self>.__getitem__, (via the
# [] notation).
# It is essential that the gradient of an op's outputs be fully computed
# before op.update_gradient is called, or else key errors may be raised
# and incorrect gradients will be computed.
# bprop sets the omega evaluation mode to be 'build', so no computations
# or allocations are done by bprop.
# """
# if not maybe_redo and self.did_bprop:
# raise Exception('bprop has already been done. Consider calling with maybe_redo=True.')
# core.build_mode()
# try:
# outputs = self.outputs
# inputs = gof.graph.inputs(outputs)
# for op in gof.graph.io_toposort(inputs, outputs).__reversed__():
# op.update_gradient(self)
# finally:
# core.pop_mode()
# self.did_bprop = True
# def __call__(self, item):
# """Return a derivative term.
# If the current omega evaluation mode is 'build_eval' then the node is
# computed if necessary.
# """
# if not self.did_bprop:
# raise Exception('Grad.__call__ only makes sense after a bprop')
# rval = self[item]
# if rval is not Undefined \
# and core.current_mode() == 'build_eval':
# compute_from([rval], self._compute_history)
# return rval
# def grad(cost, param=None, cost_grad = 1.0):
# """Return symbolic expression of gradient of <cost> wrt <param>.
# If <param> is None, then return a Grad instance, from which the gradients of
# multiple objects can be retrieved using the __getitem__ or __call__ methods
# (as in function currying in languages such as scheme and OCaML).
# If <param> is not None, then return the gradient expression for
# d cost / d param.
# """
# if core.current_mode() == 'eval':
# raise NotImplementedError('Gradient-related functions are not available in eval mode')
# rval = Grad({cost:core.wrap(cost_grad)})
# rval.bprop()
# if param is None:
# return rval
# else:
# return rval(param)
# class update_gradient_via_grad:
# """Inherit from this class to add a convenient self.update_gradient function"""
# def update_gradient(self, grad_d):
# """Call self.grad() and add the result to grad_d
# This function is called by grad.Grad.bprop() to construct a symbolic gradient graph.
# self.grad is called like this:
# self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
# In general, grad() should return a list of ResultValue instances whose
# length matches that of self.inputs, and whose elements are the
# gradients of self.inputs.
# There is a (but often used) special feature in place to automatically
# wrap the return value of grad() in a list if it is a ResultValue instance
# and the op is unary. This makes many grad implementations a little
# cuter.
# """
# inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
# if len(self.inputs) == 1 and is_result(inputgs):
# inputgs = [inputgs]
# else:
# assert len(inputgs) == len(self.inputs)
# for input, inputg in zip(self.inputs, inputgs):
# grad_d.add(input, inputg)
# #
# # UNIT TEST
# #
# import unittest
# import numpy
# import compile
# class _testCase (unittest.TestCase):
# class posneg(core.omega_op):
# nout=2
# def impl(x): return x, -x
# def grad(x, gpos, gneg): return gpos - gneg
# class posnegzero(core.omega_op):
# nout=3
# def impl(x): return x, -x, 0.0
# def grad(x, gpos, gneg, gzero): return gpos - gneg
# def setUp(self):
# numpy.random.seed(1)
# core.build_eval_mode()
# def matinv(self,dim):
# w = core.wrap(numpy.random.rand(dim,dim))
# wi = core.wrap(numpy.random.rand(dim,dim))
# ident = core.wrap(numpy.identity(dim))
# for i in xrange(300):
# wwi = core.dot(w, wi)
# diff = wwi - ident
# ssdiff = core.sum((diff**2))
# if i == 0:
# str0 = str_ssdiff = str(ssdiff.data)
# #print ssdiff
# g = grad(ssdiff)
# gw = g(w)
# w.data[:] += -0.4 * gw.data
# return str0, str(ssdiff.data)
# def matinv_compiled(self, dim):
# w = core.wrap(numpy.random.rand(dim,dim))
# wi = core.wrap(numpy.random.rand(dim,dim))
# ident = core.wrap(numpy.identity(dim))
# wwi = core.dot(w, wi)
# diff = wwi - ident
# ssdiff = core.sum((diff**2))
# str0 = str_ssdiff = str(ssdiff.data)
# #print ssdiff
# g = grad(ssdiff)
# gw = g(w)
# prog = compile.single(g(w),ssdiff)
# for i in xrange(300):
# prog()
# w.data[:] += -0.4 * gw.data
# return str0, str(ssdiff.data)
# def test0(self):
# """Matrix inversion by gradient descent (eval mode)"""
# self.assertEqual(('2.67327580893', '0.000438649434819'), self.matinv(3))
# def test1(self):
# """Matrix inversion by gradient descent (compiled mode)"""
# self.assertEqual(('2.67327580893', '0.000438649434819'),
# self.matinv_compiled(3))
# def test_grad_wrt_ndarray_pointer(self):
# """Grad indexing by un-wrapped ndarray"""
# a = numpy.ones((4, 4))
# b = numpy.ones((4, 4))
# c = numpy.ones((4, 4))
# expr = core.sum(core.dot(core.add(a, b), c))
# g = grad(expr)
# g[a]
# def test_bprop_call_order(self):
# """Ensure call before bprop is illegal"""
# a = numpy.ones((3,3,3))
# b = core.exp(a)
# gb = Grad({b:core.wrap(a)})
# try:
# gb(a)
# self.assertEqual('should have raised',0)
# except Exception, e:
# self.assertEqual(str(e), 'Grad.__call__ only makes sense after a bprop')
# return
# self.assertEqual('should have caught, returned',0)
# def test_undefined_grad0(self):
# """Make sure posneg works with fully specified gradients"""
# a = numpy.ones((3,3,3))
# b,c = _testCase.posneg(a)
# g = Grad({b:core.wrap(a),c:core.wrap(a)})
# g.bprop()
# max = numpy.max(g(a))
# min = numpy.min(g(a))
# self.assertEqual(max, min)
# self.assertEqual(max, 0.0)
# def test_undefined_grad1(self):
# """Propagate undefined values through posneg's first gradient"""
# a = numpy.ones((3,3,3))
# b,c = _testCase.posneg(a)
# gb = Grad({b:core.wrap(a)})
# try:
# gb.bprop()
# self.assertEqual('should have raised',0)
# except AttributeError, e:
# self.assertEqual(str(e), "class Undefined has no attribute 'shape'")
# return
# self.assertEqual("Should have been error", 0)
# def test_undefined_grad2(self):
# """Propagate undefined values through posneg's second gradient"""
# a = numpy.ones((3,3,3))
# b,c = _testCase.posneg(a)
# gc = Grad({c:core.wrap(a)})
# try:
# gc.bprop()
# self.assertEqual('should have raised',0)
# except AttributeError, e:
# self.assertEqual(str(e), "class Undefined has no attribute 'shape'")
# return
# self.assertEqual("Should have been error", 0)
# def test_undefined_grad3(self):
# """Ignore undefined values properly"""
# a = numpy.ones((3,3,3))
# b,c,d = _testCase.posnegzero(a)
# #print b, c, d
# g = Grad({b:core.wrap(a), c:core.wrap(a)})
# g.bprop()
# max = numpy.max(g(a))
# min = numpy.min(g(a))
# self.assertEqual(max, min)
# self.assertEqual(max, 0.0)
# def test_repeat_bprop(self):
# """Refuse to repeat bprop"""
# a = numpy.ones((3,3,3))
# b,c,d = _testCase.posnegzero(a)
# #print b, c, d
# g = Grad({b:core.wrap(a), c:core.wrap(a)})
# g.bprop()
# try:
# g.bprop()
# self.assertEqual('should have raised')
# except Exception, e:
# self.assertEqual(str(e), 'bprop has already been done. Consider calling with maybe_redo=True.')
# return
# self.assertEqual('should have caught')
# def test_repeat_bprop1(self):
# """Force repeat bprop"""
# a = numpy.ones((3,3,3))
# z = numpy.zeros((3,3,3))
# b,c,d = _testCase.posnegzero(a)
# #print b, c, d
# g = Grad({b:core.wrap(a), c:core.wrap(z)})
# g.bprop()
# g.bprop(maybe_redo=True)
# max = numpy.max(g(a))
# min = numpy.min(g(a))
# self.assertEqual(max, min)
# self.assertEqual(max, 2.0)
# def tearDown(self):
# core.pop_mode()
# if __name__ == '__main__':
# unittest.main()
......@@ -3,9 +3,7 @@ import numpy
from copy import copy
from gof import ResultBase
from gof import Op
from gof import utils
from gof import ResultBase, GuardedOp, utils
def as_scalar(x, name = None):
......@@ -21,13 +19,32 @@ class Scalar(ResultBase):
def __init__(self, dtype, name=None):
self.dtype = dtype
self.constant = False
ResultBase.__init__(self, role = None, data = None, name = name)
def __get_constant(self):
return self._constant
def __set_constant(self, value):
if value:
self.indestructible = True
self.constant = value
constant = property(__get_constant, __set_constant)
def validate(self, data):
py_type = self.py_type()
if not isinstance(data, py_type):
raise TypeError("Expected %s instance." % py_type)
def same_properties(self, other):
return other.dtype == self.dtype
def mergeable(self, other):
return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \
and self.data == other.data
def py_type(self):
return {'float64': float}[self.dtype]
......@@ -74,7 +91,7 @@ class Scalar(ResultBase):
class ScalarMixedOp(Op):
class ScalarMixedOp(GuardedOp):
nin = -1
nout = 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论