提交 2c9f692a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move tensor_copy rewrite to aesara.tensor.rewriting.basic

上级 63f52536
......@@ -11,6 +11,7 @@ from aesara.compile.ops import ViewOp
from aesara.graph.basic import Constant, Variable
from aesara.graph.rewriting.basic import (
NodeRewriter,
RemovalNodeRewriter,
Rewriter,
copy_stack_trace,
in2out,
......@@ -35,6 +36,7 @@ from aesara.tensor.basic import (
join,
ones_like,
switch,
tensor_copy,
zeros,
zeros_like,
)
......@@ -1294,3 +1296,6 @@ def __getattr__(name):
return fn()
raise AttributeError(f"module {__name__} has no attribute {name}")
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
......@@ -13,7 +13,6 @@ from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import (
GraphRewriter,
RemovalNodeRewriter,
check_chain,
copy_stack_trace,
node_rewriter,
......@@ -27,7 +26,6 @@ from aesara.tensor.basic import (
extract_constant,
get_scalar_constant_value,
stack,
tensor_copy,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
......@@ -972,9 +970,6 @@ def local_reshape_lift(fgraph, node):
return [e]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_useless
@register_canonicalize
@node_rewriter([SpecifyShape])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论