提交 47137e3f authored 作者: Vikram's avatar Vikram

Cleaned up upcast check. Copied stack trace. Added tests in T_cast_cast of test_opt.py

上级 c5ceabcf
...@@ -2232,6 +2232,8 @@ def local_cast_cast(node): ...@@ -2232,6 +2232,8 @@ def local_cast_cast(node):
else: else:
# Apply the second cast only # Apply the second cast only
v = node.op(base) v = node.op(base)
# Copy stack trace from the output of the original cast
copy_stack_trace(node.outputs[0], v)
return [v] return [v]
...@@ -2240,38 +2242,20 @@ def upcast(type1, type2): ...@@ -2240,38 +2242,20 @@ def upcast(type1, type2):
type2 from type1 constitutes an upcast. type2 from type1 constitutes an upcast.
""" """
upcast_pairs = ( category = {
('int8', 'int16'), ('int8', 'int32'), ('int8', 'int64'), # Pair of numbers : the 'super-index' and 'sub-index'
('int16', 'int32'), ('int16', 'int64'), 'uint8': (0, 0), 'uint16': (0, 1), 'uint32': (0, 2), 'uint64': (0, 3),
('int32', 'int64'), 'int8': (1, 0), 'int16': (1, 1), 'int32': (1, 2), 'int64': (1, 3),
('uint8', 'uint16'), ('uint8', 'uint32'), ('uint8', 'uint64'), 'float16': (2, 0.5), 'float32': (2, 1.5), 'float64': (2, 2.5),
('uint16', 'uint32'), ('uint16', 'uint64'), 'complex64': (3, 2), 'complex128': (3, 3)
('uint32', 'uint64'), }
('float16', 'float32'), ('float16', 'float32'), ('float16', 'float64'),
('float32', 'float64'),
('complex64', 'complex128'),
('uint8', 'int16'), ('uint8', 'int32'), ('uint8', 'int64'),
('uint16', 'int32'), ('uint16', 'int64'),
('uint32', 'int64'),
('int8', 'float16'), ('int8', 'float32'), ('int8', 'float64'),
('int16', 'float32'), ('int16', 'float64'),
('int32', 'float64'),
('uint8', 'float16'), ('uint8', 'float32'), ('uint8', 'float64'),
('uint16', 'float32'), ('uint16', 'float64'),
('uint32', 'float64'),
('int8', 'complex64'), ('int16', 'complex64'),
('uint8', 'complex64'), ('uint16', 'complex64'),
('float32', 'complex64'),
('int8', 'complex128'), ('int16', 'complex128'), ('int32', 'complex128'),
('uint8', 'complex128'), ('uint16', 'complex128'), ('uint32', 'complex128'),
('float32', 'complex128'), ('float64', 'complex128')
)
for pair in upcast_pairs: cat1 = category[type1]
if(type1 == pair[0] and type2 == pair[1]): cat2 = category[type2]
if(cat2[0] >= cat1[0] and cat2[1] > cat1[1]):
return True return True
else:
return False return False
......
...@@ -4650,7 +4650,7 @@ class T_cast_cast(unittest.TestCase): ...@@ -4650,7 +4650,7 @@ class T_cast_cast(unittest.TestCase):
mode = theano.compile.get_default_mode() mode = theano.compile.get_default_mode()
self.mode = mode.including('local_cast_cast') self.mode = mode.including('local_cast_cast')
def test(self): def test_consecutive(self):
x = T.fmatrix() x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float64")) o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode) f = theano.function([x], o, mode=self.mode)
...@@ -4658,7 +4658,7 @@ class T_cast_cast(unittest.TestCase): ...@@ -4658,7 +4658,7 @@ class T_cast_cast(unittest.TestCase):
f(dx) f(dx)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise) assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
x = T.dmatrix() x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32")) o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32"))
...@@ -4667,7 +4667,38 @@ class T_cast_cast(unittest.TestCase): ...@@ -4667,7 +4667,38 @@ class T_cast_cast(unittest.TestCase):
f(dx) f(dx)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise) assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
def test_upcast(self):
# Upcast followed by any other cast
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("complex128")))(x.astype("complex64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
# Upcast followed by a downcast back to the base type
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype('float32')
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, DeepCopyOp)
# Downcast followed by an upcast back to the base type
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float32"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4)
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, scal.basic.Composite)
class T_func_inverse(unittest.TestCase): class T_func_inverse(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论