提交 f7cf2734 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extract ViewOp functionality into a base TypeCastOp

上级 ddcd9882
...@@ -33,11 +33,8 @@ def register_view_op_c_code(type, code, version=()): ...@@ -33,11 +33,8 @@ def register_view_op_c_code(type, code, version=()):
ViewOp.c_code_and_version[type] = (code, version) ViewOp.c_code_and_version[type] = (code, version)
class ViewOp(COp): class TypeCastingOp(COp):
""" """Op that performs a graph-level type cast operation, but has no effect computation-wise (identity function)."""
Returns an inplace view of the input. Used internally by PyTensor.
"""
view_map = {0: [0]} view_map = {0: [0]}
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
...@@ -47,13 +44,8 @@ class ViewOp(COp): ...@@ -47,13 +44,8 @@ class ViewOp(COp):
__props__: tuple = () __props__: tuple = ()
_f16_ok: bool = True _f16_ok: bool = True
def make_node(self, x): def perform(self, node, inputs, outputs_storage):
return Apply(self, [x], [x.type()]) outputs_storage[0][0] = inputs[0]
def perform(self, node, inp, out):
(x,) = inp
(z,) = out
z[0] = x
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}" return f"{self.__class__.__name__}"
...@@ -90,6 +82,13 @@ class ViewOp(COp): ...@@ -90,6 +82,13 @@ class ViewOp(COp):
return tuple(version) return tuple(version)
class ViewOp(TypeCastingOp):
"""Returns an inplace view of the input. Used internally by PyTensor."""
def make_node(self, x):
return Apply(self, [x], [x.type()])
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return input_shapes return input_shapes
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
from pytensor.compile import JAX from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
...@@ -111,12 +111,12 @@ def jax_funcify_DeepCopyOp(op, **kwargs): ...@@ -111,12 +111,12 @@ def jax_funcify_DeepCopyOp(op, **kwargs):
return deepcopyop return deepcopyop
@jax_funcify.register(ViewOp) @jax_funcify.register(TypeCastingOp)
def jax_funcify_ViewOp(op, **kwargs): def jax_funcify_TypeCastingOp(op, **kwargs):
def viewop(x): def type_cast(x):
return x return x
return viewop return type_cast
@jax_funcify.register(OpFromGraph) @jax_funcify.register(OpFromGraph)
......
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
import numpy as np import numpy as np
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import TypeCastingOp
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
...@@ -198,14 +198,14 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -198,14 +198,14 @@ def numba_funcify_Cast(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def viewop(x): def identity(x):
return x return x
@numba_funcify.register(Identity) @numba_funcify.register(Identity)
@numba_funcify.register(ViewOp) @numba_funcify.register(TypeCastingOp)
def numba_funcify_ViewOp(op, **kwargs): def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(viewop) return numba_basic.global_numba_func(identity)
@numba_basic.numba_njit @numba_basic.numba_njit
......
...@@ -9,7 +9,7 @@ from pytensor import In ...@@ -9,7 +9,7 @@ from pytensor import In
from pytensor.compile import PYTORCH from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
...@@ -71,6 +71,14 @@ def pytorch_funcify_FunctionGraph( ...@@ -71,6 +71,14 @@ def pytorch_funcify_FunctionGraph(
) )
@pytorch_funcify.register(TypeCastingOp)
def pytorch_funcify_CastingOp(op, node, **kwargs):
def type_cast(x):
return x
return type_cast
@pytorch_funcify.register(CheckAndRaise) @pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs): def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type error = op.exc_type
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论