提交 65b9da6f authored 作者: Vikram's avatar Vikram

Check for 'bool' added. Tests also slightly modified

上级 b52f7a1b
......@@ -2225,7 +2225,7 @@ def local_cast_cast(node):
# We don't need to copy over any stack traces here
return [x]
if(upcast(base.dtype, type1.dtype)):
if(is_an_upcast(base.dtype, type1.dtype)):
# Checking for further redundancy. Eg: int8 -> int32 -> int8
if(type2.dtype == base.dtype):
return x.owner.inputs
......@@ -2237,17 +2237,23 @@ def local_cast_cast(node):
return [v]
def upcast(type1, type2):
def is_an_upcast(type1, type2):
"""Given two data types (as strings), check if converting to
type2 from type1 constitutes an upcast.
Differs from theano.scalar.upcast
"""
category = {
# Pair of numbers : the 'super-index' and 'sub-index'
'uint8': (0, 0), 'uint16': (0, 1), 'uint32': (0, 2), 'uint64': (0, 3),
'int8': (1, 0), 'int16': (1, 1), 'int32': (1, 2), 'int64': (1, 3),
'float16': (2, 0.5), 'float32': (2, 1.5), 'float64': (2, 2.5),
'complex64': (3, 2), 'complex128': (3, 3)
# The first number in the pair is the dtype (bool, uint, int, float,
# complex). Conversion from higher to lower is never an upcast.
# The second number roughly indicates the precision. Again, conversion
# from higher to lower is never an upcast.
'bool': (0, 0),
'uint8': (1, 1), 'uint16': (1, 2), 'uint32': (1, 3), 'uint64': (1, 4),
'int8': (2, 1), 'int16': (2, 2), 'int32': (2, 3), 'int64': (2, 4),
'float16': (3, 1.5), 'float32': (3, 2.5), 'float64': (3, 3.5),
'complex64': (4, 3), 'complex128': (4, 4)
}
cat1 = category[type1]
......
......@@ -4658,7 +4658,7 @@ class T_cast_cast(unittest.TestCase):
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32"))
......@@ -4667,7 +4667,7 @@ class T_cast_cast(unittest.TestCase):
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
def test_upcast(self):
# Upcast followed by any other cast
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论