提交 1284d324 authored 作者: nouiz's avatar nouiz

Merge pull request #718 from goodfeli/remove_value

Remove Value
...@@ -59,7 +59,7 @@ from gof import \ ...@@ -59,7 +59,7 @@ from gof import \
CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \ CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \
Container, \ Container, \
InconsistencyError, Env, \ InconsistencyError, Env, \
Apply, Variable, Constant, Value, \ Apply, Variable, Constant, \
Op, \ Op, \
opt, \ opt, \
toolbox, \ toolbox, \
......
...@@ -44,7 +44,7 @@ class OpFromGraph(gof.Op): ...@@ -44,7 +44,7 @@ class OpFromGraph(gof.Op):
if 'updates' in kwargs: if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs') raise TypeError('updates are not allowed in kwargs')
# TODO: the graph may have implicit inputs like Value and # TODO: the graph may have implicit inputs like
# SharedVariable instances. # SharedVariable instances.
# what impact to they have on the validity of this Op? # what impact to they have on the validity of this Op?
self.fn = orig_function(inputs, outputs, **kwargs) self.fn = orig_function(inputs, outputs, **kwargs)
......
...@@ -448,8 +448,8 @@ class Method(Component): ...@@ -448,8 +448,8 @@ class Method(Component):
if input not in _inputs: if input not in _inputs:
# Add this input to the inputs; we require that storage already exists for them, # Add this input to the inputs; we require that storage already exists for them,
# but otherwise they are immutable. # but otherwise they are immutable.
if isinstance(input, gof.Value): # and not isinstance(input, gof.Constant): if isinstance(input, gof.Constant):
#input might be Value or Constant #input might be Constant
storage = get_storage(input) storage = get_storage(input)
assert type(storage) is io.In assert type(storage) is io.In
......
...@@ -9,7 +9,7 @@ from theano import config ...@@ -9,7 +9,7 @@ from theano import config
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile import UnusedInputError from theano.compile import UnusedInputError
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
from theano.gof import Container, Variable, generic, graph, Constant, Value from theano.gof import Container, Variable, generic, graph, Constant
from theano.gof.python25 import any from theano.gof.python25 import any
import logging import logging
...@@ -477,10 +477,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -477,10 +477,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
if isinstance(param, Constant): if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param) raise TypeError('Constants not allowed in param list', param)
#if isinstance(param, Value): if isinstance(param, Variable): # N.B. includes SharedVariable
#return In(variable=param)
#raise NotImplementedError()
if isinstance(param, Variable): # N.B. includes Value and SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast) return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param): elif isinstance(param, Param):
return In( return In(
......
...@@ -22,7 +22,7 @@ class T_module(unittest.TestCase): ...@@ -22,7 +22,7 @@ class T_module(unittest.TestCase):
class Blah(Module): class Blah(Module):
def __init__(self, stepsize): def __init__(self, stepsize):
super(Blah, self).__init__() super(Blah, self).__init__()
self.stepsize = T.value(stepsize) self.stepsize = T.constant(stepsize)
x = T.dscalar() x = T.dscalar()
self.step = Method([x], x - self.stepsize) self.step = Method([x], x - self.stepsize)
...@@ -128,7 +128,7 @@ class T_module(unittest.TestCase): ...@@ -128,7 +128,7 @@ class T_module(unittest.TestCase):
assert i[0]==j assert i[0]==j
local_test(lambda:T.dscalar(),lambda:T.dscalar()) local_test(lambda:T.dscalar(),lambda:T.dscalar())
local_test(lambda:T.value(1),lambda:T.value(2)) local_test(lambda:T.constant(1),lambda:T.constant(2))
local_test(lambda:T.constant(1),lambda:T.constant(2)) local_test(lambda:T.constant(1),lambda:T.constant(2))
def test_list_assign(self): def test_list_assign(self):
...@@ -151,7 +151,6 @@ class T_module(unittest.TestCase): ...@@ -151,7 +151,6 @@ class T_module(unittest.TestCase):
assert numpy.all(4 == m.g()) assert numpy.all(4 == m.g())
local_test(lambda:T.dscalar(),lambda:T.dscalar()) local_test(lambda:T.dscalar(),lambda:T.dscalar())
local_test(lambda:T.value(1),lambda:T.value(2))
def test_tuple_assign(self): def test_tuple_assign(self):
"""Test that list members can be assigned tuple-wise""" """Test that list members can be assigned tuple-wise"""
...@@ -170,7 +169,6 @@ class T_module(unittest.TestCase): ...@@ -170,7 +169,6 @@ class T_module(unittest.TestCase):
assert 4 == m.g() assert 4 == m.g()
local_test(lambda:T.dscalar(),lambda:T.dscalar()) local_test(lambda:T.dscalar(),lambda:T.dscalar())
local_test(lambda:T.value(1),lambda:T.value(2))
def test_dict_assign(self): def test_dict_assign(self):
"""Test that list members can be assigned dict-wise""" """Test that list members can be assigned dict-wise"""
...@@ -191,8 +189,6 @@ class T_module(unittest.TestCase): ...@@ -191,8 +189,6 @@ class T_module(unittest.TestCase):
#print 'dscalar test' #print 'dscalar test'
local_test(lambda:T.dscalar(),lambda:T.dscalar()) local_test(lambda:T.dscalar(),lambda:T.dscalar())
#print 'value test'
local_test(lambda:T.value(1),lambda:T.value(2))
def test_method_in_list_or_dict(self): def test_method_in_list_or_dict(self):
...@@ -452,16 +448,6 @@ class T_module(unittest.TestCase): ...@@ -452,16 +448,6 @@ class T_module(unittest.TestCase):
assert numpy.all(m.f(xval) == [1, 2.5]) assert numpy.all(m.f(xval) == [1, 2.5])
assert numpy.all(xval == [-1, -1.5]) assert numpy.all(xval == [-1, -1.5])
def test_member_value(self):
"""Test that module Members of Value work correctly. As Variable?"""
M = Module()
x = T.dscalar()
M.y = T.value(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
m.y = 80
assert m.f(20) == 180
def test_member_constant(self): def test_member_constant(self):
"""Test that module Members of Constant work correctly. """Test that module Members of Constant work correctly.
As Variable with more optimization?""" As Variable with more optimization?"""
......
...@@ -9,10 +9,10 @@ from env import \ ...@@ -9,10 +9,10 @@ from env import \
InconsistencyError, MissingInputError, Env InconsistencyError, MissingInputError, Env
from destroyhandler import \ from destroyhandler import \
DestroyHandler DestroyHandler
from graph import \ from graph import \
Apply, Variable, Constant, Value, view_roots Apply, Variable, Constant, view_roots
from link import \ from link import \
Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany
...@@ -21,11 +21,11 @@ from op import \ ...@@ -21,11 +21,11 @@ from op import \
Op, PureOp, ops_with_inner_function Op, PureOp, ops_with_inner_function
from opt import (Optimizer, optimizer, SeqOptimizer, from opt import (Optimizer, optimizer, SeqOptimizer,
MergeOptimizer, MergeOptMerge, MergeOptimizer, MergeOptMerge,
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub, OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
InplaceOptimizer, PureThenInplaceOptimizer, InplaceOptimizer, PureThenInplaceOptimizer,
OpKeyOptimizer) OpKeyOptimizer)
from optdb import \ from optdb import \
......
...@@ -431,7 +431,7 @@ class CLinker(link.Linker): ...@@ -431,7 +431,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
#list(env.orphans.difference(self.outputs)) #list(env.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables self.orphans = list(r for r in self.variables
if isinstance(r, graph.Value) and if isinstance(r, graph.Constant) and
r not in self.inputs) r not in self.inputs)
self.temps = list(set(self.variables).difference( self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans)) self.inputs).difference(self.outputs).difference(self.orphans))
...@@ -497,8 +497,8 @@ class CLinker(link.Linker): ...@@ -497,8 +497,8 @@ class CLinker(link.Linker):
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]] [get_c_declare, get_c_extract, get_c_cleanup]]
elif variable in self.orphans: elif variable in self.orphans:
if not isinstance(variable, graph.Value): if not isinstance(variable, graph.Constant):
raise TypeError("All orphans to CLinker must be Value" raise TypeError("All orphans to CLinker must be Constant"
" instances.", variable) " instances.", variable)
if isinstance(variable, graph.Constant): if isinstance(variable, graph.Constant):
try: try:
......
...@@ -28,7 +28,7 @@ class Env(utils.object2): ...@@ -28,7 +28,7 @@ class Env(utils.object2):
""" WRITEME """ WRITEME
An Env represents a subgraph bound by a set of input variables and a An Env represents a subgraph bound by a set of input variables and a
set of output variables. The inputs list should contain all the inputs set of output variables. The inputs list should contain all the inputs
on which the outputs depend. Variables of type Value or Constant are on which the outputs depend. Variables of type Constant are
not counted as inputs. not counted as inputs.
The Env supports the replace operation which allows to replace a The Env supports the replace operation which allows to replace a
...@@ -240,7 +240,7 @@ class Env(utils.object2): ...@@ -240,7 +240,7 @@ class Env(utils.object2):
r_owner_done.add(node) r_owner_done.add(node)
self.__import__(node) self.__import__(node)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", r)
if not getattr(r, 'env', None) is self: if not getattr(r, 'env', None) is self:
self.__setup_r__(r) self.__setup_r__(r)
...@@ -260,7 +260,7 @@ class Env(utils.object2): ...@@ -260,7 +260,7 @@ class Env(utils.object2):
for r in node.inputs: for r in node.inputs:
if hasattr(r, 'env') and r.env is not self: if hasattr(r, 'env') and r.env is not self:
raise Exception("%s is already owned by another env" % r) raise Exception("%s is already owned by another env" % r)
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
#Verbose error message #Verbose error message
#Show a complete chain of variables from the missing input to an output #Show a complete chain of variables from the missing input to an output
...@@ -610,7 +610,7 @@ class Env(utils.object2): ...@@ -610,7 +610,7 @@ class Env(utils.object2):
excess = self.variables.difference(variables) excess = self.variables.difference(variables)
raise Exception("The variables are inappropriately cached. missing, in excess: ", missing, excess) raise Exception("The variables are inappropriately cached. missing, in excess: ", missing, excess)
for variable in variables: for variable in variables:
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Value): if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.env is not self: if variable.env is not self:
raise Exception("Variable should belong to the env.", variable) raise Exception("Variable should belong to the env.", variable)
......
...@@ -217,8 +217,6 @@ class Variable(utils.object2): ...@@ -217,8 +217,6 @@ class Variable(utils.object2):
- `Variable` (this base type) is typically the output of a symbolic computation, - `Variable` (this base type) is typically the output of a symbolic computation,
- `Value` (a subclass) adds a default :literal:`value`, and requires that owner is None
- `Constant` (a subclass) which adds a default and un-replaceable :literal:`value`, and - `Constant` (a subclass) which adds a default and un-replaceable :literal:`value`, and
requires that owner is None requires that owner is None
...@@ -325,23 +323,18 @@ class Variable(utils.object2): ...@@ -325,23 +323,18 @@ class Variable(utils.object2):
raise NotImplementedError('Subclasses of Variable must provide __ge__', raise NotImplementedError('Subclasses of Variable must provide __ge__',
self.__class__.__name__) self.__class__.__name__)
class Value(Variable): class Constant(Variable):
""" """
A :term:`Value` is a `Variable` with a default value. A :term:`Constant` is a `Variable` with a `value` field that cannot be changed at runtime.
Its owner field is always None. And since it has a default value, a `Value` instance need
not be named as an input to `compile.function`.
This kind of node is useful because when a value is known at compile time, more
optimizations are possible.
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
""" """
#__slots__ = ['data'] #__slots__ = ['data']
def __init__(self, type, data, name = None): def __init__(self, type, data, name = None):
"""Initialize self. """Initialize self.
:note: :note:
The data field is filtered by what is provided in the constructor for the Value's The data field is filtered by what is provided in the constructor for the Constant's
type field. type field.
WRITEME WRITEME
...@@ -349,45 +342,14 @@ class Value(Variable): ...@@ -349,45 +342,14 @@ class Value(Variable):
""" """
Variable.__init__(self, type, None, None, name) Variable.__init__(self, type, None, None, name)
self.data = type.filter(data) self.data = type.filter(data)
def __str__(self):
"""WRITEME"""
if self.name is not None:
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
"""WRITEME"""
#return copy(self)
cp = self.__class__(self.type, copy(self.data), self.name)
cp.tag = copy(self.tag)
return cp
def __set_owner(self, value):
"""WRITEME
:Exceptions:
- `ValueError`: if `value` is not `None`
"""
if value is not None:
raise ValueError("Value instances cannot have an owner.")
owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data,
doc='read-only data access method')
# index is not defined, because the `owner` attribute must necessarily be None
class Constant(Value):
"""
A :term:`Constant` is a `Value` that cannot be changed at runtime.
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
"""
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name)
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id # this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature() return isinstance(other, Constant) and self.signature() == other.signature()
def signature(self): def signature(self):
return (self.type, self.data) return (self.type, self.data)
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -396,6 +358,7 @@ class Constant(Value): ...@@ -396,6 +358,7 @@ class Constant(Value):
if len(name) > 20: if len(name) > 20:
name = name[:10] + '...' + name[-10] name = name[:10] + '...' + name[-10]
return 'Constant{%s}' % name return 'Constant{%s}' % name
def clone(self): def clone(self):
""" """
We clone this object, but we don't clone the data to lower memory requirement We clone this object, but we don't clone the data to lower memory requirement
...@@ -405,6 +368,21 @@ class Constant(Value): ...@@ -405,6 +368,21 @@ class Constant(Value):
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
def __set_owner(self, value):
"""WRITEME
:Exceptions:
- `ValueError`: if `value` is not `None`
"""
if value is not None:
raise ValueError("Constant instances cannot have an owner.")
owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data,
doc='read-only data access method')
# index is not defined, because the `owner` attribute must necessarily be None
def stack_search(start, expand, mode='bfs', build_inv = False): def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through a graph, either breadth- or depth-first """Search through a graph, either breadth- or depth-first
......
...@@ -286,7 +286,7 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -286,7 +286,7 @@ def map_storage(env, order, input_storage, output_storage):
for node in order: for node in order:
for r in node.inputs: for r in node.inputs:
if r not in storage_map: if r not in storage_map:
assert isinstance(r, graph.Value) assert isinstance(r, graph.Constant)
storage_map[r] = [r.data] storage_map[r] = [r.data]
for r in node.outputs: for r in node.outputs:
storage_map.setdefault(r, [None]) storage_map.setdefault(r, [None])
......
...@@ -34,8 +34,8 @@ class MyType(Type): ...@@ -34,8 +34,8 @@ class MyType(Type):
def MyVariable(name): def MyVariable(name):
return Variable(MyType(), None, None, name = name) return Variable(MyType(), None, None, name = name)
def MyValue(data): def MyConstant(data):
return graph.Value(MyType(), data = data) return graph.Constant(MyType(), data = data)
class MyOp(Op): class MyOp(Op):
...@@ -385,7 +385,7 @@ def test_value_repl(): ...@@ -385,7 +385,7 @@ def test_value_repl():
e = add_in_place(x, sy) e = add_in_place(x, sy)
g = Env([x,y], [e], False) g = Env([x,y], [e], False)
consistent(g) consistent(g)
g.replace(sy, MyValue("abc")) g.replace(sy, MyConstant("abc"))
consistent(g) consistent(g)
def test_value_repl_2(): def test_value_repl_2():
...@@ -394,7 +394,7 @@ def test_value_repl_2(): ...@@ -394,7 +394,7 @@ def test_value_repl_2():
e = add_in_place(x, sy) e = add_in_place(x, sy)
g = Env([x,y], [e], False) g = Env([x,y], [e], False)
consistent(g) consistent(g)
g.replace(sy, transpose_view(MyValue("abc"))) g.replace(sy, transpose_view(MyConstant("abc")))
consistent(g) consistent(g)
......
...@@ -1175,7 +1175,10 @@ class Mul(ScalarOp): ...@@ -1175,7 +1175,10 @@ class Mul(ScalarOp):
# output is complex. The rest of this function make this supposition. # output is complex. The rest of this function make this supposition.
output_type = self.output_types([i.type for i in inputs])[0] output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types: if output_type in complex_types:
assert gz.type in complex_types if not gz.type in complex_types:
raise TypeError('Mul with output_type '+str(output_type)+\
' expected gz type to be complex, got gz with type '+\
str(gz.type))
for input in inputs: for input in inputs:
if input.type in continuous_types: if input.type in continuous_types:
......
...@@ -212,17 +212,6 @@ def constant(x, name=None): ...@@ -212,17 +212,6 @@ def constant(x, name=None):
except TypeError: except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x)) raise TypeError("Could not convert %s to SparseType" % x, type(x))
if 0:
def value(x):
if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.value must be called on a "
"scipy.sparse.spmatrix")
try:
return SparseValue(SparseType(format=x.format,
dtype=x.dtype), x)
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x): def sp_ones_like(x):
# TODO: don't restrict to CSM formats # TODO: don't restrict to CSM formats
...@@ -365,12 +354,6 @@ class SparseConstant(gof.Constant, _sparse_py_operators): ...@@ -365,12 +354,6 @@ class SparseConstant(gof.Constant, _sparse_py_operators):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
class SparseValue(gof.Value, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseType(gof.Type): class SparseType(gof.Type):
""" """
@type dtype: numpy dtype string such as 'int64' or 'float64' (among others) @type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
...@@ -760,19 +743,19 @@ class CSMGrad(gof.op.Op): ...@@ -760,19 +743,19 @@ class CSMGrad(gof.op.Op):
sp_dim = x_shape[1] sp_dim = x_shape[1]
else: else:
sp_dim = x_shape[0] sp_dim = x_shape[0]
g_row = numpy.zeros(sp_dim, dtype=g_data.dtype) g_row = numpy.zeros(sp_dim, dtype=g_data.dtype)
gout_data = numpy.zeros_like(x_data) gout_data = numpy.zeros_like(x_data)
for i in range(len(x_indptr) - 1): for i in range(len(x_indptr) - 1):
for j_ptr in range(g_indptr[i], g_indptr[i + 1]): for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] += g_data[j_ptr] g_row[g_indices[j_ptr]] += g_data[j_ptr]
for j_ptr in range(x_indptr[i], x_indptr[i + 1]): for j_ptr in range(x_indptr[i], x_indptr[i + 1]):
gout_data[j_ptr] = g_row[x_indices[j_ptr]] gout_data[j_ptr] = g_row[x_indices[j_ptr]]
for j_ptr in range(g_indptr[i], g_indptr[i + 1]): for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] = 0 g_row[g_indices[j_ptr]] = 0
if self.kmap is None: if self.kmap is None:
g_out[0] = gout_data g_out[0] = gout_data
else: else:
...@@ -811,7 +794,7 @@ class CSMGradC(gof.Op): ...@@ -811,7 +794,7 @@ class CSMGradC(gof.Op):
raise NotImplementedError('Complex types are not supported for a_val') raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'): if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b_val') raise NotImplementedError('Complex types are not supported for b_val')
return """ return """
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
...@@ -825,22 +808,22 @@ class CSMGradC(gof.Op): ...@@ -825,22 +808,22 @@ class CSMGradC(gof.Op):
if (%(a_ptr)s->descr->type_num != PyArray_INT32) if (%(a_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;}
if (%(b_ind)s->descr->type_num != PyArray_INT32) { if (%(b_ind)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;}
if (%(b_ptr)s->descr->type_num != PyArray_INT32) if (%(b_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;}
if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0]) if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;}
if (%(b_val)s->dimensions[0] != %(b_ind)s->dimensions[0]) if (%(b_val)s->dimensions[0] != %(b_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;}
if (%(a_ptr)s->dimensions[0] != %(b_ptr)s->dimensions[0]) if (%(a_ptr)s->dimensions[0] != %(b_ptr)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;}
if ((!%(z)s) || (%(z)s->dimensions[0] != %(a_val)s->dimensions[0])) if ((!%(z)s) || (%(z)s->dimensions[0] != %(a_val)s->dimensions[0]))
{ {
{Py_XDECREF(%(z)s);} {Py_XDECREF(%(z)s);}
...@@ -854,9 +837,9 @@ class CSMGradC(gof.Op): ...@@ -854,9 +837,9 @@ class CSMGradC(gof.Op):
npy_intp M = %(a_ptr)s->dimensions[0] - 1; npy_intp M = %(a_ptr)s->dimensions[0] - 1;
npy_intp a_dim_0 = ((npy_int32 *)%(a_dim)s->data)[0]; npy_intp a_dim_0 = ((npy_int32 *)%(a_dim)s->data)[0];
npy_intp a_dim_1 = ((npy_int32 *)%(a_dim)s->data)[1]; npy_intp a_dim_1 = ((npy_int32 *)%(a_dim)s->data)[1];
npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0;
// strides tell you how many bytes to skip to go to next column/row entry // strides tell you how many bytes to skip to go to next column/row entry
npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize; npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize;
npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize; npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize;
...@@ -876,9 +859,9 @@ class CSMGradC(gof.Op): ...@@ -876,9 +859,9 @@ class CSMGradC(gof.Op):
const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data; const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data;
npy_intp nnz = %(a_ind)s->dimensions[0]; npy_intp nnz = %(a_ind)s->dimensions[0];
dtype_%(b_val)s b_row[sp_dim]; dtype_%(b_val)s b_row[sp_dim];
//clear the output array //clear the output array
for (npy_int64 i = 0; i < nnz; ++i) for (npy_int64 i = 0; i < nnz; ++i)
{ {
...@@ -893,12 +876,12 @@ class CSMGradC(gof.Op): ...@@ -893,12 +876,12 @@ class CSMGradC(gof.Op):
j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val]; b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val];
} }
for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr]; for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr];
j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) { j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) {
Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]]; Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]];
} }
for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr]; for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr];
j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_row[Db_ind[j_ptr * Sb_ind]] = 0; b_row[Db_ind[j_ptr * Sb_ind]] = 0;
......
...@@ -12,7 +12,7 @@ import numpy ...@@ -12,7 +12,7 @@ import numpy
import theano import theano
from theano.configparser import config from theano.configparser import config
from theano import gof from theano import gof
from theano.gof import Apply, Constant, Op, Type, Value, Variable from theano.gof import Apply, Constant, Op, Type, Variable
import elemwise import elemwise
from theano import scalar as scal from theano import scalar as scal
...@@ -390,12 +390,6 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -390,12 +390,6 @@ def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype) dtype=dtype)
def value(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorValue, name=name,
ndim=ndim, dtype=dtype)
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
try: try:
constant(x) constant(x)
...@@ -1784,14 +1778,6 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -1784,14 +1778,6 @@ class TensorConstant(_tensor_py_operators, Constant):
TensorType.Constant = TensorConstant TensorType.Constant = TensorConstant
class TensorValue(_tensor_py_operators, Value):
"""Subclass to add the tensor operators to the basic `Value` class.
To create a TensorValue, use the `value` function in this module.
:note: Value is deprecated by SharedVariable
"""
Tensor = TensorType Tensor = TensorType
...@@ -1801,8 +1787,6 @@ elemwise.as_tensor_variable = as_tensor_variable ...@@ -1801,8 +1787,6 @@ elemwise.as_tensor_variable = as_tensor_variable
elemwise.TensorType = TensorType elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant elemwise.TensorConstant = TensorConstant
elemwise.TensorValue = TensorValue
######################### #########################
# Utilities # Utilities
...@@ -2278,7 +2262,7 @@ class MaxAndArgmax(Op): ...@@ -2278,7 +2262,7 @@ class MaxAndArgmax(Op):
# not calculated here for it is not defined at every point where some # not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null # coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient. # Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s. # @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw) # (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
...@@ -2314,7 +2298,7 @@ class MaxAndArgmax(Op): ...@@ -2314,7 +2298,7 @@ class MaxAndArgmax(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
_max_and_argmax = MaxAndArgmax() _max_and_argmax = MaxAndArgmax()
......
import numpy import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gof import Apply, Constant, Generic, Op, Type, Value, Variable from theano.gof import Apply, Constant, Generic, Op, Type, Variable
from basic import tensor from basic import tensor
########################## ##########################
# Disk Access # Disk Access
......
...@@ -30,12 +30,13 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -30,12 +30,13 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq, inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor, Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast, tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast,
var, value, Join, shape, MaxAndArgmax, lscalar, zvector, exp, var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast) tile, patternbroadcast)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint
imported_scipy_special = False imported_scipy_special = False
...@@ -210,9 +211,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -210,9 +211,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.good.items(): for testname, inputs in self.good.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [ TensorType( dtype = input.dtype, broadcastable =
[ shape_elem == 1 for shape_elem in input.shape]
)() for input in inputs]
try: try:
#node = self.op.make_node(*inputrs)
node = safe_make_node(self.op, *inputrs) node = safe_make_node(self.op, *inputrs)
except Exception, exc: except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while" err_msg = ("Test %s::%s: Error occurred while"
...@@ -287,7 +289,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -287,7 +289,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.bad_build.items(): for testname, inputs in self.bad_build.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [shared(input) for input in inputs]
self.assertRaises(Exception, self.assertRaises(Exception,
safe_make_node, self.op, *inputrs) safe_make_node, self.op, *inputrs)
# The old error string was ("Test %s::%s: %s was successfully # The old error string was ("Test %s::%s: %s was successfully
...@@ -299,7 +301,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -299,7 +301,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip) raise SkipTest(skip)
for testname, inputs in self.bad_runtime.items(): for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [shared(input) for input in inputs]
try: try:
node = safe_make_node(self.op, *inputrs) node = safe_make_node(self.op, *inputrs)
except Exception, exc: except Exception, exc:
...@@ -310,7 +312,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -310,7 +312,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise raise
try: try:
f = inplace_func(inputrs, node.outputs, mode=mode) f = inplace_func([], node.outputs, mode=mode)
except Exception, exc: except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while trying" err_msg = ("Test %s::%s: Error occurred while trying"
" to make a Function") % (self.op, testname) " to make a Function") % (self.op, testname)
...@@ -321,7 +323,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -321,7 +323,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
# one? # one?
# TODO: test that only this one is raised and catch only this # TODO: test that only this one is raised and catch only this
# one or the subset that get raised. # one or the subset that get raised.
self.assertRaises(Exception, f, *inputs) self.assertRaises(Exception, f, [])
def test_grad(self): def test_grad(self):
if skip: if skip:
...@@ -332,7 +334,6 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -332,7 +334,6 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
try: try:
for testname, inputs in self.grad.items(): for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
try: try:
utt.verify_grad(self.op, inputs, utt.verify_grad(self.op, inputs,
mode=self.mode, mode=self.mode,
...@@ -3796,17 +3797,17 @@ class T_add(unittest.TestCase): ...@@ -3796,17 +3797,17 @@ class T_add(unittest.TestCase):
def test_complex_all_ops(self): def test_complex_all_ops(self):
for nbits in (64, 128): for nbits in (64, 128):
a = value(numpy.ones(3, dtype='complex%i' % nbits)+0.5j) a = shared(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
b = value(numpy.ones(3, dtype='complex%i' % nbits)+1.5j) b = shared(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
tests = (("+", lambda x,y: x+y), tests = (("+", lambda x,y: x+y),
("-", lambda x,y: x-y), ("-", lambda x,y: x-y),
("*", lambda x,y: x*y), ("*", lambda x,y: x*y),
("/", lambda x,y: x/y)) ("/", lambda x,y: x/y))
for s, fn in tests: for s, fn in tests:
f = inplace_func([a,b], fn(a, b)) f = inplace_func([], fn(a, b))
#print 'valid output:', fn(a.data, b.data) #print 'valid output:', fn(a.data, b.data)
#print 'theano output:', f(a.data, b.data) #print 'theano output:', f(a.data, b.data)
self.assertTrue(a.type.values_eq_approx(fn(a.data, b.data), f(a.data, b.data))) self.assertTrue(a.type.values_eq_approx(fn(a.get_value(), b.get_value()), f()))
def test_grad_scalar_l(self): def test_grad_scalar_l(self):
utt.verify_grad(add, [numpy.asarray([3.0]), rand(3)]) utt.verify_grad(add, [numpy.asarray([3.0]), rand(3)])
......
...@@ -39,43 +39,45 @@ class test_casting(unittest.TestCase): ...@@ -39,43 +39,45 @@ class test_casting(unittest.TestCase):
self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2))) self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2)))
def test_convert_to_complex(self): def test_convert_to_complex(self):
a = value(numpy.ones(3, dtype='complex64')+0.5j) val64 = numpy.ones(3, dtype='complex64') + 0.5j
b = value(numpy.ones(3, dtype='complex128')+0.5j) val128 = numpy.ones(3, dtype='complex128') + 0.5j
f = function([a],basic._convert_to_complex128(a)) vec64 = TensorType('complex64',(False,))()
vec128 = TensorType('complex128',(False,))()
f = function([vec64],basic._convert_to_complex128(vec64))
#we need to compare with the same type. #we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data)) assert vec64.type.values_eq_approx(val128, f(val64))
f = function([b],basic._convert_to_complex128(b)) f = function([vec128],basic._convert_to_complex128(vec128))
assert b.type.values_eq_approx(b.data, f(b.data)) assert vec64.type.values_eq_approx(val128, f(val128))
f = function([a],basic._convert_to_complex64(a)) f = function([vec64],basic._convert_to_complex64(vec64))
assert a.type.values_eq_approx(a.data, f(a.data)) assert vec64.type.values_eq_approx(val64, f(val64))
f = function([b],basic._convert_to_complex64(b)) f = function([vec128],basic._convert_to_complex64(vec128))
assert b.type.values_eq_approx(a.data, f(b.data)) assert vec128.type.values_eq_approx(val64, f(val128))
for nbits in (64, 128): # upcasting to complex128
# upcasting to complex128 for t in ['int8','int16','int32','int64','float32','float64']:
for t in ['int8','int16','int32','int64','float32','float64']: a = shared(numpy.ones(3, dtype=t))
a = value(numpy.ones(3, dtype=t)) b = shared(numpy.ones(3, dtype='complex128'))
b = value(numpy.ones(3, dtype='complex128')) f = function([],basic._convert_to_complex128(a))
f = function([a],basic._convert_to_complex128(a)) assert a.type.values_eq_approx(b.get_value(), f())
assert a.type.values_eq_approx(b.data, f(a.data))
# upcasting to complex64
# upcasting to complex64 for t in ['int8','int16','int32','int64','float32']:
for t in ['int8','int16','int32','int64','float32']: a = shared(numpy.ones(3, dtype=t))
a = value(numpy.ones(3, dtype=t)) b = shared(numpy.ones(3, dtype='complex64'))
b = value(numpy.ones(3, dtype='complex64')) f = function([],basic._convert_to_complex64(a))
f = function([a],basic._convert_to_complex64(a)) assert a.type.values_eq_approx(b.get_value(), f())
assert a.type.values_eq_approx(b.data, f(a.data))
# downcast to complex64
# downcast to complex64 for t in ['float64']:
for t in ['float64']: a = shared(numpy.ones(3, dtype=t))
a = value(numpy.ones(3, dtype=t)) b = shared(numpy.ones(3, dtype='complex64'))
b = value(numpy.ones(3, dtype='complex64')) f = function([],basic._convert_to_complex64(a))
f = function([a],basic._convert_to_complex64(a)) assert a.type.values_eq_approx(b.get_value(), f())
assert a.type.values_eq_approx(b.data, f(a.data))
def test_bug_complext_10_august_09(self): def test_bug_complext_10_august_09(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论