提交 b303a4cc authored 作者: Frederic Bastien's avatar Frederic Bastien

remove duplicate optimization as one of them was not enabled.

上级 61c1d5cd
...@@ -712,13 +712,6 @@ def local_track_shape_i(node): ...@@ -712,13 +712,6 @@ def local_track_shape_i(node):
replacement = shape_feature.scheduled[node] replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Join])
def local_useless_join(node):
if isinstance(node.op, T.Join) and len(node.inputs)==2:
return [node.inputs[1]]
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Subtensor]) @gof.local_optimizer([T.Subtensor])
...@@ -1203,6 +1196,8 @@ def apply_rebroadcast_opt(rval): ...@@ -1203,6 +1196,8 @@ def apply_rebroadcast_opt(rval):
############# #############
# Join opts # # Join opts #
############# #############
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Join]) @gof.local_optimizer([T.Join])
def local_join_1(node): def local_join_1(node):
"""Join(i, x) => x """Join(i, x) => x
......
...@@ -2094,7 +2094,7 @@ def test_make_vector(): ...@@ -2094,7 +2094,7 @@ def test_make_vector():
except AssertionError: except AssertionError:
pass pass
def test_local_useless_join(): def test_local_join_1():
#test for vector #test for vector
a = TT.vector('a') a = TT.vector('a')
s = stack(a) s = stack(a)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论