提交 98dc329c authored 作者: James Bergstra's avatar James Bergstra

Modified the Scalar shared variable to use a 0-d tensor as the backend. This

makes it play nicer with the new rule that update values must have the same type as their shared vars. An alternative fix to several tests would have been to implement a Tensor->Scalar cast, but I didn't do that because it's nice to just use Tensors.
上级 78fdbd22
"""Provide a simple user friendly API """
__docformat__ = 'restructuredtext en'
import traceback
import copy
import numpy
......@@ -159,14 +160,35 @@ def tensor_constructor(value, name=None, strict=False, broadcastable=None):
type = TensorType(value.dtype, broadcastable=broadcastable)
return TensorSharedVariable(type=type, value=value, name=name, strict=strict)
# TensorSharedVariable brings in the tensor operators, is not ideal, but works as long as we
# dont do purely scalar-scalar operations
class ScalarSharedVariable(SharedVariable, theano.tensor.basic._tensor_py_operators):
pass
@shared_constructor
def scalar_constructor(value, name=None, strict=False, dtype=None):
"""SharedVariable constructor for scalar values. Defaults to int64 or float64"""
if not isinstance(value, (float,int)):
"""SharedVariable constructor for scalar values. Defaults to int64 or float64.
:note: We implement this using 0-d tensors for now.
"""
if not isinstance (value, (numpy.number, float, int)):
raise TypeError()
# use float64 and int64 by default, user can override
if not dtype:
dtype = 'int64' if isinstance(value,int) else 'float64'
type = Scalar(dtype)
return TensorSharedVariable(type=type, value=numpy.asarray(value), name=name, strict=strict)
if dtype is None:
if isinstance(value, float):
dtype = 'float64'
elif isinstance(value, int):
dtype = 'int64'
else:
dtype = type(value).__name__
type = TensorType(dtype=dtype, broadcastable=[])
try:
# don't pass the dtype to asarray because we want this to fail if strict is True and the
# types do not match
rval = ScalarSharedVariable(type=type, value=numpy.asarray(value), name=name, strict=strict)
return rval
except:
traceback.print_exc()
raise
......@@ -18,6 +18,7 @@ class NNet(object):
self.lr = shared(lr, 'learning_rate')
self.w1 = shared(numpy.zeros((n_hidden, n_input)), 'w1')
self.w2 = shared(numpy.zeros((n_output, n_hidden)), 'w2')
print self.lr.type
self.hidden = sigmoid(tensor.dot(self.w1, self.input))
self.output = tensor.dot(self.w2, self.hidden)
......
......@@ -172,7 +172,7 @@ class Test_pfunc(unittest.TestCase):
# Same but using a mutable constant to show how it can be used to
# modify the update value after the function is created.
x.value = 0
y = numpy.ones(())
y = numpy.ones((), dtype='int64')
assign_mutable = pfunc([], [], updates = {x: y})
assign_mutable()
self.failUnless(x.value == 1)
......
......@@ -10,9 +10,15 @@ class Test_SharedVariable(unittest.TestCase):
def test_ctors(self):
assert shared(7).type == Scalar('int64')
assert shared(7.0).type == Scalar('float64')
assert shared(7, dtype='float64').type == Scalar('float64')
if 0: #when using an implementation that handles scalars with Scalar type
assert shared(7).type == Scalar('int64')
assert shared(7.0).type == Scalar('float64')
assert shared(7, dtype='float64').type == Scalar('float64')
else:
assert shared(7).type == theano.tensor.lscalar
assert shared(7.0).type == theano.tensor.dscalar
assert shared(7, dtype='float64').type == theano.tensor.dscalar
# test tensor constructor
b = shared(numpy.zeros((5,5), dtype='int32'))
......@@ -107,13 +113,17 @@ class Test_SharedVariable(unittest.TestCase):
def test_strict(self):
def f(var, val): var.value = val
b = shared(7, strict=True)
self.failUnlessRaises(TypeError, f(b,8.23))
b = shared(7.234, strict=True)
self.failUnlessRaises(TypeError, f(b,8))
b = shared(numpy.int64(7), strict=True)
#assert b.type == Scalar('int64')
assert b.type == theano.tensor.lscalar
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.float64(7.234), strict=True)
#assert b.type == Scalar('float64')
assert b.type == theano.tensor.dscalar
self.failUnlessRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f(b, numpy.random.rand(5,5)))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论