提交 1284d324 authored 作者: nouiz's avatar nouiz

Merge pull request #718 from goodfeli/remove_value

Remove Value
......@@ -59,7 +59,7 @@ from gof import \
CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \
Container, \
InconsistencyError, Env, \
Apply, Variable, Constant, Value, \
Apply, Variable, Constant, \
Op, \
opt, \
toolbox, \
......
......@@ -44,7 +44,7 @@ class OpFromGraph(gof.Op):
if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs')
# TODO: the graph may have implicit inputs like Value and
# TODO: the graph may have implicit inputs like
# SharedVariable instances.
# what impact to they have on the validity of this Op?
self.fn = orig_function(inputs, outputs, **kwargs)
......
......@@ -448,8 +448,8 @@ class Method(Component):
if input not in _inputs:
# Add this input to the inputs; we require that storage already exists for them,
# but otherwise they are immutable.
if isinstance(input, gof.Value): # and not isinstance(input, gof.Constant):
#input might be Value or Constant
if isinstance(input, gof.Constant):
#input might be Constant
storage = get_storage(input)
assert type(storage) is io.In
......
......@@ -9,7 +9,7 @@ from theano import config
from theano.compile import orig_function, In, Out
from theano.compile import UnusedInputError
from theano.compile.sharedvalue import SharedVariable, shared
from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.gof import Container, Variable, generic, graph, Constant
from theano.gof.python25 import any
import logging
......@@ -477,10 +477,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param)
#if isinstance(param, Value):
#return In(variable=param)
#raise NotImplementedError()
if isinstance(param, Variable): # N.B. includes Value and SharedVariable
if isinstance(param, Variable): # N.B. includes SharedVariable
return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param):
return In(
......
......@@ -22,7 +22,7 @@ class T_module(unittest.TestCase):
class Blah(Module):
def __init__(self, stepsize):
super(Blah, self).__init__()
self.stepsize = T.value(stepsize)
self.stepsize = T.constant(stepsize)
x = T.dscalar()
self.step = Method([x], x - self.stepsize)
......@@ -128,7 +128,7 @@ class T_module(unittest.TestCase):
assert i[0]==j
local_test(lambda:T.dscalar(),lambda:T.dscalar())
local_test(lambda:T.value(1),lambda:T.value(2))
local_test(lambda:T.constant(1),lambda:T.constant(2))
local_test(lambda:T.constant(1),lambda:T.constant(2))
def test_list_assign(self):
......@@ -151,7 +151,6 @@ class T_module(unittest.TestCase):
assert numpy.all(4 == m.g())
local_test(lambda:T.dscalar(),lambda:T.dscalar())
local_test(lambda:T.value(1),lambda:T.value(2))
def test_tuple_assign(self):
"""Test that list members can be assigned tuple-wise"""
......@@ -170,7 +169,6 @@ class T_module(unittest.TestCase):
assert 4 == m.g()
local_test(lambda:T.dscalar(),lambda:T.dscalar())
local_test(lambda:T.value(1),lambda:T.value(2))
def test_dict_assign(self):
"""Test that list members can be assigned dict-wise"""
......@@ -191,8 +189,6 @@ class T_module(unittest.TestCase):
#print 'dscalar test'
local_test(lambda:T.dscalar(),lambda:T.dscalar())
#print 'value test'
local_test(lambda:T.value(1),lambda:T.value(2))
def test_method_in_list_or_dict(self):
......@@ -452,16 +448,6 @@ class T_module(unittest.TestCase):
assert numpy.all(m.f(xval) == [1, 2.5])
assert numpy.all(xval == [-1, -1.5])
def test_member_value(self):
"""Test that module Members of Value work correctly. As Variable?"""
M = Module()
x = T.dscalar()
M.y = T.value(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
m.y = 80
assert m.f(20) == 180
def test_member_constant(self):
"""Test that module Members of Constant work correctly.
As Variable with more optimization?"""
......
......@@ -9,10 +9,10 @@ from env import \
InconsistencyError, MissingInputError, Env
from destroyhandler import \
DestroyHandler
DestroyHandler
from graph import \
Apply, Variable, Constant, Value, view_roots
Apply, Variable, Constant, view_roots
from link import \
Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany
......@@ -21,11 +21,11 @@ from op import \
Op, PureOp, ops_with_inner_function
from opt import (Optimizer, optimizer, SeqOptimizer,
MergeOptimizer, MergeOptMerge,
LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
InplaceOptimizer, PureThenInplaceOptimizer,
MergeOptimizer, MergeOptMerge,
LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
InplaceOptimizer, PureThenInplaceOptimizer,
OpKeyOptimizer)
from optdb import \
......
......@@ -431,7 +431,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order.
#list(env.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables
if isinstance(r, graph.Value) and
if isinstance(r, graph.Constant) and
r not in self.inputs)
self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans))
......@@ -497,8 +497,8 @@ class CLinker(link.Linker):
policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_extract, get_c_cleanup]]
elif variable in self.orphans:
if not isinstance(variable, graph.Value):
raise TypeError("All orphans to CLinker must be Value"
if not isinstance(variable, graph.Constant):
raise TypeError("All orphans to CLinker must be Constant"
" instances.", variable)
if isinstance(variable, graph.Constant):
try:
......
......@@ -28,7 +28,7 @@ class Env(utils.object2):
""" WRITEME
An Env represents a subgraph bound by a set of input variables and a
set of output variables. The inputs list should contain all the inputs
on which the outputs depend. Variables of type Value or Constant are
on which the outputs depend. Variables of type Constant are
not counted as inputs.
The Env supports the replace operation which allows to replace a
......@@ -240,7 +240,7 @@ class Env(utils.object2):
r_owner_done.add(node)
self.__import__(node)
for r in variables:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
raise MissingInputError("Undeclared input", r)
if not getattr(r, 'env', None) is self:
self.__setup_r__(r)
......@@ -260,7 +260,7 @@ class Env(utils.object2):
for r in node.inputs:
if hasattr(r, 'env') and r.env is not self:
raise Exception("%s is already owned by another env" % r)
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
#Verbose error message
#Show a complete chain of variables from the missing input to an output
......@@ -610,7 +610,7 @@ class Env(utils.object2):
excess = self.variables.difference(variables)
raise Exception("The variables are inappropriately cached. missing, in excess: ", missing, excess)
for variable in variables:
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Value):
if variable.owner is None and variable not in self.inputs and not isinstance(variable, graph.Constant):
raise Exception("Undeclared input.", variable)
if variable.env is not self:
raise Exception("Variable should belong to the env.", variable)
......
......@@ -217,8 +217,6 @@ class Variable(utils.object2):
- `Variable` (this base type) is typically the output of a symbolic computation,
- `Value` (a subclass) adds a default :literal:`value`, and requires that owner is None
- `Constant` (a subclass) which adds a default and un-replaceable :literal:`value`, and
requires that owner is None
......@@ -325,23 +323,18 @@ class Variable(utils.object2):
raise NotImplementedError('Subclasses of Variable must provide __ge__',
self.__class__.__name__)
class Value(Variable):
class Constant(Variable):
"""
A :term:`Value` is a `Variable` with a default value.
Its owner field is always None. And since it has a default value, a `Value` instance need
not be named as an input to `compile.function`.
This kind of node is useful because when a value is known at compile time, more
optimizations are possible.
A :term:`Constant` is a `Variable` with a `value` field that cannot be changed at runtime.
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
"""
#__slots__ = ['data']
def __init__(self, type, data, name = None):
"""Initialize self.
:note:
The data field is filtered by what is provided in the constructor for the Value's
The data field is filtered by what is provided in the constructor for the Constant's
type field.
WRITEME
......@@ -349,45 +342,14 @@ class Value(Variable):
"""
Variable.__init__(self, type, None, None, name)
self.data = type.filter(data)
def __str__(self):
"""WRITEME"""
if self.name is not None:
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
"""WRITEME"""
#return copy(self)
cp = self.__class__(self.type, copy(self.data), self.name)
cp.tag = copy(self.tag)
return cp
def __set_owner(self, value):
"""WRITEME
:Exceptions:
- `ValueError`: if `value` is not `None`
"""
if value is not None:
raise ValueError("Value instances cannot have an owner.")
owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data,
doc='read-only data access method')
# index is not defined, because the `owner` attribute must necessarily be None
class Constant(Value):
"""
A :term:`Constant` is a `Value` that cannot be changed at runtime.
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
"""
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name)
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
def signature(self):
return (self.type, self.data)
def __str__(self):
if self.name is not None:
return self.name
......@@ -396,6 +358,7 @@ class Constant(Value):
if len(name) > 20:
name = name[:10] + '...' + name[-10]
return 'Constant{%s}' % name
def clone(self):
"""
We clone this object, but we don't clone the data to lower memory requirement
......@@ -405,6 +368,21 @@ class Constant(Value):
cp.tag = copy(self.tag)
return cp
def __set_owner(self, value):
"""WRITEME
:Exceptions:
- `ValueError`: if `value` is not `None`
"""
if value is not None:
raise ValueError("Constant instances cannot have an owner.")
owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data,
doc='read-only data access method')
# index is not defined, because the `owner` attribute must necessarily be None
def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through a graph, either breadth- or depth-first
......
......@@ -286,7 +286,7 @@ def map_storage(env, order, input_storage, output_storage):
for node in order:
for r in node.inputs:
if r not in storage_map:
assert isinstance(r, graph.Value)
assert isinstance(r, graph.Constant)
storage_map[r] = [r.data]
for r in node.outputs:
storage_map.setdefault(r, [None])
......
......@@ -34,8 +34,8 @@ class MyType(Type):
def MyVariable(name):
return Variable(MyType(), None, None, name = name)
def MyValue(data):
return graph.Value(MyType(), data = data)
def MyConstant(data):
return graph.Constant(MyType(), data = data)
class MyOp(Op):
......@@ -385,7 +385,7 @@ def test_value_repl():
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
consistent(g)
g.replace(sy, MyValue("abc"))
g.replace(sy, MyConstant("abc"))
consistent(g)
def test_value_repl_2():
......@@ -394,7 +394,7 @@ def test_value_repl_2():
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
consistent(g)
g.replace(sy, transpose_view(MyValue("abc")))
g.replace(sy, transpose_view(MyConstant("abc")))
consistent(g)
......
......@@ -1175,7 +1175,10 @@ class Mul(ScalarOp):
# output is complex. The rest of this function make this supposition.
output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types:
assert gz.type in complex_types
if not gz.type in complex_types:
raise TypeError('Mul with output_type '+str(output_type)+\
' expected gz type to be complex, got gz with type '+\
str(gz.type))
for input in inputs:
if input.type in continuous_types:
......
......@@ -212,17 +212,6 @@ def constant(x, name=None):
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
if 0:
def value(x):
if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.value must be called on a "
"scipy.sparse.spmatrix")
try:
return SparseValue(SparseType(format=x.format,
dtype=x.dtype), x)
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x):
# TODO: don't restrict to CSM formats
......@@ -365,12 +354,6 @@ class SparseConstant(gof.Constant, _sparse_py_operators):
def __repr__(self):
return str(self)
class SparseValue(gof.Value, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseType(gof.Type):
"""
@type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
......@@ -760,19 +743,19 @@ class CSMGrad(gof.op.Op):
sp_dim = x_shape[1]
else:
sp_dim = x_shape[0]
g_row = numpy.zeros(sp_dim, dtype=g_data.dtype)
gout_data = numpy.zeros_like(x_data)
for i in range(len(x_indptr) - 1):
for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] += g_data[j_ptr]
for j_ptr in range(x_indptr[i], x_indptr[i + 1]):
gout_data[j_ptr] = g_row[x_indices[j_ptr]]
for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] = 0
if self.kmap is None:
g_out[0] = gout_data
else:
......@@ -811,7 +794,7 @@ class CSMGradC(gof.Op):
raise NotImplementedError('Complex types are not supported for a_val')
if node.inputs[3].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for b_val')
return """
if (%(a_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;}
if (%(a_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;}
......@@ -825,22 +808,22 @@ class CSMGradC(gof.Op):
if (%(a_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;}
if (%(b_ind)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;}
if (%(b_ptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;}
if (%(a_val)s->dimensions[0] != %(a_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;}
if (%(b_val)s->dimensions[0] != %(b_ind)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;}
if (%(a_ptr)s->dimensions[0] != %(b_ptr)s->dimensions[0])
{PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;}
if ((!%(z)s) || (%(z)s->dimensions[0] != %(a_val)s->dimensions[0]))
{
{Py_XDECREF(%(z)s);}
......@@ -854,9 +837,9 @@ class CSMGradC(gof.Op):
npy_intp M = %(a_ptr)s->dimensions[0] - 1;
npy_intp a_dim_0 = ((npy_int32 *)%(a_dim)s->data)[0];
npy_intp a_dim_1 = ((npy_int32 *)%(a_dim)s->data)[1];
npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0;
// strides tell you how many bytes to skip to go to next column/row entry
npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize;
npy_intp Sa_val = %(a_val)s->strides[0] / %(a_val)s->descr->elsize;
......@@ -876,9 +859,9 @@ class CSMGradC(gof.Op):
const npy_int32 * __restrict__ Db_ptr = (npy_int32*)%(b_ptr)s->data;
npy_intp nnz = %(a_ind)s->dimensions[0];
dtype_%(b_val)s b_row[sp_dim];
//clear the output array
for (npy_int64 i = 0; i < nnz; ++i)
{
......@@ -893,12 +876,12 @@ class CSMGradC(gof.Op):
j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val];
}
for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr];
j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) {
Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]];
}
for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr];
j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) {
b_row[Db_ind[j_ptr * Sb_ind]] = 0;
......
......@@ -12,7 +12,7 @@ import numpy
import theano
from theano.configparser import config
from theano import gof
from theano.gof import Apply, Constant, Op, Type, Value, Variable
from theano.gof import Apply, Constant, Op, Type, Variable
import elemwise
from theano import scalar as scal
......@@ -390,12 +390,6 @@ def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype)
def value(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorValue, name=name,
ndim=ndim, dtype=dtype)
def _obj_is_wrappable_as_tensor(x):
try:
constant(x)
......@@ -1784,14 +1778,6 @@ class TensorConstant(_tensor_py_operators, Constant):
TensorType.Constant = TensorConstant
class TensorValue(_tensor_py_operators, Value):
"""Subclass to add the tensor operators to the basic `Value` class.
To create a TensorValue, use the `value` function in this module.
:note: Value is deprecated by SharedVariable
"""
Tensor = TensorType
......@@ -1801,8 +1787,6 @@ elemwise.as_tensor_variable = as_tensor_variable
elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant
elemwise.TensorValue = TensorValue
#########################
# Utilities
......@@ -2278,7 +2262,7 @@ class MaxAndArgmax(Op):
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
......@@ -2314,7 +2298,7 @@ class MaxAndArgmax(Op):
def __str__(self):
return self.__class__.__name__
_max_and_argmax = MaxAndArgmax()
......
import numpy
import theano
from theano import gof
from theano.gof import Apply, Constant, Generic, Op, Type, Value, Variable
from theano.gof import Apply, Constant, Generic, Op, Type, Variable
from basic import tensor
##########################
# Disk Access
......
......@@ -30,12 +30,13 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast,
var, value, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast)
from theano.tests import unittest_tools as utt
from theano.printing import debugprint
imported_scipy_special = False
......@@ -210,9 +211,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip)
for testname, inputs in self.good.items():
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
inputrs = [ TensorType( dtype = input.dtype, broadcastable =
[ shape_elem == 1 for shape_elem in input.shape]
)() for input in inputs]
try:
#node = self.op.make_node(*inputrs)
node = safe_make_node(self.op, *inputrs)
except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while"
......@@ -287,7 +289,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip)
for testname, inputs in self.bad_build.items():
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
inputrs = [shared(input) for input in inputs]
self.assertRaises(Exception,
safe_make_node, self.op, *inputrs)
# The old error string was ("Test %s::%s: %s was successfully
......@@ -299,7 +301,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise SkipTest(skip)
for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
inputrs = [shared(input) for input in inputs]
try:
node = safe_make_node(self.op, *inputrs)
except Exception, exc:
......@@ -310,7 +312,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
raise
try:
f = inplace_func(inputrs, node.outputs, mode=mode)
f = inplace_func([], node.outputs, mode=mode)
except Exception, exc:
err_msg = ("Test %s::%s: Error occurred while trying"
" to make a Function") % (self.op, testname)
......@@ -321,7 +323,7 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
# one?
# TODO: test that only this one is raised and catch only this
# one or the subset that get raised.
self.assertRaises(Exception, f, *inputs)
self.assertRaises(Exception, f, [])
def test_grad(self):
if skip:
......@@ -332,7 +334,6 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
try:
for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
try:
utt.verify_grad(self.op, inputs,
mode=self.mode,
......@@ -3796,17 +3797,17 @@ class T_add(unittest.TestCase):
def test_complex_all_ops(self):
for nbits in (64, 128):
a = value(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
b = value(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
a = shared(numpy.ones(3, dtype='complex%i' % nbits)+0.5j)
b = shared(numpy.ones(3, dtype='complex%i' % nbits)+1.5j)
tests = (("+", lambda x,y: x+y),
("-", lambda x,y: x-y),
("*", lambda x,y: x*y),
("/", lambda x,y: x/y))
for s, fn in tests:
f = inplace_func([a,b], fn(a, b))
f = inplace_func([], fn(a, b))
#print 'valid output:', fn(a.data, b.data)
#print 'theano output:', f(a.data, b.data)
self.assertTrue(a.type.values_eq_approx(fn(a.data, b.data), f(a.data, b.data)))
self.assertTrue(a.type.values_eq_approx(fn(a.get_value(), b.get_value()), f()))
def test_grad_scalar_l(self):
utt.verify_grad(add, [numpy.asarray([3.0]), rand(3)])
......
......@@ -39,43 +39,45 @@ class test_casting(unittest.TestCase):
self.assertTrue(numpy.all(b == numpy.arange(10, dtype = type2)))
def test_convert_to_complex(self):
a = value(numpy.ones(3, dtype='complex64')+0.5j)
b = value(numpy.ones(3, dtype='complex128')+0.5j)
val64 = numpy.ones(3, dtype='complex64') + 0.5j
val128 = numpy.ones(3, dtype='complex128') + 0.5j
f = function([a],basic._convert_to_complex128(a))
vec64 = TensorType('complex64',(False,))()
vec128 = TensorType('complex128',(False,))()
f = function([vec64],basic._convert_to_complex128(vec64))
#we need to compare with the same type.
assert a.type.values_eq_approx(b.data, f(a.data))
f = function([b],basic._convert_to_complex128(b))
assert b.type.values_eq_approx(b.data, f(b.data))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(a.data, f(a.data))
f = function([b],basic._convert_to_complex64(b))
assert b.type.values_eq_approx(a.data, f(b.data))
for nbits in (64, 128):
# upcasting to complex128
for t in ['int8','int16','int32','int64','float32','float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic._convert_to_complex128(a))
assert a.type.values_eq_approx(b.data, f(a.data))
# upcasting to complex64
for t in ['int8','int16','int32','int64','float32']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
# downcast to complex64
for t in ['float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
assert vec64.type.values_eq_approx(val128, f(val64))
f = function([vec128],basic._convert_to_complex128(vec128))
assert vec64.type.values_eq_approx(val128, f(val128))
f = function([vec64],basic._convert_to_complex64(vec64))
assert vec64.type.values_eq_approx(val64, f(val64))
f = function([vec128],basic._convert_to_complex64(vec128))
assert vec128.type.values_eq_approx(val64, f(val128))
# upcasting to complex128
for t in ['int8','int16','int32','int64','float32','float64']:
a = shared(numpy.ones(3, dtype=t))
b = shared(numpy.ones(3, dtype='complex128'))
f = function([],basic._convert_to_complex128(a))
assert a.type.values_eq_approx(b.get_value(), f())
# upcasting to complex64
for t in ['int8','int16','int32','int64','float32']:
a = shared(numpy.ones(3, dtype=t))
b = shared(numpy.ones(3, dtype='complex64'))
f = function([],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.get_value(), f())
# downcast to complex64
for t in ['float64']:
a = shared(numpy.ones(3, dtype=t))
b = shared(numpy.ones(3, dtype='complex64'))
f = function([],basic._convert_to_complex64(a))
assert a.type.values_eq_approx(b.get_value(), f())
def test_bug_complext_10_august_09(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论