提交 219d33bf authored 作者: Frederic's avatar Frederic

Make an opt to not add useless op that will get removed.

上级 3ee52a9a
...@@ -4242,25 +4242,25 @@ def local_sum_prod_div_dimshuffle(node): ...@@ -4242,25 +4242,25 @@ def local_sum_prod_div_dimshuffle(node):
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
op_on_compatible_dims = T.sum( op_on_compatible_dims = T.sum(
numerator, axis=compatible_dims) numerator, axis=compatible_dims)
div_op = T.true_div( rval = T.true_div(
op_on_compatible_dims, op_on_compatible_dims,
optimized_dimshuffle) optimized_dimshuffle)
op_on_incompatible_dims = T.sum( if len(reordered_incompatible_dims) > 0:
div_op, rval = T.sum(rval,
axis=reordered_incompatible_dims) axis=reordered_incompatible_dims)
elif isinstance(node.op, T.elemwise.Prod): elif isinstance(node.op, T.elemwise.Prod):
op_on_compatible_dims = T.prod( op_on_compatible_dims = T.prod(
numerator, axis=compatible_dims) numerator, axis=compatible_dims)
dtype = numerator.dtype dtype = numerator.dtype
div_op = T.true_div( rval = T.true_div(
op_on_compatible_dims, op_on_compatible_dims,
(optimized_dimshuffle ** (optimized_dimshuffle **
T.prod([numerator.shape[ax].astype(dtype) T.prod([numerator.shape[ax].astype(dtype)
for ax in compatible_dims]))) for ax in compatible_dims])))
op_on_incompatible_dims = T.prod( if len(reordered_incompatible_dims) > 0:
div_op, rval = T.prod(rval,
axis=reordered_incompatible_dims) axis=reordered_incompatible_dims)
return [op_on_incompatible_dims] return [rval]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论