提交 033b4eab authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added casting Ops

上级 fa874c3a
...@@ -586,6 +586,19 @@ class T_Shape(unittest.TestCase): ...@@ -586,6 +586,19 @@ class T_Shape(unittest.TestCase):
s = shape(numpy.ones((5, 3, 10))) s = shape(numpy.ones((5, 3, 10)))
self.failUnless((eval_outputs([s]) == [5, 3, 10]).all()) self.failUnless((eval_outputs([s]) == [5, 3, 10]).all())
class T_Cast(unittest.TestCase):
def test_basic(self):
for type1 in ['int8', 'int16', 'int32', 'int64', 'float32', 'float64']:
x = Tensor(dtype = type1, broadcastable = (False, )).make_result()
for type2, converter in zip(['int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
[convert_to_int8, convert_to_int16, convert_to_int32, convert_to_int64,
convert_to_float32, convert_to_float64]):
y = converter(x)
f = function([x], [y], strict = True, linker = 'c&py')
a = numpy.arange(10, dtype = type1)
b = f(a)
self.failUnless(numpy.all(b == numpy.arange(10, dtype = type2)))
class T_argmax(unittest.TestCase): class T_argmax(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(123784) numpy.random.seed(123784)
...@@ -1626,5 +1639,5 @@ class _test_grad(unittest.TestCase): ...@@ -1626,5 +1639,5 @@ class _test_grad(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# suite = unittest.TestLoader() # suite = unittest.TestLoader()
# suite = suite.loadTestsFromTestCase(T_subtensor) # suite = suite.loadTestsFromTestCase(T_Cast)
# unittest.TextTestRunner(verbosity=2).run(suite) # unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -390,6 +390,19 @@ s2t.TensorConstant = TensorConstant ...@@ -390,6 +390,19 @@ s2t.TensorConstant = TensorConstant
s2t.TensorValue = TensorValue s2t.TensorValue = TensorValue
#########################
# Utilities
#########################
def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
return straight, inplace
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
...@@ -421,6 +434,28 @@ class ScalarFromTensor(Op): ...@@ -421,6 +434,28 @@ class ScalarFromTensor(Op):
scalar_from_tensor = ScalarFromTensor() scalar_from_tensor = ScalarFromTensor()
def cast(t, dtype):
mapping = {'int8': convert_to_int8,
'int16': convert_to_int16,
'int32': convert_to_int32,
'int64': convert_to_int64,
'float32': convert_to_float32,
'float64': convert_to_float64,
'complex64': convert_to_complex64,
'complex128': convert_to_complex128}
return mapping[dtype](t)
convert_to_int8 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int8)))
convert_to_int16 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int16)))
convert_to_int32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int32)))
convert_to_int64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int64)))
convert_to_float32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float32)))
convert_to_float64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float64)))
convert_to_complex64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex64)))
convert_to_complex128 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex128)))
########################## ##########################
# Unary Operations # Unary Operations
########################## ##########################
...@@ -471,12 +506,6 @@ def max(x, axis=None): ...@@ -471,12 +506,6 @@ def max(x, axis=None):
return argmax(x,axis)[0] return argmax(x,axis)[0]
def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
return straight, inplace
########################## ##########################
# Comparison # Comparison
########################## ##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论