提交 4e7211fc authored 作者: James Bergstra's avatar James Bergstra

Modified cast in tensor/basic to return its input if it already has the right dtype.

上级 e527fda1
...@@ -973,6 +973,9 @@ _cast_mapping = {'int8': _convert_to_int8, ...@@ -973,6 +973,9 @@ _cast_mapping = {'int8': _convert_to_int8,
@constructor @constructor
def cast(x, dtype): def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`.""" """Symbolically cast `x` to a Tensor of type `dtype`."""
_x = as_tensor_variable(x)
if _x.type.dtype == dtype:
return _x
if x.type.dtype.startswith('complex') and not dtype.startswith('complex'): if x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()') raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()')
return _cast_mapping[dtype](x) return _cast_mapping[dtype](x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论