Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7ef44dfd
提交
7ef44dfd
authored
10月 22, 2015
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move nfunc_spec to Scalar op, This allow to always have it even when we use Elemwise(a_scalar_op).
This happend frequently in the grad computation of elemwise. This could speed up DebugMode. This fix tests crashX
上级
ed337d4e
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
153 行增加
和
60 行删除
+153
-60
basic.py
theano/scalar/basic.py
+95
-0
basic.py
theano/tensor/basic.py
+56
-60
elemwise.py
theano/tensor/elemwise.py
+2
-0
没有找到文件。
theano/scalar/basic.py
浏览文件 @
7ef44dfd
...
@@ -1031,6 +1031,7 @@ class LT(LogicalComparison):
...
@@ -1031,6 +1031,7 @@ class LT(LogicalComparison):
identity
=
False
identity
=
False
commutative
=
False
commutative
=
False
associative
=
False
associative
=
False
nfunc_spec
=
(
'less'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
# built-in < don't support complex
# built-in < don't support complex
...
@@ -1049,6 +1050,7 @@ class GT(LogicalComparison):
...
@@ -1049,6 +1050,7 @@ class GT(LogicalComparison):
identity
=
False
identity
=
False
commutative
=
False
commutative
=
False
associative
=
False
associative
=
False
nfunc_spec
=
(
'greater'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
# built-in > don't support complex
# built-in > don't support complex
...
@@ -1067,6 +1069,7 @@ class LE(LogicalComparison):
...
@@ -1067,6 +1069,7 @@ class LE(LogicalComparison):
identity
=
False
identity
=
False
commutative
=
False
commutative
=
False
associative
=
False
associative
=
False
nfunc_spec
=
(
'less_equal'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
# built-in <= don't support complex
# built-in <= don't support complex
...
@@ -1085,6 +1088,7 @@ class GE(LogicalComparison):
...
@@ -1085,6 +1088,7 @@ class GE(LogicalComparison):
identity
=
False
identity
=
False
commutative
=
False
commutative
=
False
associative
=
False
associative
=
False
nfunc_spec
=
(
'greater_equal'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
# built-in >= don't support complex
# built-in >= don't support complex
...
@@ -1103,6 +1107,7 @@ class EQ(LogicalComparison):
...
@@ -1103,6 +1107,7 @@ class EQ(LogicalComparison):
identity
=
False
identity
=
False
commutative
=
True
commutative
=
True
associative
=
False
associative
=
False
nfunc_spec
=
(
'equal'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
==
y
return
x
==
y
...
@@ -1118,6 +1123,7 @@ class NEQ(LogicalComparison):
...
@@ -1118,6 +1123,7 @@ class NEQ(LogicalComparison):
identity
=
False
identity
=
False
commutative
=
True
commutative
=
True
associative
=
False
associative
=
False
nfunc_spec
=
(
'not_equal'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
!=
y
return
x
!=
y
...
@@ -1132,6 +1138,8 @@ neq = NEQ()
...
@@ -1132,6 +1138,8 @@ neq = NEQ()
class
IsNan
(
FixedLogicalComparison
):
class
IsNan
(
FixedLogicalComparison
):
nfunc_spec
=
(
'isnan'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
isnan
(
x
)
return
numpy
.
isnan
(
x
)
...
@@ -1145,6 +1153,8 @@ isnan = IsNan()
...
@@ -1145,6 +1153,8 @@ isnan = IsNan()
class
IsInf
(
FixedLogicalComparison
):
class
IsInf
(
FixedLogicalComparison
):
nfunc_spec
=
(
'isinf'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
isinf
(
x
)
return
numpy
.
isinf
(
x
)
...
@@ -1223,6 +1233,7 @@ inclosedrange = InRange(False, False)
...
@@ -1223,6 +1233,7 @@ inclosedrange = InRange(False, False)
class
Switch
(
ScalarOp
):
class
Switch
(
ScalarOp
):
nin
=
3
nin
=
3
nfunc_spec
=
(
'where'
,
3
,
1
)
def
impl
(
self
,
cond
,
ift
,
iff
):
def
impl
(
self
,
cond
,
ift
,
iff
):
if
cond
:
if
cond
:
...
@@ -1296,6 +1307,7 @@ class OR(BinaryBitOp):
...
@@ -1296,6 +1307,7 @@ class OR(BinaryBitOp):
identity
=
0
identity
=
0
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'bitwise_or'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
|
y
return
x
|
y
...
@@ -1311,6 +1323,7 @@ class XOR(BinaryBitOp):
...
@@ -1311,6 +1323,7 @@ class XOR(BinaryBitOp):
identity
=
0
identity
=
0
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'bitwise_xor'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
^
y
return
x
^
y
...
@@ -1326,6 +1339,7 @@ class AND(BinaryBitOp):
...
@@ -1326,6 +1339,7 @@ class AND(BinaryBitOp):
identity
=
1
identity
=
1
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'bitwise_and'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
&
y
return
x
&
y
...
@@ -1338,6 +1352,8 @@ and_ = AND()
...
@@ -1338,6 +1352,8 @@ and_ = AND()
class
Invert
(
UnaryBitOp
):
class
Invert
(
UnaryBitOp
):
nfunc_spec
=
(
'invert'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
~
x
return
~
x
...
@@ -1354,6 +1370,7 @@ invert = Invert()
...
@@ -1354,6 +1370,7 @@ invert = Invert()
class
Maximum
(
BinaryScalarOp
):
class
Maximum
(
BinaryScalarOp
):
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'maximum'
,
2
,
1
)
def
impl
(
self
,
*
inputs
):
def
impl
(
self
,
*
inputs
):
# The built-in max function don't support complex type
# The built-in max function don't support complex type
...
@@ -1392,6 +1409,7 @@ maximum = Maximum(upcast_out, name='maximum')
...
@@ -1392,6 +1409,7 @@ maximum = Maximum(upcast_out, name='maximum')
class
Minimum
(
BinaryScalarOp
):
class
Minimum
(
BinaryScalarOp
):
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'minimum'
,
2
,
1
)
def
impl
(
self
,
*
inputs
):
def
impl
(
self
,
*
inputs
):
# The built-in min function don't support complex type
# The built-in min function don't support complex type
...
@@ -1427,6 +1445,7 @@ class Add(ScalarOp):
...
@@ -1427,6 +1445,7 @@ class Add(ScalarOp):
identity
=
0
identity
=
0
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'add'
,
2
,
1
)
def
impl
(
self
,
*
inputs
):
def
impl
(
self
,
*
inputs
):
return
sum
(
inputs
)
return
sum
(
inputs
)
...
@@ -1465,6 +1484,7 @@ class Mul(ScalarOp):
...
@@ -1465,6 +1484,7 @@ class Mul(ScalarOp):
identity
=
1
identity
=
1
commutative
=
True
commutative
=
True
associative
=
True
associative
=
True
nfunc_spec
=
(
'multiply'
,
2
,
1
)
def
impl
(
self
,
*
inputs
):
def
impl
(
self
,
*
inputs
):
return
numpy
.
product
(
inputs
)
return
numpy
.
product
(
inputs
)
...
@@ -1516,6 +1536,8 @@ mul = Mul(upcast_out, name='mul')
...
@@ -1516,6 +1536,8 @@ mul = Mul(upcast_out, name='mul')
class
Sub
(
BinaryScalarOp
):
class
Sub
(
BinaryScalarOp
):
nfunc_spec
=
(
'subtract'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
-
y
return
x
-
y
...
@@ -1604,6 +1626,8 @@ def div_proxy(x, y):
...
@@ -1604,6 +1626,8 @@ def div_proxy(x, y):
class
TrueDiv
(
BinaryScalarOp
):
class
TrueDiv
(
BinaryScalarOp
):
nfunc_spec
=
(
'true_divide'
,
2
,
1
)
def
output_types
(
self
,
types
):
def
output_types
(
self
,
types
):
if
all
(
t
in
discrete_types
for
t
in
types
):
if
all
(
t
in
discrete_types
for
t
in
types
):
return
[
get_scalar_type
(
config
.
floatX
)]
return
[
get_scalar_type
(
config
.
floatX
)]
...
@@ -1659,6 +1683,7 @@ true_div = TrueDiv(upcast_out, name='true_div')
...
@@ -1659,6 +1683,7 @@ true_div = TrueDiv(upcast_out, name='true_div')
class
IntDiv
(
BinaryScalarOp
):
class
IntDiv
(
BinaryScalarOp
):
nfunc_spec
=
(
'floor_divide'
,
2
,
1
)
complex_error
=
ComplexError
(
complex_error
=
ComplexError
(
"Theano does not support integer division (//) on "
"Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it."
)
"complex numbers, since numpy deprecated it."
)
...
@@ -1744,6 +1769,7 @@ def mod_check(x, y):
...
@@ -1744,6 +1769,7 @@ def mod_check(x, y):
class
Mod
(
BinaryScalarOp
):
class
Mod
(
BinaryScalarOp
):
nfunc_spec
=
(
'mod'
,
2
,
1
)
complex_error
=
ComplexError
(
complex_error
=
ComplexError
(
"Theano does not support the mod operator (
%
) on "
"Theano does not support the mod operator (
%
) on "
"complex numbers, since numpy deprecated it."
)
"complex numbers, since numpy deprecated it."
)
...
@@ -1828,6 +1854,8 @@ mod = Mod(upcast_out, name='mod')
...
@@ -1828,6 +1854,8 @@ mod = Mod(upcast_out, name='mod')
class
Pow
(
BinaryScalarOp
):
class
Pow
(
BinaryScalarOp
):
nfunc_spec
=
(
'power'
,
2
,
1
)
def
impl
(
self
,
x
,
y
):
def
impl
(
self
,
x
,
y
):
return
x
**
y
return
x
**
y
...
@@ -1903,6 +1931,8 @@ pow = Pow(upcast_out, name='pow')
...
@@ -1903,6 +1931,8 @@ pow = Pow(upcast_out, name='pow')
class
Clip
(
ScalarOp
):
class
Clip
(
ScalarOp
):
nin
=
3
nin
=
3
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use nfunc_spec = ('clip', 3, 1)
def
impl
(
self
,
x
,
min
,
max
):
def
impl
(
self
,
x
,
min
,
max
):
if
x
<
min
:
if
x
<
min
:
...
@@ -2086,6 +2116,8 @@ def cast(x, dtype):
...
@@ -2086,6 +2116,8 @@ def cast(x, dtype):
class
Abs
(
UnaryScalarOp
):
class
Abs
(
UnaryScalarOp
):
nfunc_spec
=
(
'abs'
,
1
,
1
)
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
inputs
=
[
as_scalar
(
input
)
for
input
in
[
x
]]
inputs
=
[
as_scalar
(
input
)
for
input
in
[
x
]]
if
inputs
[
0
]
.
type
==
complex64
:
if
inputs
[
0
]
.
type
==
complex64
:
...
@@ -2126,6 +2158,8 @@ abs_ = Abs(same_out)
...
@@ -2126,6 +2158,8 @@ abs_ = Abs(same_out)
class
Sgn
(
UnaryScalarOp
):
class
Sgn
(
UnaryScalarOp
):
nfunc_spec
=
(
'sign'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# casting to output type is handled by filter
# casting to output type is handled by filter
return
numpy
.
sign
(
x
)
return
numpy
.
sign
(
x
)
...
@@ -2162,6 +2196,8 @@ sgn = Sgn(same_out_nocomplex, name='sgn')
...
@@ -2162,6 +2196,8 @@ sgn = Sgn(same_out_nocomplex, name='sgn')
class
Ceil
(
UnaryScalarOp
):
class
Ceil
(
UnaryScalarOp
):
nfunc_spec
=
(
'ceil'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
ceil
(
x
)
return
numpy
.
ceil
(
x
)
...
@@ -2183,6 +2219,8 @@ ceil = Ceil(same_out_nocomplex, name='ceil')
...
@@ -2183,6 +2219,8 @@ ceil = Ceil(same_out_nocomplex, name='ceil')
class
Floor
(
UnaryScalarOp
):
class
Floor
(
UnaryScalarOp
):
nfunc_spec
=
(
'floor'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
floor
(
x
)
return
numpy
.
floor
(
x
)
...
@@ -2204,6 +2242,8 @@ floor = Floor(same_out_nocomplex, name='floor')
...
@@ -2204,6 +2242,8 @@ floor = Floor(same_out_nocomplex, name='floor')
class
Trunc
(
UnaryScalarOp
):
class
Trunc
(
UnaryScalarOp
):
nfunc_spec
=
(
'trunc'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
trunc
(
x
)
return
numpy
.
trunc
(
x
)
...
@@ -2227,6 +2267,8 @@ class RoundHalfToEven(UnaryScalarOp):
...
@@ -2227,6 +2267,8 @@ class RoundHalfToEven(UnaryScalarOp):
See http://en.wikipedia.org/wiki/Rounding for more details.
See http://en.wikipedia.org/wiki/Rounding for more details.
"""
"""
nfunc_spec
=
(
'around'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
round
(
x
)
return
numpy
.
round
(
x
)
...
@@ -2348,6 +2390,8 @@ round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only)
...
@@ -2348,6 +2390,8 @@ round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only)
class
Neg
(
UnaryScalarOp
):
class
Neg
(
UnaryScalarOp
):
nfunc_spec
=
(
'negative'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
-
x
return
-
x
...
@@ -2413,6 +2457,7 @@ class Log(UnaryScalarOp):
...
@@ -2413,6 +2457,7 @@ class Log(UnaryScalarOp):
log base e.
log base e.
"""
"""
nfunc_spec
=
(
'log'
,
1
,
1
)
amd_float32
=
"amd_vrsa_logf"
amd_float32
=
"amd_vrsa_logf"
amd_float64
=
"amd_vrda_log"
amd_float64
=
"amd_vrda_log"
...
@@ -2454,6 +2499,7 @@ class Log2(UnaryScalarOp):
...
@@ -2454,6 +2499,7 @@ class Log2(UnaryScalarOp):
log base 2.
log base 2.
"""
"""
nfunc_spec
=
(
'log2'
,
1
,
1
)
amd_float32
=
"amd_vrsa_log2f"
amd_float32
=
"amd_vrsa_log2f"
amd_float64
=
"amd_vrda_log2"
amd_float64
=
"amd_vrda_log2"
...
@@ -2492,6 +2538,7 @@ class Log10(UnaryScalarOp):
...
@@ -2492,6 +2538,7 @@ class Log10(UnaryScalarOp):
log base 10.
log base 10.
"""
"""
nfunc_spec
=
(
'log10'
,
1
,
1
)
amd_float32
=
"amd_vrsa_log10f"
amd_float32
=
"amd_vrsa_log10f"
amd_float64
=
"amd_vrda_log10"
amd_float64
=
"amd_vrda_log10"
...
@@ -2530,6 +2577,8 @@ class Log1p(UnaryScalarOp):
...
@@ -2530,6 +2577,8 @@ class Log1p(UnaryScalarOp):
log(1+x).
log(1+x).
"""
"""
nfunc_spec
=
(
'log1p'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.log1p will compute the result in
# If x is an int8 or uint8, numpy.log1p will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2561,6 +2610,7 @@ log1p = Log1p(upgrade_to_float, name='log1p')
...
@@ -2561,6 +2610,7 @@ log1p = Log1p(upgrade_to_float, name='log1p')
class
Exp
(
UnaryScalarOp
):
class
Exp
(
UnaryScalarOp
):
nfunc_spec
=
(
'exp'
,
1
,
1
)
amd_float32
=
"amd_vrsa_expf"
amd_float32
=
"amd_vrsa_expf"
amd_float64
=
"amd_vrda_exp"
amd_float64
=
"amd_vrda_exp"
...
@@ -2595,6 +2645,8 @@ exp = Exp(upgrade_to_float, name='exp')
...
@@ -2595,6 +2645,8 @@ exp = Exp(upgrade_to_float, name='exp')
class
Exp2
(
UnaryScalarOp
):
class
Exp2
(
UnaryScalarOp
):
nfunc_spec
=
(
'exp2'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.exp2 will compute the result in
# If x is an int8 or uint8, numpy.exp2 will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2626,6 +2678,8 @@ exp2 = Exp2(upgrade_to_float, name='exp2')
...
@@ -2626,6 +2678,8 @@ exp2 = Exp2(upgrade_to_float, name='exp2')
class
Expm1
(
UnaryScalarOp
):
class
Expm1
(
UnaryScalarOp
):
nfunc_spec
=
(
'expm1'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.expm1 will compute the result in
# If x is an int8 or uint8, numpy.expm1 will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2660,6 +2714,8 @@ expm1 = Expm1(upgrade_to_float, name='expm1')
...
@@ -2660,6 +2714,8 @@ expm1 = Expm1(upgrade_to_float, name='expm1')
class
Sqr
(
UnaryScalarOp
):
class
Sqr
(
UnaryScalarOp
):
nfunc_spec
=
(
'square'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
x
*
x
return
x
*
x
...
@@ -2684,6 +2740,8 @@ sqr = Sqr(same_out, name='sqr')
...
@@ -2684,6 +2740,8 @@ sqr = Sqr(same_out, name='sqr')
class
Sqrt
(
UnaryScalarOp
):
class
Sqrt
(
UnaryScalarOp
):
nfunc_spec
=
(
'sqrt'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.sqrt will compute the result in
# If x is an int8 or uint8, numpy.sqrt will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2715,6 +2773,8 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt')
...
@@ -2715,6 +2773,8 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt')
class
Deg2Rad
(
UnaryScalarOp
):
class
Deg2Rad
(
UnaryScalarOp
):
nfunc_spec
=
(
'deg2rad'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.deg2rad will compute the result in
# If x is an int8 or uint8, numpy.deg2rad will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2746,6 +2806,8 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad')
...
@@ -2746,6 +2806,8 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad')
class
Rad2Deg
(
UnaryScalarOp
):
class
Rad2Deg
(
UnaryScalarOp
):
nfunc_spec
=
(
'rad2deg'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.rad2deg will compute the result in
# If x is an int8 or uint8, numpy.rad2deg will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2777,6 +2839,7 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg')
...
@@ -2777,6 +2839,7 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg')
class
Cos
(
UnaryScalarOp
):
class
Cos
(
UnaryScalarOp
):
nfunc_spec
=
(
'cos'
,
1
,
1
)
amd_float32
=
"amd_vrsa_cosf"
amd_float32
=
"amd_vrsa_cosf"
amd_float64
=
"amd_vrda_cos"
amd_float64
=
"amd_vrda_cos"
...
@@ -2811,6 +2874,8 @@ cos = Cos(upgrade_to_float, name='cos')
...
@@ -2811,6 +2874,8 @@ cos = Cos(upgrade_to_float, name='cos')
class
ArcCos
(
UnaryScalarOp
):
class
ArcCos
(
UnaryScalarOp
):
nfunc_spec
=
(
'arccos'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.arccos will compute the result in
# If x is an int8 or uint8, numpy.arccos will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2842,6 +2907,7 @@ arccos = ArcCos(upgrade_to_float, name='arccos')
...
@@ -2842,6 +2907,7 @@ arccos = ArcCos(upgrade_to_float, name='arccos')
class
Sin
(
UnaryScalarOp
):
class
Sin
(
UnaryScalarOp
):
nfunc_spec
=
(
'sin'
,
1
,
1
)
amd_float32
=
"amd_vrsa_sinf"
amd_float32
=
"amd_vrsa_sinf"
amd_float64
=
"amd_vrda_sin"
amd_float64
=
"amd_vrda_sin"
...
@@ -2876,6 +2942,8 @@ sin = Sin(upgrade_to_float, name='sin')
...
@@ -2876,6 +2942,8 @@ sin = Sin(upgrade_to_float, name='sin')
class
ArcSin
(
UnaryScalarOp
):
class
ArcSin
(
UnaryScalarOp
):
nfunc_spec
=
(
'arcsin'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.arcsin will compute the result in
# If x is an int8 or uint8, numpy.arcsin will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2907,6 +2975,8 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin')
...
@@ -2907,6 +2975,8 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin')
class
Tan
(
UnaryScalarOp
):
class
Tan
(
UnaryScalarOp
):
nfunc_spec
=
(
'tan'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.tan will compute the result in
# If x is an int8 or uint8, numpy.tan will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2938,6 +3008,8 @@ tan = Tan(upgrade_to_float, name='tan')
...
@@ -2938,6 +3008,8 @@ tan = Tan(upgrade_to_float, name='tan')
class
ArcTan
(
UnaryScalarOp
):
class
ArcTan
(
UnaryScalarOp
):
nfunc_spec
=
(
'arctan'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.arctan will compute the result in
# If x is an int8 or uint8, numpy.arctan will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -2969,6 +3041,8 @@ arctan = ArcTan(upgrade_to_float, name='arctan')
...
@@ -2969,6 +3041,8 @@ arctan = ArcTan(upgrade_to_float, name='arctan')
class
ArcTan2
(
BinaryScalarOp
):
class
ArcTan2
(
BinaryScalarOp
):
nfunc_spec
=
(
'arctan2'
,
1
,
1
)
def
impl
(
self
,
y
,
x
):
def
impl
(
self
,
y
,
x
):
# If x and y are int8 or uint8, numpy.arctan2 will compute the result
# If x and y are int8 or uint8, numpy.arctan2 will compute the result
# in half-precision (float16), where we want float32.
# in half-precision (float16), where we want float32.
...
@@ -3016,6 +3090,8 @@ class Cosh(UnaryScalarOp):
...
@@ -3016,6 +3090,8 @@ class Cosh(UnaryScalarOp):
cosh(x) = (exp(x) + exp(-x)) / 2.
cosh(x) = (exp(x) + exp(-x)) / 2.
"""
"""
nfunc_spec
=
(
'cosh'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.cosh will compute the result in
# If x is an int8 or uint8, numpy.cosh will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -3047,6 +3123,8 @@ cosh = Cosh(upgrade_to_float, name='cosh')
...
@@ -3047,6 +3123,8 @@ cosh = Cosh(upgrade_to_float, name='cosh')
class
ArcCosh
(
UnaryScalarOp
):
class
ArcCosh
(
UnaryScalarOp
):
nfunc_spec
=
(
'arccosh'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.arccosh will compute the result in
# If x is an int8 or uint8, numpy.arccosh will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -3082,6 +3160,8 @@ class Sinh(UnaryScalarOp):
...
@@ -3082,6 +3160,8 @@ class Sinh(UnaryScalarOp):
sinh(x) = (exp(x) - exp(-x)) / 2.
sinh(x) = (exp(x) - exp(-x)) / 2.
"""
"""
nfunc_spec
=
(
'sinh'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.sinh will compute the result in
# If x is an int8 or uint8, numpy.sinh will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -3113,6 +3193,8 @@ sinh = Sinh(upgrade_to_float, name='sinh')
...
@@ -3113,6 +3193,8 @@ sinh = Sinh(upgrade_to_float, name='sinh')
class
ArcSinh
(
UnaryScalarOp
):
class
ArcSinh
(
UnaryScalarOp
):
nfunc_spec
=
(
'arcsinh'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.arcsinh will compute the result in
# If x is an int8 or uint8, numpy.arcsinh will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -3149,6 +3231,8 @@ class Tanh(UnaryScalarOp):
...
@@ -3149,6 +3231,8 @@ class Tanh(UnaryScalarOp):
= (exp(2*x) - 1) / (exp(2*x) + 1).
= (exp(2*x) - 1) / (exp(2*x) + 1).
"""
"""
nfunc_spec
=
(
'tanh'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.tanh will compute the result in
# If x is an int8 or uint8, numpy.tanh will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -3180,6 +3264,8 @@ tanh = Tanh(upgrade_to_float, name='tanh')
...
@@ -3180,6 +3264,8 @@ tanh = Tanh(upgrade_to_float, name='tanh')
class
ArcTanh
(
UnaryScalarOp
):
class
ArcTanh
(
UnaryScalarOp
):
nfunc_spec
=
(
'arctanh'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
# If x is an int8 or uint8, numpy.arctanh will compute the result in
# If x is an int8 or uint8, numpy.arctanh will compute the result in
# half-precision (float16), where we want float32.
# half-precision (float16), where we want float32.
...
@@ -3215,6 +3301,9 @@ class Real(UnaryScalarOp):
...
@@ -3215,6 +3301,9 @@ class Real(UnaryScalarOp):
Extract the real coordinate of a complex number.
Extract the real coordinate of a complex number.
"""
"""
# numpy.real(float32) return a view on the inputs.
# nfunc_spec = ('real', 1, 1)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
real
(
x
)
return
numpy
.
real
(
x
)
...
@@ -3227,6 +3316,8 @@ real = Real(real_out, name='real')
...
@@ -3227,6 +3316,8 @@ real = Real(real_out, name='real')
class
Imag
(
UnaryScalarOp
):
class
Imag
(
UnaryScalarOp
):
nfunc_spec
=
(
'imag'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
imag
(
x
)
return
numpy
.
imag
(
x
)
...
@@ -3244,6 +3335,8 @@ imag = Imag(real_out, name='imag')
...
@@ -3244,6 +3335,8 @@ imag = Imag(real_out, name='imag')
class
Angle
(
UnaryScalarOp
):
class
Angle
(
UnaryScalarOp
):
nfunc_spec
=
(
'angle'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
angle
(
x
)
return
numpy
.
angle
(
x
)
...
@@ -3303,6 +3396,8 @@ complex = Complex(name='complex')
...
@@ -3303,6 +3396,8 @@ complex = Complex(name='complex')
class
Conj
(
UnaryScalarOp
):
class
Conj
(
UnaryScalarOp
):
nfunc_spec
=
(
'conj'
,
1
,
1
)
def
impl
(
self
,
x
):
def
impl
(
self
,
x
):
return
numpy
.
conj
(
x
)
return
numpy
.
conj
(
x
)
conj
=
Conj
(
same_out
,
name
=
'conj'
)
conj
=
Conj
(
same_out
,
name
=
'conj'
)
...
...
theano/tensor/basic.py
浏览文件 @
7ef44dfd
...
@@ -1738,42 +1738,42 @@ def largest(*args):
...
@@ -1738,42 +1738,42 @@ def largest(*args):
# Comparison
# Comparison
##########################
##########################
@_scal_elemwise
_with_nfunc
(
'less'
,
2
,
1
)
@_scal_elemwise
def
lt
(
a
,
b
):
def
lt
(
a
,
b
):
"""a < b"""
"""a < b"""
@_scal_elemwise
_with_nfunc
(
'greater'
,
2
,
1
)
@_scal_elemwise
def
gt
(
a
,
b
):
def
gt
(
a
,
b
):
"""a > b"""
"""a > b"""
@_scal_elemwise
_with_nfunc
(
'less_equal'
,
2
,
1
)
@_scal_elemwise
def
le
(
a
,
b
):
def
le
(
a
,
b
):
"""a <= b"""
"""a <= b"""
@_scal_elemwise
_with_nfunc
(
'greater_equal'
,
2
,
1
)
@_scal_elemwise
def
ge
(
a
,
b
):
def
ge
(
a
,
b
):
"""a >= b"""
"""a >= b"""
@_scal_elemwise
_with_nfunc
(
'equal'
,
2
,
1
)
@_scal_elemwise
def
eq
(
a
,
b
):
def
eq
(
a
,
b
):
"""a == b"""
"""a == b"""
@_scal_elemwise
_with_nfunc
(
'not_equal'
,
2
,
1
)
@_scal_elemwise
def
neq
(
a
,
b
):
def
neq
(
a
,
b
):
"""a != b"""
"""a != b"""
@_scal_elemwise
_with_nfunc
(
'isnan'
,
1
,
1
)
@_scal_elemwise
def
isnan
(
a
):
def
isnan
(
a
):
"""isnan(a)"""
"""isnan(a)"""
@_scal_elemwise
_with_nfunc
(
'isinf'
,
1
,
1
)
@_scal_elemwise
def
isinf
(
a
):
def
isinf
(
a
):
"""isinf(a)"""
"""isinf(a)"""
...
@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
...
@@ -1922,7 +1922,7 @@ def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
# Condition
# Condition
##########################
##########################
@_scal_elemwise
_with_nfunc
(
'where'
,
3
,
1
)
@_scal_elemwise
def
switch
(
cond
,
ift
,
iff
):
def
switch
(
cond
,
ift
,
iff
):
"""if cond then ift else iff"""
"""if cond then ift else iff"""
...
@@ -1932,25 +1932,25 @@ where = switch
...
@@ -1932,25 +1932,25 @@ where = switch
##########################
##########################
@_scal_elemwise
_with_nfunc
(
'bitwise_and'
,
2
,
1
)
@_scal_elemwise
def
and_
(
a
,
b
):
def
and_
(
a
,
b
):
"""bitwise a & b"""
"""bitwise a & b"""
bitwise_and
=
and_
# numpy name for it
bitwise_and
=
and_
# numpy name for it
@_scal_elemwise
_with_nfunc
(
'bitwise_or'
,
2
,
1
)
@_scal_elemwise
def
or_
(
a
,
b
):
def
or_
(
a
,
b
):
"""bitwise a | b"""
"""bitwise a | b"""
bitwise_or
=
or_
# numpy name for it
bitwise_or
=
or_
# numpy name for it
@_scal_elemwise
_with_nfunc
(
'bitwise_xor'
,
2
,
1
)
@_scal_elemwise
def
xor
(
a
,
b
):
def
xor
(
a
,
b
):
"""bitwise a ^ b"""
"""bitwise a ^ b"""
bitwise_xor
=
xor
# numpy name for it
bitwise_xor
=
xor
# numpy name for it
@_scal_elemwise
_with_nfunc
(
'invert'
,
1
,
1
)
@_scal_elemwise
def
invert
(
a
):
def
invert
(
a
):
"""bitwise ~a"""
"""bitwise ~a"""
bitwise_not
=
invert
# numpy alias for it
bitwise_not
=
invert
# numpy alias for it
...
@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it
...
@@ -1960,7 +1960,7 @@ bitwise_not = invert # numpy alias for it
# Math
# Math
##########################
##########################
@_scal_elemwise
_with_nfunc
(
'abs'
,
1
,
1
)
@_scal_elemwise
def
abs_
(
a
):
def
abs_
(
a
):
"""|`a`|
"""|`a`|
...
@@ -1972,22 +1972,22 @@ def abs_(a):
...
@@ -1972,22 +1972,22 @@ def abs_(a):
pprint
.
assign
(
abs_
,
printing
.
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
pprint
.
assign
(
abs_
,
printing
.
PatternPrinter
((
'|
%(0)
s|'
,
-
1000
)))
@_scal_elemwise
_with_nfunc
(
'exp'
,
1
,
1
)
@_scal_elemwise
def
exp
(
a
):
def
exp
(
a
):
"""e^`a`"""
"""e^`a`"""
@_scal_elemwise
_with_nfunc
(
'exp2'
,
1
,
1
)
@_scal_elemwise
def
exp2
(
a
):
def
exp2
(
a
):
"""2^`a`"""
"""2^`a`"""
@_scal_elemwise
_with_nfunc
(
'expm1'
,
1
,
1
)
@_scal_elemwise
def
expm1
(
a
):
def
expm1
(
a
):
"""e^`a` - 1"""
"""e^`a` - 1"""
@_scal_elemwise
_with_nfunc
(
'negative'
,
1
,
1
)
@_scal_elemwise
def
neg
(
a
):
def
neg
(
a
):
"""-a"""
"""-a"""
...
@@ -1999,42 +1999,42 @@ def inv(a):
...
@@ -1999,42 +1999,42 @@ def inv(a):
"""1.0/a"""
"""1.0/a"""
@_scal_elemwise
_with_nfunc
(
'log'
,
1
,
1
)
@_scal_elemwise
def
log
(
a
):
def
log
(
a
):
"""base e logarithm of a"""
"""base e logarithm of a"""
@_scal_elemwise
_with_nfunc
(
'log2'
,
1
,
1
)
@_scal_elemwise
def
log2
(
a
):
def
log2
(
a
):
"""base 2 logarithm of a"""
"""base 2 logarithm of a"""
@_scal_elemwise
_with_nfunc
(
'log10'
,
1
,
1
)
@_scal_elemwise
def
log10
(
a
):
def
log10
(
a
):
"""base 10 logarithm of a"""
"""base 10 logarithm of a"""
@_scal_elemwise
_with_nfunc
(
'log1p'
,
1
,
1
)
@_scal_elemwise
def
log1p
(
a
):
def
log1p
(
a
):
"""log(1+a)"""
"""log(1+a)"""
@_scal_elemwise
_with_nfunc
(
'sign'
,
1
,
1
)
@_scal_elemwise
def
sgn
(
a
):
def
sgn
(
a
):
"""sign of a"""
"""sign of a"""
@_scal_elemwise
_with_nfunc
(
'ceil'
,
1
,
1
)
@_scal_elemwise
def
ceil
(
a
):
def
ceil
(
a
):
"""ceiling of a"""
"""ceiling of a"""
@_scal_elemwise
_with_nfunc
(
'floor'
,
1
,
1
)
@_scal_elemwise
def
floor
(
a
):
def
floor
(
a
):
"""floor of a"""
"""floor of a"""
@_scal_elemwise
_with_nfunc
(
'trunc'
,
1
,
1
)
@_scal_elemwise
def
trunc
(
a
):
def
trunc
(
a
):
"""trunc of a"""
"""trunc of a"""
...
@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"):
...
@@ -2056,7 +2056,7 @@ def round(a, mode="half_away_from_zero"):
raise
Exception
(
"round mode
%
s is not implemented."
%
mode
)
raise
Exception
(
"round mode
%
s is not implemented."
%
mode
)
@_scal_elemwise
_with_nfunc
(
'around'
,
1
,
1
)
@_scal_elemwise
def
round_half_to_even
(
a
):
def
round_half_to_even
(
a
):
"""round_half_to_even(a)"""
"""round_half_to_even(a)"""
...
@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a):
...
@@ -2066,7 +2066,7 @@ def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)"""
"""round_half_away_from_zero(a)"""
@_scal_elemwise
_with_nfunc
(
'square'
,
1
,
1
)
@_scal_elemwise
def
sqr
(
a
):
def
sqr
(
a
):
"""square of a"""
"""square of a"""
...
@@ -2075,82 +2075,82 @@ def sqr(a):
...
@@ -2075,82 +2075,82 @@ def sqr(a):
square
=
sqr
square
=
sqr
@_scal_elemwise
_with_nfunc
(
'sqrt'
,
1
,
1
)
@_scal_elemwise
def
sqrt
(
a
):
def
sqrt
(
a
):
"""square root of a"""
"""square root of a"""
@_scal_elemwise
_with_nfunc
(
'deg2rad'
,
1
,
1
)
@_scal_elemwise
def
deg2rad
(
a
):
def
deg2rad
(
a
):
"""convert degree a to radian"""
"""convert degree a to radian"""
@_scal_elemwise
_with_nfunc
(
'rad2deg'
,
1
,
1
)
@_scal_elemwise
def
rad2deg
(
a
):
def
rad2deg
(
a
):
"""convert radian a to degree"""
"""convert radian a to degree"""
@_scal_elemwise
_with_nfunc
(
'cos'
,
1
,
1
)
@_scal_elemwise
def
cos
(
a
):
def
cos
(
a
):
"""cosine of a"""
"""cosine of a"""
@_scal_elemwise
_with_nfunc
(
'arccos'
,
1
,
1
)
@_scal_elemwise
def
arccos
(
a
):
def
arccos
(
a
):
"""arccosine of a"""
"""arccosine of a"""
@_scal_elemwise
_with_nfunc
(
'sin'
,
1
,
1
)
@_scal_elemwise
def
sin
(
a
):
def
sin
(
a
):
"""sine of a"""
"""sine of a"""
@_scal_elemwise
_with_nfunc
(
'arcsin'
,
1
,
1
)
@_scal_elemwise
def
arcsin
(
a
):
def
arcsin
(
a
):
"""arcsine of a"""
"""arcsine of a"""
@_scal_elemwise
_with_nfunc
(
'tan'
,
1
,
1
)
@_scal_elemwise
def
tan
(
a
):
def
tan
(
a
):
"""tangent of a"""
"""tangent of a"""
@_scal_elemwise
_with_nfunc
(
'arctan'
,
1
,
1
)
@_scal_elemwise
def
arctan
(
a
):
def
arctan
(
a
):
"""arctangent of a"""
"""arctangent of a"""
@_scal_elemwise
_with_nfunc
(
'arctan2'
,
1
,
1
)
@_scal_elemwise
def
arctan2
(
a
,
b
):
def
arctan2
(
a
,
b
):
"""arctangent of a / b"""
"""arctangent of a / b"""
@_scal_elemwise
_with_nfunc
(
'cosh'
,
1
,
1
)
@_scal_elemwise
def
cosh
(
a
):
def
cosh
(
a
):
"""hyperbolic cosine of a"""
"""hyperbolic cosine of a"""
@_scal_elemwise
_with_nfunc
(
'arccosh'
,
1
,
1
)
@_scal_elemwise
def
arccosh
(
a
):
def
arccosh
(
a
):
"""hyperbolic arc cosine of a"""
"""hyperbolic arc cosine of a"""
@_scal_elemwise
_with_nfunc
(
'sinh'
,
1
,
1
)
@_scal_elemwise
def
sinh
(
a
):
def
sinh
(
a
):
"""hyperbolic sine of a"""
"""hyperbolic sine of a"""
@_scal_elemwise
_with_nfunc
(
'arcsinh'
,
1
,
1
)
@_scal_elemwise
def
arcsinh
(
a
):
def
arcsinh
(
a
):
"""hyperbolic arc sine of a"""
"""hyperbolic arc sine of a"""
@_scal_elemwise
_with_nfunc
(
'tanh'
,
1
,
1
)
@_scal_elemwise
def
tanh
(
a
):
def
tanh
(
a
):
"""hyperbolic tangent of a"""
"""hyperbolic tangent of a"""
@_scal_elemwise
_with_nfunc
(
'arctanh'
,
1
,
1
)
@_scal_elemwise
def
arctanh
(
a
):
def
arctanh
(
a
):
"""hyperbolic arc tangent of a"""
"""hyperbolic arc tangent of a"""
...
@@ -2200,21 +2200,19 @@ def chi2sf(x, k):
...
@@ -2200,21 +2200,19 @@ def chi2sf(x, k):
"""chi squared survival function"""
"""chi squared survival function"""
# numpy.real(float32) return a view on the inputs.
# @_scal_elemwise_with_nfunc('real', 1, 1)
@_scal_elemwise
@_scal_elemwise
def
real
(
z
):
def
real
(
z
):
"""Return real component of complex-valued tensor `z`"""
"""Return real component of complex-valued tensor `z`"""
_tensor_py_operators
.
real
=
property
(
real
)
_tensor_py_operators
.
real
=
property
(
real
)
@_scal_elemwise
_with_nfunc
(
'imag'
,
1
,
1
)
@_scal_elemwise
def
imag
(
z
):
def
imag
(
z
):
"""Return imaginary component of complex-valued tensor `z`"""
"""Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators
.
imag
=
property
(
imag
)
_tensor_py_operators
.
imag
=
property
(
imag
)
@_scal_elemwise
_with_nfunc
(
'angle'
,
1
,
1
)
@_scal_elemwise
def
angle
(
z
):
def
angle
(
z
):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
"""Return polar-coordinate angle of complex-valued tensor `z`"""
...
@@ -2224,7 +2222,7 @@ def complex(real, imag):
...
@@ -2224,7 +2222,7 @@ def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components"""
"""Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise
_with_nfunc
(
'conj'
,
1
,
1
)
@_scal_elemwise
def
conj
(
z
):
def
conj
(
z
):
"""Return the complex conjugate of `z`."""
"""Return the complex conjugate of `z`."""
...
@@ -3166,13 +3164,13 @@ setdefault = default # legacy
...
@@ -3166,13 +3164,13 @@ setdefault = default # legacy
##########################
##########################
# Arithmetics
# Arithmetics
##########################
##########################
@_scal_elemwise
_with_nfunc
(
'maximum'
,
2
,
1
)
@_scal_elemwise
def
maximum
(
x
,
y
):
def
maximum
(
x
,
y
):
"""elemwise maximum. See max for the maximum in one tensor"""
"""elemwise maximum. See max for the maximum in one tensor"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
_with_nfunc
(
'minimum'
,
2
,
1
)
@_scal_elemwise
def
minimum
(
x
,
y
):
def
minimum
(
x
,
y
):
"""elemwise minimum. See min for the minimum in one tensor"""
"""elemwise minimum. See min for the minimum in one tensor"""
# see decorator for function body
# see decorator for function body
...
@@ -3191,31 +3189,31 @@ def divmod(x, y):
...
@@ -3191,31 +3189,31 @@ def divmod(x, y):
return
floor_div
(
x
,
y
),
mod_check
(
x
,
y
)
return
floor_div
(
x
,
y
),
mod_check
(
x
,
y
)
@_scal_elemwise
_with_nfunc
(
'add'
,
2
,
1
)
@_scal_elemwise
def
add
(
a
,
*
other_terms
):
def
add
(
a
,
*
other_terms
):
"""elementwise addition"""
"""elementwise addition"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
_with_nfunc
(
'subtract'
,
2
,
1
)
@_scal_elemwise
def
sub
(
a
,
b
):
def
sub
(
a
,
b
):
"""elementwise subtraction"""
"""elementwise subtraction"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
_with_nfunc
(
'multiply'
,
2
,
1
)
@_scal_elemwise
def
mul
(
a
,
*
other_terms
):
def
mul
(
a
,
*
other_terms
):
"""elementwise multiplication"""
"""elementwise multiplication"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
_with_nfunc
(
'true_divide'
,
2
,
1
)
@_scal_elemwise
def
true_div
(
a
,
b
):
def
true_div
(
a
,
b
):
"""elementwise [true] division (inverse of multiplication)"""
"""elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
_with_nfunc
(
'floor_divide'
,
2
,
1
)
@_scal_elemwise
def
int_div
(
a
,
b
):
def
int_div
(
a
,
b
):
"""elementwise [floor] division (inverse of multiplication)"""
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
# see decorator for function body
...
@@ -3256,20 +3254,18 @@ def mod_check(x, y):
...
@@ -3256,20 +3254,18 @@ def mod_check(x, y):
return
mod
(
x
,
y
)
return
mod
(
x
,
y
)
@_scal_elemwise
_with_nfunc
(
'mod'
,
2
,
1
)
@_scal_elemwise
def
mod
(
a
,
b
):
def
mod
(
a
,
b
):
"""elementwise modulo"""
"""elementwise modulo"""
# see decorator for function body
# see decorator for function body
@_scal_elemwise
_with_nfunc
(
'power'
,
2
,
1
)
@_scal_elemwise
def
pow
(
a
,
b
):
def
pow
(
a
,
b
):
"""elementwise power"""
"""elementwise power"""
# see decorator for function body
# see decorator for function body
# The numpy.clip don't work correctly when the min is bigger then the max,
# So we do not use @scal_elemwise_with_nfunc('clip', 3, 1)
@_scal_elemwise
@_scal_elemwise
def
clip
(
x
,
min
,
max
):
def
clip
(
x
,
min
,
max
):
"""
"""
...
...
theano/tensor/elemwise.py
浏览文件 @
7ef44dfd
...
@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp):
...
@@ -503,6 +503,8 @@ class Elemwise(OpenMPOp):
self
.
ufunc
=
None
self
.
ufunc
=
None
self
.
nfunc
=
None
self
.
nfunc
=
None
if
nfunc_spec
is
None
:
nfunc_spec
=
getattr
(
scalar_op
,
'nfunc_spec'
,
None
)
self
.
nfunc_spec
=
nfunc_spec
self
.
nfunc_spec
=
nfunc_spec
if
nfunc_spec
:
if
nfunc_spec
:
self
.
nfunc
=
getattr
(
numpy
,
nfunc_spec
[
0
])
self
.
nfunc
=
getattr
(
numpy
,
nfunc_spec
[
0
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论