提交 033b4eab authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added casting Ops

上级 fa874c3a
......@@ -586,6 +586,19 @@ class T_Shape(unittest.TestCase):
s = shape(numpy.ones((5, 3, 10)))
self.failUnless((eval_outputs([s]) == [5, 3, 10]).all())
class T_Cast(unittest.TestCase):
def test_basic(self):
for type1 in ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']:
x = Tensor(dtype = type1, broadcastable = (False, )).make_result()
for type2, converter in zip(['int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
[convert_to_int8, convert_to_int16, convert_to_int32, convert_to_int64,
convert_to_float32, convert_to_float64]):
y = converter(x)
f = function([x], [y], strict = True, linker = 'c&py')
a = numpy.arange(10, dtype = type1)
b = f(a)
self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2)))
class T_argmax(unittest.TestCase):
def setUp(self):
numpy.random.seed(123784)
......@@ -1624,7 +1637,7 @@ class _test_grad(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main()
# suite = unittest.TestLoader()
# suite = suite.loadTestsFromTestCase(T_subtensor)
# suite = suite.loadTestsFromTestCase(T_Cast)
# unittest.TextTestRunner(verbosity=2).run(suite)
......@@ -390,6 +390,19 @@ s2t.TensorConstant = TensorConstant
s2t.TensorValue = TensorValue
#########################
# Utilities
#########################
def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
return straight, inplace
#########################
# Casting Operations
#########################
......@@ -421,6 +434,28 @@ class ScalarFromTensor(Op):
scalar_from_tensor = ScalarFromTensor()
def cast(t, dtype):
mapping = {'int8': convert_to_int8,
'int16': convert_to_int16,
'int32': convert_to_int32,
'int64': convert_to_int64,
'float32': convert_to_float32,
'float64': convert_to_float64,
'complex64': convert_to_complex64,
'complex128': convert_to_complex128}
return mapping[dtype](t)
convert_to_int8 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int8)))
convert_to_int16 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int16)))
convert_to_int32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int32)))
convert_to_int64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int64)))
convert_to_float32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float32)))
convert_to_float64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float64)))
convert_to_complex64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex64)))
convert_to_complex128 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex128)))
##########################
# Unary Operations
##########################
......@@ -471,12 +506,6 @@ def max(x, axis=None):
return argmax(x,axis)[0]
def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
return straight, inplace
##########################
# Comparison
##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论