提交 688d6883 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move TypeCastingOp dispatcher to basic.py

This isn't strictly needed but it's a more intuitive placement
上级 686ed878
...@@ -10,7 +10,7 @@ from pytensor import In, config ...@@ -10,7 +10,7 @@ from pytensor import In, config
from pytensor.compile import NUMBA from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type from pytensor.graph.type import Type
...@@ -328,6 +328,15 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): ...@@ -328,6 +328,15 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
return opfromgraph return opfromgraph
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
@numba_njit
def identity(x):
return x
return identity
@numba_funcify.register(DeepCopyOp) @numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs): def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType): if isinstance(node.inputs[0].type, TensorType):
......
...@@ -2,7 +2,6 @@ import math ...@@ -2,7 +2,6 @@ import math
import numpy as np import numpy as np
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
...@@ -197,7 +196,6 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -197,7 +196,6 @@ def numba_funcify_Cast(op, node, **kwargs):
@numba_funcify.register(Identity) @numba_funcify.register(Identity)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs): def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def identity(x): def identity(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论