Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
27654022
提交
27654022
authored
10月 12, 2016
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add stuff so that scalar tests pass with the new bool type.
上级
146ef971
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
74 行增加
和
87 行删除
+74
-87
gradient.py
theano/gradient.py
+6
-7
basic.py
theano/scalar/basic.py
+62
-80
basic.py
theano/tensor/basic.py
+5
-0
type.py
theano/tensor/type.py
+1
-0
没有找到文件。
theano/gradient.py
浏览文件 @
27654022
...
@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None,
...
@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None,
# function, sure, but nonetheless one we can and should support.
# function, sure, but nonetheless one we can and should support.
# So before we try to cast it make sure it even has a dtype
# So before we try to cast it make sure it even has a dtype
if
(
hasattr
(
g_cost
.
type
,
'dtype'
)
and
if
(
hasattr
(
g_cost
.
type
,
'dtype'
)
and
cost
.
type
.
dtype
not
in
tensor
.
discrete
_dtypes
):
cost
.
type
.
dtype
in
tensor
.
continuous
_dtypes
):
# Here we enforce the constraint that floating point variables
# Here we enforce the constraint that floating point variables
# have the same dtype as their gradient.
# have the same dtype as their gradient.
g_cost
=
g_cost
.
astype
(
cost
.
type
.
dtype
)
g_cost
=
g_cost
.
astype
(
cost
.
type
.
dtype
)
...
@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None,
...
@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None,
# This is to be enforced by the Op.grad method for the
# This is to be enforced by the Op.grad method for the
# Op that outputs cost.
# Op that outputs cost.
if
hasattr
(
g_cost
.
type
,
'dtype'
):
if
hasattr
(
g_cost
.
type
,
'dtype'
):
assert
g_cost
.
type
.
dtype
not
in
tensor
.
discrete
_dtypes
assert
g_cost
.
type
.
dtype
in
tensor
.
continuous
_dtypes
grad_dict
[
cost
]
=
g_cost
grad_dict
[
cost
]
=
g_cost
...
@@ -1334,12 +1334,11 @@ def _float_ones_like(x):
...
@@ -1334,12 +1334,11 @@ def _float_ones_like(x):
""" Like ones_like, but forces the object to have a
""" Like ones_like, but forces the object to have a
floating point dtype """
floating point dtype """
rval
=
tensor
.
ones_like
(
x
)
dtype
=
x
.
type
.
dtype
if
'float'
not
in
dtype
:
dtype
=
theano
.
config
.
floatX
if
rval
.
type
.
dtype
.
find
(
'float'
)
!=
-
1
:
return
tensor
.
ones_like
(
x
,
dtype
=
dtype
)
return
rval
return
rval
.
astype
(
theano
.
config
.
floatX
)
class
numeric_grad
(
object
):
class
numeric_grad
(
object
):
...
...
theano/scalar/basic.py
浏览文件 @
27654022
...
@@ -34,6 +34,7 @@ from theano.gradient import grad_undefined
...
@@ -34,6 +34,7 @@ from theano.gradient import grad_undefined
from
theano.printing
import
pprint
from
theano.printing
import
pprint
import
collections
import
collections
builtin_bool
=
bool
builtin_complex
=
complex
builtin_complex
=
complex
builtin_int
=
int
builtin_int
=
int
builtin_float
=
float
builtin_float
=
float
...
@@ -161,7 +162,7 @@ class Scalar(Type):
...
@@ -161,7 +162,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType.
TODO: refactor to be named ScalarType for consistency with TensorType.
"""
"""
__props__
=
(
'dtype'
,)
ndim
=
0
ndim
=
0
def
__init__
(
self
,
dtype
):
def
__init__
(
self
,
dtype
):
...
@@ -200,6 +201,8 @@ class Scalar(Type):
...
@@ -200,6 +201,8 @@ class Scalar(Type):
def
values_eq_approx
(
self
,
a
,
b
,
tolerance
=
1e-4
):
def
values_eq_approx
(
self
,
a
,
b
,
tolerance
=
1e-4
):
# The addition have risk of overflow especially with [u]int8
# The addition have risk of overflow especially with [u]int8
if
self
.
dtype
==
'bool'
:
return
a
==
b
diff
=
a
-
b
diff
=
a
-
b
if
diff
==
0
:
if
diff
==
0
:
return
True
return
True
...
@@ -227,12 +230,6 @@ class Scalar(Type):
...
@@ -227,12 +230,6 @@ class Scalar(Type):
else
:
else
:
return
[]
return
[]
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
other
.
dtype
==
self
.
dtype
def
__hash__
(
self
):
return
hash
(
'theano.scalar.Scalar'
)
^
hash
(
self
.
dtype
)
def
dtype_specs
(
self
):
def
dtype_specs
(
self
):
try
:
try
:
# To help debug dtype/typenum problem, here is code to get
# To help debug dtype/typenum problem, here is code to get
...
@@ -244,7 +241,8 @@ class Scalar(Type):
...
@@ -244,7 +241,8 @@ class Scalar(Type):
# now, as Theano always expect the exact typenum that
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
# correspond to our supported dtype.
"""
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
for dtype in ['bool', 'int8', 'uint8', 'short', 'ushort', 'intc',
'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
'float32', 'float64', 'int8', 'int16', 'int32',
...
@@ -260,6 +258,7 @@ class Scalar(Type):
...
@@ -260,6 +258,7 @@ class Scalar(Type):
'complex128'
:
(
numpy
.
complex128
,
'theano_complex128'
,
'complex128'
:
(
numpy
.
complex128
,
'theano_complex128'
,
'Complex128'
),
'Complex128'
),
'complex64'
:
(
numpy
.
complex64
,
'theano_complex64'
,
'Complex64'
),
'complex64'
:
(
numpy
.
complex64
,
'theano_complex64'
,
'Complex64'
),
'bool'
:
(
numpy
.
bool_
,
'npy_bool'
,
'Bool'
),
'uint8'
:
(
numpy
.
uint8
,
'npy_uint8'
,
'UInt8'
),
'uint8'
:
(
numpy
.
uint8
,
'npy_uint8'
,
'UInt8'
),
'int8'
:
(
numpy
.
int8
,
'npy_int8'
,
'Int8'
),
'int8'
:
(
numpy
.
int8
,
'npy_int8'
,
'Int8'
),
'uint16'
:
(
numpy
.
uint16
,
'npy_uint16'
,
'UInt16'
),
'uint16'
:
(
numpy
.
uint16
,
'npy_uint16'
,
'UInt16'
),
...
@@ -288,12 +287,13 @@ class Scalar(Type):
...
@@ -288,12 +287,13 @@ class Scalar(Type):
def
c_literal
(
self
,
data
):
def
c_literal
(
self
,
data
):
if
'complex'
in
self
.
dtype
:
if
'complex'
in
self
.
dtype
:
raise
NotImplementedError
(
"No literal for complex values."
)
raise
NotImplementedError
(
"No literal for complex values."
)
if
self
.
dtype
==
'bool'
:
return
'1'
if
b
else
'0'
return
str
(
data
)
return
str
(
data
)
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
if
(
check_input
):
if
(
check_input
):
pre
=
"""
pre
=
"""
typedef
%(dtype)
s
%(name)
s_dtype; // Deprecated use dtype_
%(name)
s instead.
typedef
%(dtype)
s dtype_
%(name)
s;
typedef
%(dtype)
s dtype_
%(name)
s;
"""
%
dict
(
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
])
"""
%
dict
(
name
=
name
,
dtype
=
self
.
dtype_specs
()[
1
])
else
:
else
:
...
@@ -309,6 +309,7 @@ class Scalar(Type):
...
@@ -309,6 +309,7 @@ class Scalar(Type):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
):
if
self
.
dtype
==
'float16'
:
if
self
.
dtype
==
'float16'
:
# This doesn't work at the numpy level
raise
NotImplementedError
(
'float16'
)
raise
NotImplementedError
(
'float16'
)
specs
=
self
.
dtype_specs
()
specs
=
self
.
dtype_specs
()
if
(
check_input
):
if
(
check_input
):
...
@@ -517,6 +518,7 @@ theano.compile.register_view_op_c_code(
...
@@ -517,6 +518,7 @@ theano.compile.register_view_op_c_code(
1
)
1
)
bool
=
get_scalar_type
(
'bool'
)
int8
=
get_scalar_type
(
'int8'
)
int8
=
get_scalar_type
(
'int8'
)
int16
=
get_scalar_type
(
'int16'
)
int16
=
get_scalar_type
(
'int16'
)
int32
=
get_scalar_type
(
'int32'
)
int32
=
get_scalar_type
(
'int32'
)
...
@@ -538,7 +540,7 @@ complex_types = complex64, complex128
...
@@ -538,7 +540,7 @@ complex_types = complex64, complex128
discrete_types
=
int_types
+
uint_types
discrete_types
=
int_types
+
uint_types
continuous_types
=
float_types
+
complex_types
continuous_types
=
float_types
+
complex_types
all_types
=
discrete_types
+
continuous_types
all_types
=
(
bool
,)
+
discrete_types
+
continuous_types
class
_scalar_py_operators
:
class
_scalar_py_operators
:
...
@@ -681,38 +683,35 @@ complexs64 = _multi(complex64)
...
@@ -681,38 +683,35 @@ complexs64 = _multi(complex64)
complexs128
=
_multi
(
complex128
)
complexs128
=
_multi
(
complex128
)
# Using a class instead of a function makes it possible to deep-copy it in
# Using a class instead of a function makes it possible to deep-copy it.
# Python 2.4.
# Note that currently only a few functions use this mechanism, because
# Note that currently only a few functions use this mechanism, because it is
# it is enough to make the test-suite pass. However, it may prove
# enough to make the test-suite pass with Python 2.4. However, it may prove
# necessary to use this same mechanism in other places as well in the
# necessary to use this same mechanism in other places as well in the future.
# future.
class
upcast_out
(
object
):
def
upcast_out
(
*
types
):
def
__new__
(
self
,
*
types
):
dtype
=
Scalar
.
upcast
(
*
types
)
dtype
=
Scalar
.
upcast
(
*
types
)
return
get_scalar_type
(
dtype
),
return
get_scalar_type
(
dtype
),
class
upgrade_to_float
(
object
):
def
__new__
(
self
,
*
types
):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
conv
=
{
int8
:
float32
,
int16
:
float32
,
int32
:
float64
,
int64
:
float64
,
uint8
:
float32
,
uint16
:
float32
,
uint32
:
float64
,
uint64
:
float64
}
return
get_scalar_type
(
Scalar
.
upcast
(
*
[
conv
.
get
(
type
,
type
)
for
type
in
types
])),
def
upgrade_to_float
(
*
types
):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
class
same_out
(
object
):
"""
def
__new__
(
self
,
type
):
conv
=
{
int8
:
float32
,
return
type
,
int16
:
float32
,
int32
:
float64
,
int64
:
float64
,
uint8
:
float32
,
uint16
:
float32
,
uint32
:
float64
,
uint64
:
float64
}
return
get_scalar_type
(
Scalar
.
upcast
(
*
[
conv
.
get
(
type
,
type
)
for
type
in
types
])),
def
same_out
(
type
):
return
type
,
def
upcast_out_no_complex
(
*
types
):
def
upcast_out_no_complex
(
*
types
):
...
@@ -728,6 +727,8 @@ def same_out_float_only(type):
...
@@ -728,6 +727,8 @@ def same_out_float_only(type):
class
transfer_type
(
gof
.
utils
.
object2
):
class
transfer_type
(
gof
.
utils
.
object2
):
__props__
=
(
'transfer'
,)
def
__init__
(
self
,
*
transfer
):
def
__init__
(
self
,
*
transfer
):
assert
all
(
type
(
x
)
in
[
int
,
str
]
or
x
is
None
for
x
in
transfer
)
assert
all
(
type
(
x
)
in
[
int
,
str
]
or
x
is
None
for
x
in
transfer
)
self
.
transfer
=
transfer
self
.
transfer
=
transfer
...
@@ -748,26 +749,16 @@ class transfer_type(gof.utils.object2):
...
@@ -748,26 +749,16 @@ class transfer_type(gof.utils.object2):
return
retval
return
retval
# return [upcast if i is None else types[i] for i in self.transfer]
# return [upcast if i is None else types[i] for i in self.transfer]
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
transfer
==
other
.
transfer
def
__hash__
(
self
):
return
hash
(
self
.
transfer
)
class
specific_out
(
gof
.
utils
.
object2
):
class
specific_out
(
gof
.
utils
.
object2
):
__props__
=
(
'spec'
,)
def
__init__
(
self
,
*
spec
):
def
__init__
(
self
,
*
spec
):
self
.
spec
=
spec
self
.
spec
=
spec
def
__call__
(
self
,
*
types
):
def
__call__
(
self
,
*
types
):
return
self
.
spec
return
self
.
spec
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
spec
==
other
.
spec
def
__hash__
(
self
):
return
hash
(
self
.
spec
)
def
int_out
(
*
types
):
def
int_out
(
*
types
):
return
int64
,
return
int64
,
...
@@ -1007,12 +998,12 @@ class BinaryScalarOp(ScalarOp):
...
@@ -1007,12 +998,12 @@ class BinaryScalarOp(ScalarOp):
class
LogicalComparison
(
BinaryScalarOp
):
class
LogicalComparison
(
BinaryScalarOp
):
def
output_types
(
self
,
*
input_dtypes
):
def
output_types
(
self
,
*
input_dtypes
):
return
[
int8
]
return
[
bool
]
def
grad
(
self
,
inputs
,
output_gradients
):
def
grad
(
self
,
inputs
,
output_gradients
):
x
,
y
=
inputs
x
,
y
=
inputs
out
=
self
(
x
,
y
)
out
=
self
(
x
,
y
)
assert
str
(
out
.
type
.
dtype
)
.
find
(
'int'
)
!=
-
1
assert
out
.
type
==
bool
return
[
x
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
),
return
[
x
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
),
y
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)]
y
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)]
...
@@ -1023,12 +1014,12 @@ class FixedLogicalComparison(UnaryScalarOp):
...
@@ -1023,12 +1014,12 @@ class FixedLogicalComparison(UnaryScalarOp):
"""
"""
def
output_types
(
self
,
*
input_dtypes
):
def
output_types
(
self
,
*
input_dtypes
):
return
[
int8
]
return
[
bool
]
def
grad
(
self
,
inputs
,
output_gradients
):
def
grad
(
self
,
inputs
,
output_gradients
):
x
,
=
inputs
x
,
=
inputs
out
=
self
(
x
)
out
=
self
(
x
)
assert
str
(
out
.
type
.
dtype
)
.
find
(
'int'
)
!=
-
1
assert
out
.
type
==
bool
return
[
x
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)]
return
[
x
.
zeros_like
()
.
astype
(
theano
.
config
.
floatX
)]
...
@@ -1202,21 +1193,10 @@ class InRange(LogicalComparison):
...
@@ -1202,21 +1193,10 @@ class InRange(LogicalComparison):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,
low
,
hi
)
=
inputs
(
x
,
low
,
hi
)
=
inputs
(
z
,)
=
outputs
(
z
,)
=
outputs
if
self
.
openlow
:
cmp1
=
'>'
else
:
cmp1
=
'>='
# backport
cmp1
=
'>'
if
self
.
openlow
else
'>='
# cmp1 = '>' if self.openlow else '>
='
cmp2
=
'<'
if
self
.
openhi
else
'<
='
if
self
.
openhi
:
cmp2
=
'<'
else
:
cmp2
=
'<='
# backport
# cmp2 = '<' if self.openhi else '<='
return
(
"
%(z)
s =
%(x)
s
%(cmp1)
s
%(low)
s &&"
return
(
"
%(z)
s =
%(x)
s
%(cmp1)
s
%(low)
s &&"
"
%(x)
s
%(cmp2)
s
%(hi)
s;"
%
locals
())
"
%(x)
s
%(cmp2)
s
%(hi)
s;"
%
locals
())
...
@@ -1247,13 +1227,8 @@ class Switch(ScalarOp):
...
@@ -1247,13 +1227,8 @@ class Switch(ScalarOp):
nfunc_spec
=
(
'where'
,
3
,
1
)
nfunc_spec
=
(
'where'
,
3
,
1
)
def
impl
(
self
,
cond
,
ift
,
iff
):
def
impl
(
self
,
cond
,
ift
,
iff
):
if
cond
:
return
ift
if
cond
else
iff
return
ift
else
:
return
iff
# backport
# return ift if cond else iff
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
cond
,
ift
,
iff
)
=
inputs
(
cond
,
ift
,
iff
)
=
inputs
(
z
,)
=
outputs
(
z
,)
=
outputs
...
@@ -1290,9 +1265,9 @@ switch = Switch()
...
@@ -1290,9 +1265,9 @@ switch = Switch()
class
UnaryBitOp
(
UnaryScalarOp
):
class
UnaryBitOp
(
UnaryScalarOp
):
def
output_types
(
self
,
*
input_types
):
def
output_types
(
self
,
*
input_types
):
for
i
in
input_types
[
0
]:
for
i
in
input_types
[
0
]:
if
i
not
in
(
int8
,
int16
,
int32
,
int64
):
if
i
not
in
(
(
bool
,)
+
discrete_types
):
raise
TypeError
(
'input to a BitOp must have type
int8,
'
raise
TypeError
(
'input to a BitOp must have type
(u)int8,
'
'
int16, int32 or int64...
not
%
s'
%
i
)
'
(u)int16, (u)int32 or (u)int64 or bool
not
%
s'
%
i
)
return
upcast_out
(
*
input_types
[
0
])
return
upcast_out
(
*
input_types
[
0
])
def
grad
(
self
,
inputs
,
output_gradients
):
def
grad
(
self
,
inputs
,
output_gradients
):
...
@@ -1302,10 +1277,13 @@ class UnaryBitOp(UnaryScalarOp):
...
@@ -1302,10 +1277,13 @@ class UnaryBitOp(UnaryScalarOp):
class
BinaryBitOp
(
BinaryScalarOp
):
class
BinaryBitOp
(
BinaryScalarOp
):
def
output_types
(
self
,
*
input_types
):
def
output_types
(
self
,
*
input_types
):
t0
,
t1
=
input_types
[
0
]
t0
,
t1
=
input_types
[
0
]
if
t0
==
bool
and
t1
==
bool
:
return
[
bool
]
for
i
in
input_types
[
0
]:
for
i
in
input_types
[
0
]:
if
i
not
in
(
int8
,
int16
,
int32
,
int64
):
if
i
not
in
discrete_types
:
raise
TypeError
(
'input to a BitOp must have type int8,'
raise
TypeError
(
'input to a BitOp must have type (u)int8, '
' int16, int32 or int64... not
%
s'
%
i
)
'(u)int16, (u)int32 or (u)int64 or '
'be all bools not
%
s'
%
i
)
return
upcast_out
(
*
input_types
[
0
])
return
upcast_out
(
*
input_types
[
0
])
def
grad
(
self
,
inputs
,
output_gradients
):
def
grad
(
self
,
inputs
,
output_gradients
):
...
@@ -1371,6 +1349,8 @@ class Invert(UnaryBitOp):
...
@@ -1371,6 +1349,8 @@ class Invert(UnaryBitOp):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
(
x
,)
=
inputs
(
z
,)
=
outputs
(
z
,)
=
outputs
if
node
.
outputs
[
0
]
.
type
==
bool
:
return
"
%(z)
s = (!
%(x)
s);"
%
locals
()
return
"
%(z)
s = (~
%(x)
s);"
%
locals
()
return
"
%(z)
s = (~
%(x)
s);"
%
locals
()
invert
=
Invert
()
invert
=
Invert
()
...
@@ -2079,6 +2059,7 @@ class Cast(UnaryScalarOp):
...
@@ -2079,6 +2059,7 @@ class Cast(UnaryScalarOp):
else
:
else
:
return
s
return
s
convert_to_bool
=
Cast
(
bool
,
name
=
'convert_to_bool'
)
convert_to_int8
=
Cast
(
int8
,
name
=
'convert_to_int8'
)
convert_to_int8
=
Cast
(
int8
,
name
=
'convert_to_int8'
)
convert_to_int16
=
Cast
(
int16
,
name
=
'convert_to_int16'
)
convert_to_int16
=
Cast
(
int16
,
name
=
'convert_to_int16'
)
convert_to_int32
=
Cast
(
int32
,
name
=
'convert_to_int32'
)
convert_to_int32
=
Cast
(
int32
,
name
=
'convert_to_int32'
)
...
@@ -2094,6 +2075,7 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
...
@@ -2094,6 +2075,7 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128
=
Cast
(
complex128
,
name
=
'convert_to_complex128'
)
convert_to_complex128
=
Cast
(
complex128
,
name
=
'convert_to_complex128'
)
_cast_mapping
=
{
_cast_mapping
=
{
'bool'
:
convert_to_bool
,
'int8'
:
convert_to_int8
,
'int8'
:
convert_to_int8
,
'int16'
:
convert_to_int16
,
'int16'
:
convert_to_int16
,
'int32'
:
convert_to_int32
,
'int32'
:
convert_to_int32
,
...
...
theano/tensor/basic.py
浏览文件 @
27654022
...
@@ -1246,6 +1246,10 @@ def _conversion(real_value, name):
...
@@ -1246,6 +1246,10 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the
# what types you are casting to what. That logic is implemented by the
# `cast()` function below.
# `cast()` function below.
_convert_to_bool
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
convert_to_bool
),
'bool'
)
"""Cast to boolean"""
_convert_to_int8
=
_conversion
(
_convert_to_int8
=
_conversion
(
elemwise
.
Elemwise
(
scal
.
convert_to_int8
),
'int8'
)
elemwise
.
Elemwise
(
scal
.
convert_to_int8
),
'int8'
)
"""Cast to 8-bit integer"""
"""Cast to 8-bit integer"""
...
@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion(
...
@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion(
"""Cast to double-precision complex"""
"""Cast to double-precision complex"""
_cast_mapping
=
{
_cast_mapping
=
{
'bool'
:
_convert_to_bool
,
'int8'
:
_convert_to_int8
,
'int8'
:
_convert_to_int8
,
'int16'
:
_convert_to_int16
,
'int16'
:
_convert_to_int16
,
'int32'
:
_convert_to_int32
,
'int32'
:
_convert_to_int32
,
...
...
theano/tensor/type.py
浏览文件 @
27654022
...
@@ -255,6 +255,7 @@ class TensorType(Type):
...
@@ -255,6 +255,7 @@ class TensorType(Type):
'float16'
:
(
float
,
'npy_float16'
,
'NPY_FLOAT16'
),
'float16'
:
(
float
,
'npy_float16'
,
'NPY_FLOAT16'
),
'float32'
:
(
float
,
'npy_float32'
,
'NPY_FLOAT32'
),
'float32'
:
(
float
,
'npy_float32'
,
'NPY_FLOAT32'
),
'float64'
:
(
float
,
'npy_float64'
,
'NPY_FLOAT64'
),
'float64'
:
(
float
,
'npy_float64'
,
'NPY_FLOAT64'
),
'bool'
:
(
bool
,
'npy_bool'
,
'NPY_BOOL'
),
'uint8'
:
(
int
,
'npy_uint8'
,
'NPY_UINT8'
),
'uint8'
:
(
int
,
'npy_uint8'
,
'NPY_UINT8'
),
'int8'
:
(
int
,
'npy_int8'
,
'NPY_INT8'
),
'int8'
:
(
int
,
'npy_int8'
,
'NPY_INT8'
),
'uint16'
:
(
int
,
'npy_uint16'
,
'NPY_UINT16'
),
'uint16'
:
(
int
,
'npy_uint16'
,
'NPY_UINT16'
),
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论