提交 51c892be authored 作者: Frederic Bastien's avatar Frederic Bastien

cast to/from floatX that are not needed as they don't change the type will not…

cast to/from floatX that are not needed as they don't change the type will not generate an OP. Added test for this.
上级 be54333b
......@@ -957,6 +957,8 @@ _cast_mapping = {
'complex128': convert_to_complex128}
def cast(x, dtype):
"""Symbolically cast `x` to a Scalar of given `dtype`."""
if dtype=='floatX': dtype = config.config.get('scalar.floatX')
_x = as_scalar(x)
if _x.type.dtype == dtype:
return _x
......
......@@ -1128,6 +1128,8 @@ _cast_mapping = {
@constructor
def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`."""
if dtype=='floatX': dtype = config.config.get('scalar.floatX')
_x = as_tensor_variable(x)
if _x.type.dtype == dtype:
return _x
......
......@@ -1929,6 +1929,63 @@ def test_default_state():
assert f(1) == 4.8
assert f(2.2) == 7
def test_cast_floatX():
floatx=config.config.get('scalar.floatX')
#float64 cast to float64 should not generate an op
x = dvector('x')
f = function([x],[cast(x,'float64')])
# print f.maker.env.toposort()
assert len(f.maker.env.toposort())==0
#float32 cast to float32 should not generate an op
x = fvector('x')
f = function([x],[cast(x,'float32')])
# print f.maker.env.toposort()
assert len(f.maker.env.toposort())==0
#floatX cast to float64
x = xvector('x')
f = function([x],[cast(x,'float64')])
# print f.maker.env.toposort()
if floatx=='float64':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#floatX cast to float32
x = xvector('x')
f = function([x],[cast(x,'float32')])
# print f.maker.env.toposort()
if floatx=='float32':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#float64 cast to floatX
x = dvector('x')
f = function([x],[cast(x,'floatX')])
# print f.maker.env.toposort()
if floatx=='float64':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#float32 cast to floatX
x = fvector('x')
f = function([x],[cast(x,'floatX')])
# print f.maker.env.toposort()
if floatx=='float32':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#floatX cast to floatX
x = xvector('x')
f = function([x],[cast(x,'floatX')])
# print f.maker.env.toposort()
assert len(f.maker.env.toposort()) == 0
if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
default_mode = compile.Mode(linker = 'c&py',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论