提交 65b9da6f authored 作者: Vikram's avatar Vikram

Check for 'bool' added. Tests also slightly modified

上级 b52f7a1b
...@@ -2225,7 +2225,7 @@ def local_cast_cast(node): ...@@ -2225,7 +2225,7 @@ def local_cast_cast(node):
# We don't need to copy over any stack traces here # We don't need to copy over any stack traces here
return [x] return [x]
if(upcast(base.dtype, type1.dtype)): if(is_an_upcast(base.dtype, type1.dtype)):
# Checking for further redundancy. Eg: int8 -> int32 -> int8 # Checking for further redundancy. Eg: int8 -> int32 -> int8
if(type2.dtype == base.dtype): if(type2.dtype == base.dtype):
return x.owner.inputs return x.owner.inputs
...@@ -2237,17 +2237,23 @@ def local_cast_cast(node): ...@@ -2237,17 +2237,23 @@ def local_cast_cast(node):
return [v] return [v]
def upcast(type1, type2): def is_an_upcast(type1, type2):
"""Given two data types (as strings), check if converting to """Given two data types (as strings), check if converting to
type2 from type1 constitutes an upcast. type2 from type1 constitutes an upcast.
Differs from theano.scalar.upcast
""" """
category = { category = {
# Pair of numbers : the 'super-index' and 'sub-index' # The first number in the pair is the dtype (bool, uint, int, float,
'uint8': (0, 0), 'uint16': (0, 1), 'uint32': (0, 2), 'uint64': (0, 3), # complex). Conversion from higher to lower is never an upcast.
'int8': (1, 0), 'int16': (1, 1), 'int32': (1, 2), 'int64': (1, 3), # The second number roughly indicates the precision. Again, conversion
'float16': (2, 0.5), 'float32': (2, 1.5), 'float64': (2, 2.5), # from higher to lower is never an upcast.
'complex64': (3, 2), 'complex128': (3, 3)
'bool': (0, 0),
'uint8': (1, 1), 'uint16': (1, 2), 'uint32': (1, 3), 'uint64': (1, 4),
'int8': (2, 1), 'int16': (2, 2), 'int32': (2, 3), 'int64': (2, 4),
'float16': (3, 1.5), 'float32': (3, 2.5), 'float64': (3, 3.5),
'complex64': (4, 3), 'complex128': (4, 4)
} }
cat1 = category[type1] cat1 = category[type1]
......
...@@ -4658,7 +4658,7 @@ class T_cast_cast(unittest.TestCase): ...@@ -4658,7 +4658,7 @@ class T_cast_cast(unittest.TestCase):
f(dx) f(dx)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise) assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
x = T.dmatrix() x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32")) o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32"))
...@@ -4667,7 +4667,7 @@ class T_cast_cast(unittest.TestCase): ...@@ -4667,7 +4667,7 @@ class T_cast_cast(unittest.TestCase):
f(dx) f(dx)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise) assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
def test_upcast(self): def test_upcast(self):
# Upcast followed by any other cast # Upcast followed by any other cast
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论