Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
639b0871
提交
639b0871
authored
5月 31, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 06, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Get rid of redundant checks in tracked node_rewriters
上级
cb2c40ba
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
369 行增加
和
438 行删除
+369
-438
math.py
pytensor/tensor/rewriting/math.py
+369
-438
没有找到文件。
pytensor/tensor/rewriting/math.py
浏览文件 @
639b0871
...
@@ -328,7 +328,7 @@ def local_func_inv(fgraph, node):
...
@@ -328,7 +328,7 @@ def local_func_inv(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
log
,
log1p
,
exp
,
expm1
])
def
local_exp_log
(
fgraph
,
node
):
def
local_exp_log
(
fgraph
,
node
):
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -368,7 +368,7 @@ def local_exp_log(fgraph, node):
...
@@ -368,7 +368,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@register_specialize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
exp
,
expm1
])
def
local_exp_log_nan_switch
(
fgraph
,
node
):
def
local_exp_log_nan_switch
(
fgraph
,
node
):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -431,11 +431,7 @@ def local_sumsqr2dot(fgraph, node):
...
@@ -431,11 +431,7 @@ def local_sumsqr2dot(fgraph, node):
``pt.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))``
``pt.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))``
and converts it to ``pt.dot(pt.sqr(G), pt.sqr(W).sum(axis=0))``.
and converts it to ``pt.dot(pt.sqr(G), pt.sqr(W).sum(axis=0))``.
"""
"""
if
(
if
node
.
op
.
axis
==
(
1
,
2
):
isinstance
(
node
.
op
,
Sum
)
and
isinstance
(
node
.
op
.
scalar_op
,
ps
.
Add
)
and
node
.
op
.
axis
==
(
1
,
2
)
):
in1
=
node
.
inputs
[
0
]
in1
=
node
.
inputs
[
0
]
out
=
node
.
outputs
[
0
]
out
=
node
.
outputs
[
0
]
...
@@ -479,7 +475,7 @@ def local_mul_exp_to_exp_add(fgraph, node):
...
@@ -479,7 +475,7 @@ def local_mul_exp_to_exp_add(fgraph, node):
n
.
owner
.
inputs
[
0
]
n
.
owner
.
inputs
[
0
]
for
n
in
node
.
inputs
for
n
in
node
.
inputs
if
n
.
owner
if
n
.
owner
and
hasattr
(
n
.
owner
.
op
,
"scalar_op"
)
and
isinstance
(
n
.
owner
.
op
,
Elemwise
)
and
isinstance
(
n
.
owner
.
op
.
scalar_op
,
ps
.
Exp
)
and
isinstance
(
n
.
owner
.
op
.
scalar_op
,
ps
.
Exp
)
]
]
# Can only do any rewrite if there are at least two exp-s
# Can only do any rewrite if there are at least two exp-s
...
@@ -523,7 +519,7 @@ def local_mul_pow_to_pow_add(fgraph, node):
...
@@ -523,7 +519,7 @@ def local_mul_pow_to_pow_add(fgraph, node):
for
n
in
node
.
inputs
:
for
n
in
node
.
inputs
:
if
(
if
(
n
.
owner
n
.
owner
and
hasattr
(
n
.
owner
.
op
,
"scalar_op"
)
and
isinstance
(
n
.
owner
.
op
,
Elemwise
)
and
isinstance
(
n
.
owner
.
op
.
scalar_op
,
ps
.
Pow
)
and
isinstance
(
n
.
owner
.
op
.
scalar_op
,
ps
.
Pow
)
):
):
base_node
=
n
.
owner
.
inputs
[
0
]
base_node
=
n
.
owner
.
inputs
[
0
]
...
@@ -567,28 +563,27 @@ def local_mul_pow_to_pow_add(fgraph, node):
...
@@ -567,28 +563,27 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
sub
])
def
local_expm1
(
fgraph
,
node
):
def
local_expm1
(
fgraph
,
node
):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
if
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
ps
.
Sub
):
in1
,
in2
=
node
.
inputs
in1
,
in2
=
node
.
inputs
out
=
node
.
outputs
[
0
]
out
=
node
.
outputs
[
0
]
if
(
if
(
in1
.
owner
in1
.
owner
and
isinstance
(
in1
.
owner
.
op
,
Elemwise
)
and
isinstance
(
in1
.
owner
.
op
,
Elemwise
)
and
isinstance
(
in1
.
owner
.
op
.
scalar_op
,
ps
.
Exp
)
and
isinstance
(
in1
.
owner
.
op
.
scalar_op
,
ps
.
Exp
)
and
extract_constant
(
in2
,
only_process_constants
=
False
)
==
1
and
extract_constant
(
in2
,
only_process_constants
=
False
)
==
1
):
):
in11
=
in1
.
owner
.
inputs
[
0
]
in11
=
in1
.
owner
.
inputs
[
0
]
new_out
=
expm1
(
in11
)
new_out
=
expm1
(
in11
)
if
new_out
.
dtype
!=
out
.
dtype
:
if
new_out
.
dtype
!=
out
.
dtype
:
new_out
=
cast
(
new_out
,
dtype
=
out
.
dtype
)
new_out
=
cast
(
new_out
,
dtype
=
out
.
dtype
)
if
not
out
.
type
.
is_super
(
new_out
.
type
):
if
not
out
.
type
.
is_super
(
new_out
.
type
):
return
return
return
[
new_out
]
return
[
new_out
]
@register_specialize
@register_specialize
...
@@ -625,8 +620,6 @@ def local_mul_switch_sink(fgraph, node):
...
@@ -625,8 +620,6 @@ def local_mul_switch_sink(fgraph, node):
part of the graph.
part of the graph.
"""
"""
if
node
.
op
!=
mul
:
return
False
for
idx
,
i
in
enumerate
(
node
.
inputs
):
for
idx
,
i
in
enumerate
(
node
.
inputs
):
if
i
.
owner
and
i
.
owner
.
op
==
switch
:
if
i
.
owner
and
i
.
owner
.
op
==
switch
:
switch_node
=
i
.
owner
switch_node
=
i
.
owner
...
@@ -705,8 +698,6 @@ def local_div_switch_sink(fgraph, node):
...
@@ -705,8 +698,6 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details.
See `local_mul_switch_sink` for more details.
"""
"""
if
node
.
op
!=
true_div
and
node
.
op
!=
int_div
:
return
False
op
=
node
.
op
op
=
node
.
op
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
switch
:
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
switch
:
switch_node
=
node
.
inputs
[
0
]
.
owner
switch_node
=
node
.
inputs
[
0
]
.
owner
...
@@ -1235,8 +1226,7 @@ register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canon
...
@@ -1235,8 +1226,7 @@ register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canon
@register_canonicalize
@register_canonicalize
@node_rewriter
([
neg
])
@node_rewriter
([
neg
])
def
local_neg_to_mul
(
fgraph
,
node
):
def
local_neg_to_mul
(
fgraph
,
node
):
if
node
.
op
==
neg
:
return
[
mul
(
np
.
array
(
-
1
,
dtype
=
node
.
inputs
[
0
]
.
dtype
),
node
.
inputs
[
0
])]
return
[
mul
(
np
.
array
(
-
1
,
dtype
=
node
.
inputs
[
0
]
.
dtype
),
node
.
inputs
[
0
])]
@register_specialize
@register_specialize
...
@@ -1347,17 +1337,12 @@ def local_sum_of_neg_to_neg_of_sum(fgraph, node):
...
@@ -1347,17 +1337,12 @@ def local_sum_of_neg_to_neg_of_sum(fgraph, node):
@register_specialize
@register_specialize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
sub
])
def
local_elemwise_sub_zeros
(
fgraph
,
node
):
def
local_elemwise_sub_zeros
(
fgraph
,
node
):
"""
"""
Elemwise{sub}(X,X) -> zeros_like(X)
Elemwise{sub}(X,X) -> zeros_like(X)
"""
"""
if
(
if
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]:
isinstance
(
node
.
op
,
Elemwise
)
and
node
.
op
.
scalar_op
.
nin
==
2
and
node
.
op
.
scalar_op
==
ps
.
sub
and
node
.
inputs
[
0
]
==
node
.
inputs
[
1
]
):
res
=
zeros_like
(
node
.
inputs
[
0
])
res
=
zeros_like
(
node
.
inputs
[
0
])
# Copy over stacktrace from previous output.
# Copy over stacktrace from previous output.
# This could help for failures due to out-of-memory.
# This could help for failures due to out-of-memory.
...
@@ -1400,8 +1385,6 @@ def local_useless_elemwise_comparison(fgraph, node):
...
@@ -1400,8 +1385,6 @@ def local_useless_elemwise_comparison(fgraph, node):
the graph easier to read.
the graph easier to read.
"""
"""
if
not
isinstance
(
node
.
op
,
Elemwise
):
return
if
node
.
op
.
scalar_op
.
nin
!=
2
:
if
node
.
op
.
scalar_op
.
nin
!=
2
:
return
return
...
@@ -1590,14 +1573,13 @@ def local_sum_prod_all_to_none(fgraph, node):
...
@@ -1590,14 +1573,13 @@ def local_sum_prod_all_to_none(fgraph, node):
Prod{0,1,...N} -> Prod{}
Prod{0,1,...N} -> Prod{}
"""
"""
if
isinstance
(
node
.
op
,
Sum
)
or
isinstance
(
node
.
op
,
Prod
):
op_type
=
Sum
if
isinstance
(
node
.
op
,
Sum
)
else
Prod
op_type
=
Sum
if
isinstance
(
node
.
op
,
Sum
)
else
Prod
# if all the axes are named, then use None as a shorthand
# if all the axes are named, then use None as a shorthand
# this permits more merging
# this permits more merging
if
node
.
op
.
axis
is
None
:
if
node
.
op
.
axis
is
None
:
return
return
if
set
(
node
.
op
.
axis
)
==
set
(
range
(
node
.
inputs
[
0
]
.
type
.
ndim
)):
if
set
(
node
.
op
.
axis
)
==
set
(
range
(
node
.
inputs
[
0
]
.
type
.
ndim
)):
return
[
op_type
(
axis
=
None
,
dtype
=
node
.
op
.
dtype
)(
node
.
inputs
[
0
])]
return
[
op_type
(
axis
=
None
,
dtype
=
node
.
op
.
dtype
)(
node
.
inputs
[
0
])]
@register_canonicalize
@register_canonicalize
...
@@ -1609,35 +1591,34 @@ def local_op_of_op(fgraph, node):
...
@@ -1609,35 +1591,34 @@ def local_op_of_op(fgraph, node):
Sum(Sum()) -> single Sum()
Sum(Sum()) -> single Sum()
"""
"""
if
isinstance
(
node
.
op
,
Prod
)
or
isinstance
(
node
.
op
,
Sum
):
op_type
=
Sum
if
isinstance
(
node
.
op
,
Sum
)
else
Prod
op_type
=
Sum
if
isinstance
(
node
.
op
,
Sum
)
else
Prod
(
node_inps
,)
=
node
.
inputs
(
node_inps
,)
=
node
.
inputs
out_dtype
=
node
.
op
.
dtype
out_dtype
=
node
.
op
.
dtype
# This is done to make sure the rewrite doesn't affect other
# This is done to make sure the rewrite doesn't affect other
# computations.
# computations.
if
len
(
fgraph
.
clients
[
node_inps
])
==
1
:
if
len
(
fgraph
.
clients
[
node_inps
])
==
1
:
if
node_inps
.
owner
and
(
isinstance
(
node_inps
.
owner
.
op
,
node
.
op
.
__class__
)):
if
node_inps
.
owner
and
(
isinstance
(
node_inps
.
owner
.
op
,
node
.
op
.
__class__
)):
# check to see either the inner or outer prod is doing a
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
# product over all axis, in which case we can remove it
if
node_inps
.
owner
.
op
.
axis
is
None
or
node
.
op
.
axis
is
None
:
if
node_inps
.
owner
.
op
.
axis
is
None
or
node
.
op
.
axis
is
None
:
return
[
op_type
(
None
,
dtype
=
out_dtype
)(
node_inps
.
owner
.
inputs
[
0
])]
return
[
op_type
(
None
,
dtype
=
out_dtype
)(
node_inps
.
owner
.
inputs
[
0
])]
# figure out which axes were in the original sum
# figure out which axes were in the original sum
newaxis
=
list
(
node_inps
.
owner
.
op
.
axis
)
newaxis
=
list
(
node_inps
.
owner
.
op
.
axis
)
for
i
in
node
.
op
.
axis
:
for
i
in
node
.
op
.
axis
:
new_i
=
i
new_i
=
i
for
ii
in
node_inps
.
owner
.
op
.
axis
:
for
ii
in
node_inps
.
owner
.
op
.
axis
:
if
new_i
>=
ii
:
if
new_i
>=
ii
:
new_i
+=
1
new_i
+=
1
assert
new_i
not
in
newaxis
assert
new_i
not
in
newaxis
newaxis
.
append
(
new_i
)
newaxis
.
append
(
new_i
)
assert
len
(
newaxis
)
==
len
(
assert
len
(
newaxis
)
==
len
(
list
(
node_inps
.
owner
.
op
.
axis
)
+
list
(
node
.
op
.
axis
)
list
(
node_inps
.
owner
.
op
.
axis
)
+
list
(
node
.
op
.
axis
)
)
)
combined
=
op_type
(
newaxis
,
dtype
=
out_dtype
)
combined
=
op_type
(
newaxis
,
dtype
=
out_dtype
)
return
[
combined
(
node_inps
.
owner
.
inputs
[
0
])]
return
[
combined
(
node_inps
.
owner
.
inputs
[
0
])]
ALL_REDUCE
=
[
ALL_REDUCE
=
[
...
@@ -1669,11 +1650,7 @@ def local_reduce_join(fgraph, node):
...
@@ -1669,11 +1650,7 @@ def local_reduce_join(fgraph, node):
where we join and reduce on the same set of axis.
where we join and reduce on the same set of axis.
"""
"""
if
(
if
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Join
):
isinstance
(
node
.
op
,
CAReduce
)
and
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Join
)
):
join_node
=
node
.
inputs
[
0
]
.
owner
join_node
=
node
.
inputs
[
0
]
.
owner
if
extract_constant
(
join_node
.
inputs
[
0
],
only_process_constants
=
True
)
!=
0
:
if
extract_constant
(
join_node
.
inputs
[
0
],
only_process_constants
=
True
)
!=
0
:
return
return
...
@@ -1732,11 +1709,10 @@ def local_reduce_join(fgraph, node):
...
@@ -1732,11 +1709,10 @@ def local_reduce_join(fgraph, node):
@node_rewriter
(
ALL_REDUCE
)
@node_rewriter
(
ALL_REDUCE
)
def
local_useless_reduce
(
fgraph
,
node
):
def
local_useless_reduce
(
fgraph
,
node
):
"""Sum(a, axis=[]) -> a"""
"""Sum(a, axis=[]) -> a"""
if
isinstance
(
node
.
op
,
CAReduce
):
(
summed
,)
=
node
.
inputs
(
summed
,)
=
node
.
inputs
# if reduce were doing anything, the output ndim would be reduced
# if reduce were doing anything, the output ndim would be reduced
if
summed
.
type
==
node
.
outputs
[
0
]
.
type
:
if
summed
.
type
==
node
.
outputs
[
0
]
.
type
:
return
[
summed
]
return
[
summed
]
@register_canonicalize
@register_canonicalize
...
@@ -1745,42 +1721,41 @@ def local_useless_reduce(fgraph, node):
...
@@ -1745,42 +1721,41 @@ def local_useless_reduce(fgraph, node):
@node_rewriter
(
ALL_REDUCE
)
@node_rewriter
(
ALL_REDUCE
)
def
local_reduce_broadcastable
(
fgraph
,
node
):
def
local_reduce_broadcastable
(
fgraph
,
node
):
"""Remove reduction over broadcastable dimensions."""
"""Remove reduction over broadcastable dimensions."""
if
isinstance
(
node
.
op
,
CAReduce
):
(
reduced
,)
=
node
.
inputs
(
reduced
,)
=
node
.
inputs
odtype
=
node
.
outputs
[
0
]
.
dtype
odtype
=
node
.
outputs
[
0
]
.
dtype
if
node
.
op
.
axis
is
None
:
if
node
.
op
.
axis
is
None
:
if
all
(
reduced
.
broadcastable
):
if
all
(
reduced
.
broadcastable
):
return
[
reduced
.
dimshuffle
()
.
astype
(
odtype
)]
return
[
reduced
.
dimshuffle
()
.
astype
(
odtype
)]
else
:
else
:
axis
=
list
(
node
.
op
.
axis
)
axis
=
list
(
node
.
op
.
axis
)
cuttable
=
[
a
for
a
in
axis
if
reduced
.
broadcastable
[
a
]]
cuttable
=
[
a
for
a
in
axis
if
reduced
.
broadcastable
[
a
]]
if
cuttable
:
if
cuttable
:
# -- we can remove some axes of summation.
# -- we can remove some axes of summation.
new_axis
=
[]
new_axis
=
[]
pattern
=
[]
pattern
=
[]
ii
=
0
ii
=
0
for
p
in
range
(
reduced
.
ndim
):
for
p
in
range
(
reduced
.
ndim
):
if
p
not
in
cuttable
:
if
p
not
in
cuttable
:
if
p
in
axis
:
if
p
in
axis
:
new_axis
.
append
(
ii
)
new_axis
.
append
(
ii
)
pattern
.
append
(
p
)
pattern
.
append
(
p
)
ii
+=
1
ii
+=
1
new_reduced
=
reduced
.
dimshuffle
(
*
pattern
)
new_reduced
=
reduced
.
dimshuffle
(
*
pattern
)
if
new_axis
:
if
new_axis
:
if
type
(
node
.
op
)
==
CAReduce
:
if
type
(
node
.
op
)
==
CAReduce
:
# This case handles `CAReduce` instances
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
# scalar `Op`-specific subclasses
# TODO FIXME: This highlights a major design flaw in
# TODO FIXME: This highlights a major design flaw in
# `CAReduce` (or at least our use of it), and it needs
# `CAReduce` (or at least our use of it), and it needs
# to be fixed
# to be fixed
new_op
=
node
.
op
.
__class__
(
node
.
op
.
scalar_op
,
axis
=
new_axis
)
new_op
=
node
.
op
.
__class__
(
node
.
op
.
scalar_op
,
axis
=
new_axis
)
else
:
new_op
=
node
.
op
.
__class__
(
axis
=
new_axis
)
return
[
new_op
(
new_reduced
)]
else
:
else
:
# -- in this case we can remove the reduction completely
new_op
=
node
.
op
.
__class__
(
axis
=
new_axis
)
return
[
new_reduced
.
astype
(
odtype
)]
return
[
new_op
(
new_reduced
)]
else
:
# -- in this case we can remove the reduction completely
return
[
new_reduced
.
astype
(
odtype
)]
@register_specialize
@register_specialize
...
@@ -1792,61 +1767,54 @@ def local_opt_alloc(fgraph, node):
...
@@ -1792,61 +1767,54 @@ def local_opt_alloc(fgraph, node):
prod(alloc(constant,shapes...)) => constant**prod(shapes)
prod(alloc(constant,shapes...)) => constant**prod(shapes)
"""
"""
if
isinstance
(
node
.
op
,
Sum
)
or
isinstance
(
node
.
op
,
Prod
):
(
node_inps
,)
=
node
.
inputs
(
node_inps
,)
=
node
.
inputs
if
node_inps
.
owner
and
isinstance
(
node_inps
.
owner
.
op
,
Alloc
):
if
node_inps
.
owner
and
isinstance
(
node_inps
.
owner
.
op
,
Alloc
):
inp
=
node_inps
.
owner
.
inputs
[
0
]
inp
=
node_inps
.
owner
.
inputs
[
0
]
shapes
=
node_inps
.
owner
.
inputs
[
1
:]
shapes
=
node_inps
.
owner
.
inputs
[
1
:]
try
:
try
:
val
=
get_underlying_scalar_constant_value
(
inp
,
only_process_constants
=
True
)
val
=
get_underlying_scalar_constant_value
(
assert
val
.
size
==
1
inp
,
only_process_constants
=
True
val
=
val
.
reshape
(
1
)[
0
]
)
# check which type of op
assert
val
.
size
==
1
size
=
mul
(
*
shapes
)
val
=
val
.
reshape
(
1
)[
0
]
if
inp
.
dtype
in
(
"float16"
,
"float32"
):
# check which type of op
# shapes are ints and normally int64.
size
=
mul
(
*
shapes
)
# We don't want to have a float64 upcast
if
inp
.
dtype
in
(
"float16"
,
"float32"
):
# We don't want to downcast to float16
# shapes are ints and normally int64.
# as we fear it could loose too much precision
# We don't want to have a float64 upcast
# that will be amplified by the mul/pow below.
# We don't want to downcast to float16
size
=
size
.
astype
(
"float32"
)
# as we fear it could loose too much precision
if
node
.
op
.
axis
is
None
or
node
.
op
.
axis
==
tuple
(
range
(
inp
.
ndim
)):
# that will be amplified by the mul/pow below.
if
isinstance
(
node
.
op
,
Sum
):
size
=
size
.
astype
(
"float32"
)
val
=
val
*
size
if
node
.
op
.
axis
is
None
or
node
.
op
.
axis
==
tuple
(
range
(
inp
.
ndim
)):
else
:
if
isinstance
(
node
.
op
,
Sum
):
val
=
val
**
size
val
=
val
*
size
# Sum can change the input dtype (upcast or bool
else
:
# -> float32) by default or by user request.
val
=
val
**
size
# We can ignore the acc_dtype, as there is only 1
# Sum can change the input dtype (upcast or bool
# elemwise we will do and not a sequence, so there is no
# -> float32) by default or by user request.
# accumulation of errors.
# We can ignore the acc_dtype, as there is only 1
# So mostly, we just need to cast the output to the old
# elemwise we will do and not a sequence, so there is no
# dtype.
# accumulation of errors.
# So mostly, we just need to cast the output to the old
# dtype.
val
=
val
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
return
[
val
]
to_prod
=
[
shapes
[
i
]
for
i
in
range
(
len
(
shapes
))
if
i
in
node
.
op
.
axis
]
if
to_prod
:
size
=
mul
(
*
to_prod
)
if
isinstance
(
node
.
op
,
Sum
):
val
*=
size
else
:
val
=
val
**
size
# See comments above.
val
=
val
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
val
=
val
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
return
[
return
[
val
]
alloc
(
to_prod
=
[
shapes
[
i
]
for
i
in
range
(
len
(
shapes
))
if
i
in
node
.
op
.
axis
]
val
,
if
to_prod
:
*
[
size
=
mul
(
*
to_prod
)
shapes
[
i
]
if
isinstance
(
node
.
op
,
Sum
):
for
i
in
range
(
len
(
shapes
))
val
*=
size
if
i
not
in
node
.
op
.
axis
else
:
],
val
=
val
**
size
)
# See comments above.
]
val
=
val
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
except
NotScalarConstantError
:
return
[
pass
alloc
(
val
,
*
[
shapes
[
i
]
for
i
in
range
(
len
(
shapes
))
if
i
not
in
node
.
op
.
axis
],
)
]
except
NotScalarConstantError
:
pass
@register_specialize
@register_specialize
...
@@ -1858,19 +1826,18 @@ def local_neg_div_neg(fgraph, node):
...
@@ -1858,19 +1826,18 @@ def local_neg_div_neg(fgraph, node):
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
Also performs - (c / b) -> ((-c) / b) when c is a scalar constant.
"""
"""
if
node
.
op
==
neg
:
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
true_div
:
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
true_div
:
frac
=
node
.
inputs
[
0
]
frac
=
node
.
inputs
[
0
]
num
,
denom
=
frac
.
owner
.
inputs
num
,
denom
=
frac
.
owner
.
inputs
if
num
.
owner
and
num
.
owner
.
op
==
neg
:
if
num
.
owner
and
num
.
owner
.
op
==
neg
:
if
len
(
fgraph
.
clients
[
frac
])
==
1
:
if
len
(
fgraph
.
clients
[
frac
])
==
1
:
# No other clients of the original division
# No other clients of the original division
new_num
=
num
.
owner
.
inputs
[
0
]
new_num
=
num
.
owner
.
inputs
[
0
]
return
[
true_div
(
new_num
,
denom
)]
return
[
true_div
(
new_num
,
denom
)]
elif
all
(
num
.
broadcastable
)
and
isinstance
(
num
,
Constant
):
elif
all
(
num
.
broadcastable
)
and
isinstance
(
num
,
Constant
):
if
len
(
fgraph
.
clients
[
frac
])
==
1
:
if
len
(
fgraph
.
clients
[
frac
])
==
1
:
new_num
=
-
num
.
data
new_num
=
-
num
.
data
return
[
true_div
(
new_num
,
denom
)]
return
[
true_div
(
new_num
,
denom
)]
@register_canonicalize
@register_canonicalize
...
@@ -1881,14 +1848,13 @@ def local_sub_neg_to_add(fgraph, node):
...
@@ -1881,14 +1848,13 @@ def local_sub_neg_to_add(fgraph, node):
x - (-y) -> x + y
x - (-y) -> x + y
"""
"""
if
node
.
op
==
sub
:
minuend
,
subtrahend
=
node
.
inputs
minuend
,
subtrahend
=
node
.
inputs
if
subtrahend
.
owner
:
if
subtrahend
.
owner
:
if
subtrahend
.
owner
.
op
==
neg
:
if
subtrahend
.
owner
.
op
==
neg
:
pre_neg
=
subtrahend
.
owner
.
inputs
[
0
]
pre_neg
=
subtrahend
.
owner
.
inputs
[
0
]
new_out
=
add
(
minuend
,
pre_neg
)
new_out
=
add
(
minuend
,
pre_neg
)
return
[
new_out
]
return
[
new_out
]
@register_specialize
@register_specialize
...
@@ -1903,7 +1869,7 @@ def local_add_neg_to_sub(fgraph, node):
...
@@ -1903,7 +1869,7 @@ def local_add_neg_to_sub(fgraph, node):
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# Rewrite is only applicable when there are two inputs to add
# Rewrite is only applicable when there are two inputs to add
if
node
.
op
==
add
and
len
(
node
.
inputs
)
==
2
:
if
len
(
node
.
inputs
)
==
2
:
# Look for pattern with either input order
# Look for pattern with either input order
for
first
,
second
in
(
node
.
inputs
,
reversed
(
node
.
inputs
)):
for
first
,
second
in
(
node
.
inputs
,
reversed
(
node
.
inputs
)):
if
second
.
owner
:
if
second
.
owner
:
...
@@ -1927,27 +1893,24 @@ def local_mul_zero(fgraph, node):
...
@@ -1927,27 +1893,24 @@ def local_mul_zero(fgraph, node):
with zero.
with zero.
"""
"""
if
node
.
op
==
mul
:
otype
=
node
.
outputs
[
0
]
.
type
otype
=
node
.
outputs
[
0
]
.
type
for
i
in
node
.
inputs
:
for
i
in
node
.
inputs
:
try
:
try
:
value
=
get_underlying_scalar_constant_value
(
i
)
value
=
get_underlying_scalar_constant_value
(
i
)
except
NotScalarConstantError
:
except
NotScalarConstantError
:
continue
continue
# print 'MUL by value', value, node.inputs
# print 'MUL by value', value, node.inputs
if
value
==
0
:
if
value
==
0
:
# print '... returning zeros'
# print '... returning zeros'
return
[
return
[
broadcast_arrays
(
_asarray
(
0
,
dtype
=
otype
.
dtype
),
*
node
.
inputs
)[
0
]]
broadcast_arrays
(
_asarray
(
0
,
dtype
=
otype
.
dtype
),
*
node
.
inputs
)[
0
]
]
# TODO: Add this to the canonicalization to reduce redundancy.
# TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize
@register_specialize
@node_rewriter
([
true_div
])
@node_rewriter
([
true_div
])
def
local_div_to_reciprocal
(
fgraph
,
node
):
def
local_div_to_reciprocal
(
fgraph
,
node
):
if
n
ode
.
op
==
true_div
and
n
p
.
all
(
get_constant
(
node
.
inputs
[
0
])
==
1.0
):
if
np
.
all
(
get_constant
(
node
.
inputs
[
0
])
==
1.0
):
out
=
node
.
outputs
[
0
]
out
=
node
.
outputs
[
0
]
new_out
=
reciprocal
(
local_mul_canonizer
.
merge_num_denum
(
node
.
inputs
[
1
:],
[]))
new_out
=
reciprocal
(
local_mul_canonizer
.
merge_num_denum
(
node
.
inputs
[
1
:],
[]))
# The ones could have forced upcasting
# The ones could have forced upcasting
...
@@ -1957,30 +1920,22 @@ def local_div_to_reciprocal(fgraph, node):
...
@@ -1957,30 +1920,22 @@ def local_div_to_reciprocal(fgraph, node):
if
not
out
.
type
.
is_super
(
new_out
.
type
):
if
not
out
.
type
.
is_super
(
new_out
.
type
):
new_out
=
alloc_like
(
new_out
,
out
,
fgraph
)
new_out
=
alloc_like
(
new_out
,
out
,
fgraph
)
return
[
new_out
]
return
[
new_out
]
else
:
return
False
@register_canonicalize
@register_canonicalize
@node_rewriter
([
reciprocal
])
@node_rewriter
([
reciprocal
])
def
local_reciprocal_canon
(
fgraph
,
node
):
def
local_reciprocal_canon
(
fgraph
,
node
):
if
node
.
op
==
reciprocal
:
return
[
pt_pow
(
node
.
inputs
[
0
],
-
1.0
)]
return
[
pt_pow
(
node
.
inputs
[
0
],
-
1.0
)]
else
:
return
False
@register_canonicalize
@register_canonicalize
@node_rewriter
([
pt_pow
])
@node_rewriter
([
pt_pow
])
def
local_pow_canonicalize
(
fgraph
,
node
):
def
local_pow_canonicalize
(
fgraph
,
node
):
if
node
.
op
==
pt_pow
:
cst
=
get_constant
(
node
.
inputs
[
1
])
cst
=
get_constant
(
node
.
inputs
[
1
])
if
cst
==
0
:
if
cst
==
0
:
return
[
alloc_like
(
1
,
node
.
outputs
[
0
],
fgraph
)]
return
[
alloc_like
(
1
,
node
.
outputs
[
0
],
fgraph
)]
if
cst
==
1
:
if
cst
==
1
:
return
[
alloc_like
(
node
.
inputs
[
0
],
node
.
outputs
[
0
],
fgraph
)]
return
[
alloc_like
(
node
.
inputs
[
0
],
node
.
outputs
[
0
],
fgraph
)]
else
:
return
False
@register_specialize
@register_specialize
...
@@ -1989,21 +1944,17 @@ def local_mul_to_sqr(fgraph, node):
...
@@ -1989,21 +1944,17 @@ def local_mul_to_sqr(fgraph, node):
"""
"""
x*x -> sqr(x)
x*x -> sqr(x)
"""
"""
if
node
.
op
==
mul
:
if
len
(
node
.
inputs
)
==
2
:
if
len
(
node
.
inputs
)
==
2
:
if
node
.
inputs
[
0
]
is
node
.
inputs
[
1
]:
if
node
.
inputs
[
0
]
is
node
.
inputs
[
1
]:
return
[
sqr
(
node
.
inputs
[
0
])]
return
[
sqr
(
node
.
inputs
[
0
])]
@register_canonicalize
@register_canonicalize
@node_rewriter
([
int_div
])
@node_rewriter
([
int_div
])
def
local_intdiv_by_one
(
fgraph
,
node
):
def
local_intdiv_by_one
(
fgraph
,
node
):
"""x // 1 -> x"""
"""x // 1 -> x"""
if
node
.
op
in
[
int_div
]:
if
isinstance
(
node
.
inputs
[
1
],
TensorConstant
)
and
np
.
all
(
node
.
inputs
[
1
]
.
value
==
1
):
if
isinstance
(
node
.
inputs
[
1
],
TensorConstant
)
and
np
.
all
(
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
node
.
inputs
[
1
]
.
value
==
1
):
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
@register_canonicalize
@register_canonicalize
...
@@ -2011,49 +1962,43 @@ def local_intdiv_by_one(fgraph, node):
...
@@ -2011,49 +1962,43 @@ def local_intdiv_by_one(fgraph, node):
@node_rewriter
([
int_div
,
true_div
])
@node_rewriter
([
int_div
,
true_div
])
def
local_zero_div
(
fgraph
,
node
):
def
local_zero_div
(
fgraph
,
node
):
"""0 / x -> 0"""
"""0 / x -> 0"""
if
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
if
get_constant
(
node
.
inputs
[
0
])
==
0
:
node
.
op
.
scalar_op
,
ps
.
IntDiv
|
ps
.
TrueDiv
ret
=
alloc_like
(
0
,
node
.
outputs
[
0
],
fgraph
)
):
ret
.
tag
.
values_eq_approx
=
values_eq_approx_remove_nan
if
get_constant
(
node
.
inputs
[
0
])
==
0
:
return
[
ret
]
ret
=
alloc_like
(
0
,
node
.
outputs
[
0
],
fgraph
)
ret
.
tag
.
values_eq_approx
=
values_eq_approx_remove_nan
return
[
ret
]
@register_specialize
@register_specialize
@node_rewriter
([
pt_pow
])
@node_rewriter
([
pt_pow
])
def
local_pow_specialize
(
fgraph
,
node
):
def
local_pow_specialize
(
fgraph
,
node
):
if
node
.
op
==
pt_pow
:
# the idea here is that we have pow(x, y)
# the idea here is that we have pow(x, y)
odtype
=
node
.
outputs
[
0
]
.
dtype
odtype
=
node
.
outputs
[
0
]
.
dtype
xsym
=
node
.
inputs
[
0
]
xsym
=
node
.
inputs
[
0
]
ysym
=
node
.
inputs
[
1
]
ysym
=
node
.
inputs
[
1
]
y
=
get_constant
(
ysym
)
y
=
get_constant
(
ysym
)
if
(
y
is
not
None
)
and
not
broadcasted_by
(
xsym
,
ysym
):
if
(
y
is
not
None
)
and
not
broadcasted_by
(
xsym
,
ysym
):
rval
=
None
rval
=
None
if
np
.
all
(
y
==
2
):
if
np
.
all
(
y
==
2
):
rval
=
[
sqr
(
xsym
)]
rval
=
[
sqr
(
xsym
)]
if
np
.
all
(
y
==
1
):
if
np
.
all
(
y
==
1
):
rval
=
[
xsym
]
rval
=
[
xsym
]
if
np
.
all
(
y
==
0
):
if
np
.
all
(
y
==
0
):
rval
=
[
alloc_like
(
1
,
xsym
,
fgraph
)]
rval
=
[
alloc_like
(
1
,
xsym
,
fgraph
)]
if
np
.
all
(
y
==
0.5
):
if
np
.
all
(
y
==
0.5
):
rval
=
[
sqrt
(
xsym
)]
rval
=
[
sqrt
(
xsym
)]
if
np
.
all
(
y
==
-
0.5
):
if
np
.
all
(
y
==
-
0.5
):
rval
=
[
reciprocal
(
sqrt
(
xsym
))]
rval
=
[
reciprocal
(
sqrt
(
xsym
))]
if
np
.
all
(
y
==
-
1
):
if
np
.
all
(
y
==
-
1
):
rval
=
[
reciprocal
(
xsym
)]
rval
=
[
reciprocal
(
xsym
)]
if
np
.
all
(
y
==
-
2
):
if
np
.
all
(
y
==
-
2
):
rval
=
[
reciprocal
(
sqr
(
xsym
))]
rval
=
[
reciprocal
(
sqr
(
xsym
))]
if
rval
:
if
rval
:
if
not
rval
[
0
]
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
if
not
rval
[
0
]
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
return
None
return
None
rval
[
0
]
=
cast
(
rval
[
0
],
odtype
)
rval
[
0
]
=
cast
(
rval
[
0
],
odtype
)
assert
rval
[
0
]
.
type
.
dtype
==
node
.
outputs
[
0
]
.
type
.
dtype
assert
rval
[
0
]
.
type
.
dtype
==
node
.
outputs
[
0
]
.
type
.
dtype
return
rval
return
rval
else
:
return
False
@register_specialize
@register_specialize
...
@@ -2138,61 +2083,60 @@ def local_mul_specialize(fgraph, node):
...
@@ -2138,61 +2083,60 @@ def local_mul_specialize(fgraph, node):
"""
"""
# at this point [post canonicalize], mul() may have many inputs.
# at this point [post canonicalize], mul() may have many inputs.
if
node
.
op
==
mul
:
# the idea here is that we have pow(x, y)
# the idea here is that we have pow(x, y)
has_neg
=
False
has_neg
=
False
new_inputs
=
[]
new_inputs
=
[]
nb_neg_node
=
0
nb_neg_node
=
0
nb_cst
=
0
nb_cst
=
0
for
inp
in
node
.
inputs
:
for
inp
in
node
.
inputs
:
# remove any neg arguments
# remove any neg arguments
while
inp
.
owner
and
inp
.
owner
.
op
==
neg
:
while
inp
.
owner
and
inp
.
owner
.
op
==
neg
:
has_neg
^=
True
has_neg
^=
True
inp
=
inp
.
owner
.
inputs
[
0
]
inp
=
inp
.
owner
.
inputs
[
0
]
nb_neg_node
+=
1
nb_neg_node
+=
1
# remove special case arguments of 1, -1 or 0
# remove special case arguments of 1, -1 or 0
y
=
get_constant
(
inp
)
y
=
get_constant
(
inp
)
if
y
==
1.0
:
if
y
==
1.0
:
nb_cst
+=
1
nb_cst
+=
1
elif
y
==
-
1.0
:
elif
y
==
-
1.0
:
nb_cst
+=
1
nb_cst
+=
1
has_neg
^=
True
# toggles
has_neg
^=
True
# toggles
elif
y
==
0.0
:
elif
y
==
0.0
:
# if we find any zero, we just return right away
# if we find any zero, we just return right away
return
[
alloc_like
(
0
,
node
.
outputs
[
0
],
fgraph
)]
return
[
alloc_like
(
0
,
node
.
outputs
[
0
],
fgraph
)]
else
:
else
:
new_inputs
.
append
(
inp
)
new_inputs
.
append
(
inp
)
if
new_inputs
!=
node
.
inputs
:
if
new_inputs
!=
node
.
inputs
:
if
new_inputs
:
if
new_inputs
:
if
len
(
new_inputs
)
==
1
:
if
len
(
new_inputs
)
==
1
:
if
has_neg
:
if
has_neg
:
if
new_inputs
[
0
]
.
dtype
in
([
*
uint_dtypes
,
"bool"
]):
if
new_inputs
[
0
]
.
dtype
in
([
*
uint_dtypes
,
"bool"
]):
return
return
else
:
rval
=
-
new_inputs
[
0
]
else
:
else
:
rval
=
new_inputs
[
0
]
rval
=
-
new_inputs
[
0
]
else
:
else
:
# The next case would cause a replace by an equivalent case.
rval
=
new_inputs
[
0
]
if
has_neg
and
nb_neg_node
==
0
and
nb_cst
==
1
:
return
elif
has_neg
:
# Don't add an extra neg node as we can't
# fully replace this mul by a neg.
m1
=
np
.
asarray
(
-
1
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)
new_inputs
=
[
m1
,
*
new_inputs
]
rval
=
mul
(
*
new_inputs
)
return
[
alloc_like
(
rval
,
node
.
outputs
[
0
],
fgraph
)]
else
:
else
:
# there are no variable inputs to mul
# The next case would cause a replace by an equivalent case.
# N.B. this could have been constant-folded...
if
has_neg
and
nb_neg_node
==
0
and
nb_cst
==
1
:
if
has_neg
:
return
return
[
alloc_like
(
-
1
,
node
.
outputs
[
0
],
fgraph
)]
elif
has_neg
:
else
:
# Don't add an extra neg node as we can't
return
[
alloc_like
(
1
,
node
.
outputs
[
0
],
fgraph
)]
# fully replace this mul by a neg.
m1
=
np
.
asarray
(
-
1
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)
new_inputs
=
[
m1
,
*
new_inputs
]
rval
=
mul
(
*
new_inputs
)
return
[
alloc_like
(
rval
,
node
.
outputs
[
0
],
fgraph
)]
else
:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if
has_neg
:
return
[
alloc_like
(
-
1
,
node
.
outputs
[
0
],
fgraph
)]
else
:
return
[
alloc_like
(
1
,
node
.
outputs
[
0
],
fgraph
)]
@register_specialize
@register_specialize
...
@@ -2276,7 +2220,7 @@ def local_abs_lift(fgraph, node):
...
@@ -2276,7 +2220,7 @@ def local_abs_lift(fgraph, node):
This is needed for check_for_x_over_absX to apply in more case.
This is needed for check_for_x_over_absX to apply in more case.
"""
"""
if
node
.
op
==
pt_abs
and
node
.
inputs
[
0
]
.
owner
:
if
node
.
inputs
[
0
]
.
owner
:
assert
node
.
nin
==
1
assert
node
.
nin
==
1
if
node
.
inputs
[
0
]
.
owner
.
op
==
mul
:
if
node
.
inputs
[
0
]
.
owner
.
op
==
mul
:
return
[
mul
(
*
[
pt_abs
(
i
)
for
i
in
node
.
inputs
[
0
]
.
owner
.
inputs
])]
return
[
mul
(
*
[
pt_abs
(
i
)
for
i
in
node
.
inputs
[
0
]
.
owner
.
inputs
])]
...
@@ -2328,31 +2272,30 @@ def local_abs_merge(fgraph, node):
...
@@ -2328,31 +2272,30 @@ def local_abs_merge(fgraph, node):
def
local_log1p
(
fgraph
,
node
):
def
local_log1p
(
fgraph
,
node
):
# log(1+x) -> log1p(x)
# log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
# log(1-x) -> log1p(-x)
if
node
.
op
==
log
:
(
log_arg
,)
=
node
.
inputs
(
log_arg
,)
=
node
.
inputs
if
log_arg
.
owner
and
log_arg
.
owner
.
op
==
add
:
if
log_arg
.
owner
and
log_arg
.
owner
.
op
==
add
:
scalars
,
scalar_inputs
,
nonconsts
=
scalarconsts_rest
(
scalars
,
scalar_inputs
,
nonconsts
=
scalarconsts_rest
(
log_arg
.
owner
.
inputs
,
only_process_constants
=
True
log_arg
.
owner
.
inputs
,
only_process_constants
=
True
)
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
# scalar_inputs are potentially dimshuffled and fill'd scalars
if
scalars
and
np
.
allclose
(
np
.
sum
(
scalars
),
1
):
if
scalars
and
np
.
allclose
(
np
.
sum
(
scalars
),
1
):
if
nonconsts
:
if
nonconsts
:
if
len
(
nonconsts
)
>
1
:
if
len
(
nonconsts
)
>
1
:
ninp
=
add
(
*
nonconsts
)
ninp
=
add
(
*
nonconsts
)
else
:
else
:
ninp
=
nonconsts
[
0
]
ninp
=
nonconsts
[
0
]
if
ninp
.
dtype
!=
log_arg
.
type
.
dtype
:
if
ninp
.
dtype
!=
log_arg
.
type
.
dtype
:
ninp
=
ninp
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
ninp
=
ninp
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
return
[
alloc_like
(
log1p
(
ninp
),
node
.
outputs
[
0
],
fgraph
)]
return
[
alloc_like
(
log1p
(
ninp
),
node
.
outputs
[
0
],
fgraph
)]
elif
log_arg
.
owner
and
log_arg
.
owner
.
op
==
sub
:
elif
log_arg
.
owner
and
log_arg
.
owner
.
op
==
sub
:
one
=
extract_constant
(
log_arg
.
owner
.
inputs
[
0
],
only_process_constants
=
True
)
one
=
extract_constant
(
log_arg
.
owner
.
inputs
[
0
],
only_process_constants
=
True
)
if
one
!=
1
:
if
one
!=
1
:
return
return
other
=
log_arg
.
owner
.
inputs
[
1
]
other
=
log_arg
.
owner
.
inputs
[
1
]
if
other
.
dtype
!=
log_arg
.
dtype
:
if
other
.
dtype
!=
log_arg
.
dtype
:
other
=
other
.
astype
(
log_arg
.
dtype
)
other
=
other
.
astype
(
log_arg
.
dtype
)
return
[
log1p
(
neg
(
other
))]
return
[
log1p
(
neg
(
other
))]
@register_stabilize
@register_stabilize
...
@@ -2365,26 +2308,25 @@ def local_log_add_exp(fgraph, node):
...
@@ -2365,26 +2308,25 @@ def local_log_add_exp(fgraph, node):
TODO: in canonicalize, change log10 and log2 -> log
TODO: in canonicalize, change log10 and log2 -> log
"""
"""
if
node
.
op
==
log
:
z
=
node
.
inputs
[
0
]
z
=
node
.
inputs
[
0
]
if
z
.
owner
and
z
.
owner
.
op
==
add
:
if
z
.
owner
and
z
.
owner
.
op
==
add
:
zi
=
z
.
owner
.
inputs
zi
=
z
.
owner
.
inputs
pre_exp
=
[
x
.
owner
.
inputs
[
0
]
for
x
in
zi
if
x
.
owner
and
x
.
owner
.
op
==
exp
]
pre_exp
=
[
x
.
owner
.
inputs
[
0
]
for
x
in
zi
if
x
.
owner
and
x
.
owner
.
op
==
exp
]
# all arguments to add are exp(<something>)
# all arguments to add are exp(<something>)
if
len
(
pre_exp
)
==
len
(
zi
):
if
len
(
pre_exp
)
==
len
(
zi
):
# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside add to break the self-symmetry
# Switch statement is placed directly inside add to break the self-symmetry
# of the returned output (otherwise the rewrite would not stabilize)
# of the returned output (otherwise the rewrite would not stabilize)
max_pre
=
reduce
(
maximum
,
pre_exp
)
max_pre
=
reduce
(
maximum
,
pre_exp
)
ret
=
max_pre
+
log
(
ret
=
max_pre
+
log
(
add
(
add
(
*
[
*
[
switch
(
isinf
(
max_pre
),
exp
(
max_pre
),
exp
(
p
-
max_pre
))
switch
(
isinf
(
max_pre
),
exp
(
max_pre
),
exp
(
p
-
max_pre
))
for
p
in
pre_exp
for
p
in
pre_exp
]
]
)
)
)
return
[
ret
]
)
return
[
ret
]
@register_stabilize
@register_stabilize
...
@@ -2393,9 +2335,6 @@ def local_log_add_exp(fgraph, node):
...
@@ -2393,9 +2335,6 @@ def local_log_add_exp(fgraph, node):
def
local_log_sum_exp
(
fgraph
,
node
):
def
local_log_sum_exp
(
fgraph
,
node
):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
if
node
.
op
!=
log
:
return
sum_node
=
node
.
inputs
[
0
]
.
owner
sum_node
=
node
.
inputs
[
0
]
.
owner
# If the sum has keepdims=True, there might be a dimshuffle
# If the sum has keepdims=True, there might be a dimshuffle
if
sum_node
and
isinstance
(
sum_node
.
op
,
DimShuffle
):
if
sum_node
and
isinstance
(
sum_node
.
op
,
DimShuffle
):
...
@@ -2720,8 +2659,7 @@ def local_log_erfc(fgraph, node):
...
@@ -2720,8 +2659,7 @@ def local_log_erfc(fgraph, node):
numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
10.0541948,10.0541951,.0000001)]
10.0541948,10.0541951,.0000001)]
"""
"""
if
node
.
op
!=
log
:
return
False
if
not
node
.
inputs
[
0
]
.
owner
or
node
.
inputs
[
0
]
.
owner
.
op
!=
erfc
:
if
not
node
.
inputs
[
0
]
.
owner
or
node
.
inputs
[
0
]
.
owner
.
op
!=
erfc
:
return
False
return
False
...
@@ -2773,8 +2711,6 @@ def local_grad_log_erfc_neg(fgraph, node):
...
@@ -2773,8 +2711,6 @@ def local_grad_log_erfc_neg(fgraph, node):
Make it so that the test does not generate an error in that case!
Make it so that the test does not generate an error in that case!
"""
"""
if
node
.
op
!=
true_div
:
return
False
if
not
node
.
inputs
[
1
]
.
owner
or
node
.
inputs
[
1
]
.
owner
.
op
!=
erfc
:
if
not
node
.
inputs
[
1
]
.
owner
or
node
.
inputs
[
1
]
.
owner
.
op
!=
erfc
:
return
False
return
False
...
@@ -3147,46 +3083,45 @@ def local_exp_over_1_plus_exp(fgraph, node):
...
@@ -3147,46 +3083,45 @@ def local_exp_over_1_plus_exp(fgraph, node):
"""
"""
# This rewrite should be done for numerical stability
# This rewrite should be done for numerical stability
# so we don't care to check client counts
# so we don't care to check client counts
if
node
.
op
==
true_div
:
# find all the exp() terms in the numerator
# find all the exp() terms in the numerator
num
,
denom
=
node
.
inputs
num
,
denom
=
node
.
inputs
num_exp_x
,
num_rest
,
num_neg
=
partition_num_or_denom
(
num
,
is_exp
)
num_exp_x
,
num_rest
,
num_neg
=
partition_num_or_denom
(
num
,
is_exp
)
denom_1pexp
,
denom_rest
,
denom_neg
=
partition_num_or_denom
(
denom
,
is_1pexp
)
denom_1pexp
,
denom_rest
,
denom_neg
=
partition_num_or_denom
(
denom
,
is_1pexp
)
sigmoids
=
[]
sigmoids
=
[]
for
t
in
denom_1pexp
:
for
t
in
denom_1pexp
:
if
t
in
num_exp_x
:
if
t
in
num_exp_x
:
# case: exp(x) /(1+exp(x))
# case: exp(x) /(1+exp(x))
sigmoids
.
append
(
sigmoid
(
t
))
sigmoids
.
append
(
sigmoid
(
t
))
del
num_exp_x
[
num_exp_x
.
index
(
t
)]
del
num_exp_x
[
num_exp_x
.
index
(
t
)]
else
:
# case: 1/(1+exp(x))
sigmoids
.
append
(
sigmoid
(
-
t
))
copy_stack_trace
(
node
.
outputs
[
0
],
sigmoids
[
-
1
])
if
not
sigmoids
:
# we didn't find any. abort
return
# put the new numerator together
new_num
=
sigmoids
+
[
exp
(
t
)
for
t
in
num_exp_x
]
+
num_rest
if
len
(
new_num
)
==
1
:
new_num
=
new_num
[
0
]
else
:
else
:
new_num
=
mul
(
*
new_num
)
# case: 1/(1+exp(x))
sigmoids
.
append
(
sigmoid
(
-
t
))
copy_stack_trace
(
node
.
outputs
[
0
],
sigmoids
[
-
1
])
if
num_neg
^
denom_neg
:
if
not
sigmoids
:
# we didn't find any. abort
new_num
=
-
new_num
return
# put the new numerator together
new_num
=
sigmoids
+
[
exp
(
t
)
for
t
in
num_exp_x
]
+
num_rest
if
len
(
new_num
)
==
1
:
new_num
=
new_num
[
0
]
else
:
new_num
=
mul
(
*
new_num
)
copy_stack_trace
(
num
,
new_num
)
if
num_neg
^
denom_neg
:
new_num
=
-
new_num
if
len
(
denom_rest
)
==
0
:
copy_stack_trace
(
num
,
new_num
)
return
[
new_num
]
elif
len
(
denom_rest
)
==
1
:
out
=
new_num
/
denom_rest
[
0
]
else
:
out
=
new_num
/
mul
(
*
denom_rest
)
copy_stack_trace
(
node
.
outputs
[
0
],
out
)
if
len
(
denom_rest
)
==
0
:
return
[
out
]
return
[
new_num
]
elif
len
(
denom_rest
)
==
1
:
out
=
new_num
/
denom_rest
[
0
]
else
:
out
=
new_num
/
mul
(
*
denom_rest
)
copy_stack_trace
(
node
.
outputs
[
0
],
out
)
return
[
out
]
def
parse_mul_tree
(
root
):
def
parse_mul_tree
(
root
):
...
@@ -3498,9 +3433,6 @@ def local_sigm_times_exp(fgraph, node):
...
@@ -3498,9 +3433,6 @@ def local_sigm_times_exp(fgraph, node):
todo: add stack traces to the intermediate variables
todo: add stack traces to the intermediate variables
"""
"""
# Bail early if it is not a multiplication.
if
node
.
op
!=
mul
:
return
None
# Obtain tree of multiplications starting at this node.
# Obtain tree of multiplications starting at this node.
mul_tree
=
parse_mul_tree
(
node
.
outputs
[
0
])
mul_tree
=
parse_mul_tree
(
node
.
outputs
[
0
])
did_something
=
perform_sigm_times_exp
(
mul_tree
)
did_something
=
perform_sigm_times_exp
(
mul_tree
)
...
@@ -3528,31 +3460,30 @@ def local_reciprocal_1_plus_exp(fgraph, node):
...
@@ -3528,31 +3460,30 @@ def local_reciprocal_1_plus_exp(fgraph, node):
"""
"""
# This Rewrite should be done for numerical stability
# This Rewrite should be done for numerical stability
# so we don't care to check client counts
# so we don't care to check client counts
if
node
.
op
==
reciprocal
:
reciprocal_arg
=
node
.
inputs
[
0
]
reciprocal_arg
=
node
.
inputs
[
0
]
if
reciprocal_arg
.
owner
and
reciprocal_arg
.
owner
.
op
==
add
:
if
reciprocal_arg
.
owner
and
reciprocal_arg
.
owner
.
op
==
add
:
scalars_
,
scalar_inputs
,
nonconsts
=
scalarconsts_rest
(
scalars_
,
scalar_inputs
,
nonconsts
=
scalarconsts_rest
(
reciprocal_arg
.
owner
.
inputs
,
only_process_constants
=
True
reciprocal_arg
.
owner
.
inputs
,
only_process_constants
=
True
)
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
# scalar_inputs are potentially dimshuffled and fill'd scalars
if
len
(
nonconsts
)
==
1
:
if
len
(
nonconsts
)
==
1
:
if
nonconsts
[
0
]
.
owner
and
nonconsts
[
0
]
.
owner
.
op
==
exp
:
if
nonconsts
[
0
]
.
owner
and
nonconsts
[
0
]
.
owner
.
op
==
exp
:
if
scalars_
and
np
.
allclose
(
np
.
sum
(
scalars_
),
1
):
if
scalars_
and
np
.
allclose
(
np
.
sum
(
scalars_
),
1
):
out
=
[
out
=
[
alloc_like
(
alloc_like
(
sigmoid
(
neg
(
nonconsts
[
0
]
.
owner
.
inputs
[
0
])),
sigmoid
(
neg
(
nonconsts
[
0
]
.
owner
.
inputs
[
0
])),
node
.
outputs
[
0
],
node
.
outputs
[
0
],
fgraph
,
fgraph
,
)
]
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace
(
[
nonconsts
[
0
],
reciprocal_arg
,
node
.
outputs
[
0
]],
out
)
)
return
out
]
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace
(
[
nonconsts
[
0
],
reciprocal_arg
,
node
.
outputs
[
0
]],
out
)
return
out
# 1 - sigmoid(x) -> sigmoid(-x)
# 1 - sigmoid(x) -> sigmoid(-x)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论