提交 47137e3f authored 作者: Vikram's avatar Vikram

Cleaned up upcast check. Copied stack trace. Added tests in T_cast_cast of test_opt.py

上级 c5ceabcf
......@@ -2232,6 +2232,8 @@ def local_cast_cast(node):
else:
# Apply the second cast only
v = node.op(base)
# Copy stack trace from the output of the original cast
copy_stack_trace(node.outputs[0], v)
return [v]
......@@ -2240,39 +2242,21 @@ def upcast(type1, type2):
type2 from type1 constitutes an upcast.
"""
upcast_pairs = (
('int8', 'int16'), ('int8', 'int32'), ('int8', 'int64'),
('int16', 'int32'), ('int16', 'int64'),
('int32', 'int64'),
('uint8', 'uint16'), ('uint8', 'uint32'), ('uint8', 'uint64'),
('uint16', 'uint32'), ('uint16', 'uint64'),
('uint32', 'uint64'),
('float16', 'float32'), ('float16', 'float32'), ('float16', 'float64'),
('float32', 'float64'),
('complex64', 'complex128'),
('uint8', 'int16'), ('uint8', 'int32'), ('uint8', 'int64'),
('uint16', 'int32'), ('uint16', 'int64'),
('uint32', 'int64'),
('int8', 'float16'), ('int8', 'float32'), ('int8', 'float64'),
('int16', 'float32'), ('int16', 'float64'),
('int32', 'float64'),
('uint8', 'float16'), ('uint8', 'float32'), ('uint8', 'float64'),
('uint16', 'float32'), ('uint16', 'float64'),
('uint32', 'float64'),
('int8', 'complex64'), ('int16', 'complex64'),
('uint8', 'complex64'), ('uint16', 'complex64'),
('float32', 'complex64'),
('int8', 'complex128'), ('int16', 'complex128'), ('int32', 'complex128'),
('uint8', 'complex128'), ('uint16', 'complex128'), ('uint32', 'complex128'),
('float32', 'complex128'), ('float64', 'complex128')
)
category = {
# Pair of numbers : the 'super-index' and 'sub-index'
'uint8': (0, 0), 'uint16': (0, 1), 'uint32': (0, 2), 'uint64': (0, 3),
'int8': (1, 0), 'int16': (1, 1), 'int32': (1, 2), 'int64': (1, 3),
'float16': (2, 0.5), 'float32': (2, 1.5), 'float64': (2, 2.5),
'complex64': (3, 2), 'complex128': (3, 3)
}
for pair in upcast_pairs:
if(type1 == pair[0] and type2 == pair[1]):
return True
return False
cat1 = category[type1]
cat2 = category[type2]
if(cat2[0] >= cat1[0] and cat2[1] > cat1[1]):
return True
else:
return False
@register_canonicalize
......
......@@ -4650,7 +4650,7 @@ class T_cast_cast(unittest.TestCase):
mode = theano.compile.get_default_mode()
self.mode = mode.including('local_cast_cast')
def test(self):
def test_consecutive(self):
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode)
......@@ -4658,7 +4658,7 @@ class T_cast_cast(unittest.TestCase):
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32"))
......@@ -4667,7 +4667,38 @@ class T_cast_cast(unittest.TestCase):
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
def test_upcast(self):
# Upcast followed by any other cast
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("complex128")))(x.astype("complex64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
# Upcast followed by a downcast back to the base type
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype('float32')
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, DeepCopyOp)
# Downcast followed by an upcast back to the base type
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float32"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4)
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, scal.basic.Composite)
class T_func_inverse(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论