提交 2c9f692a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move tensor_copy rewrite to aesara.tensor.rewriting.basic

上级 63f52536
...@@ -11,6 +11,7 @@ from aesara.compile.ops import ViewOp ...@@ -11,6 +11,7 @@ from aesara.compile.ops import ViewOp
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.rewriting.basic import ( from aesara.graph.rewriting.basic import (
NodeRewriter, NodeRewriter,
RemovalNodeRewriter,
Rewriter, Rewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
...@@ -35,6 +36,7 @@ from aesara.tensor.basic import ( ...@@ -35,6 +36,7 @@ from aesara.tensor.basic import (
join, join,
ones_like, ones_like,
switch, switch,
tensor_copy,
zeros, zeros,
zeros_like, zeros_like,
) )
...@@ -1294,3 +1296,6 @@ def __getattr__(name): ...@@ -1294,3 +1296,6 @@ def __getattr__(name):
return fn() return fn()
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
...@@ -13,7 +13,6 @@ from aesara.graph.features import AlreadyThere, Feature ...@@ -13,7 +13,6 @@ from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import ( from aesara.graph.rewriting.basic import (
GraphRewriter, GraphRewriter,
RemovalNodeRewriter,
check_chain, check_chain,
copy_stack_trace, copy_stack_trace,
node_rewriter, node_rewriter,
...@@ -27,7 +26,6 @@ from aesara.tensor.basic import ( ...@@ -27,7 +26,6 @@ from aesara.tensor.basic import (
extract_constant, extract_constant,
get_scalar_constant_value, get_scalar_constant_value,
stack, stack,
tensor_copy,
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
...@@ -972,9 +970,6 @@ def local_reshape_lift(fgraph, node): ...@@ -972,9 +970,6 @@ def local_reshape_lift(fgraph, node):
return [e] return [e]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([SpecifyShape]) @node_rewriter([SpecifyShape])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论