提交 688d6883 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move TypeCastingOp dispatcher to basic.py

This isn't strictly needed but it's a more intuitive placement
上级 686ed878
......@@ -10,7 +10,7 @@ from pytensor import In, config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
......@@ -328,6 +328,15 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
return opfromgraph
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
@numba_njit
def identity(x):
return x
return identity
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
......
......@@ -2,7 +2,6 @@ import math
import numpy as np
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
......@@ -197,7 +196,6 @@ def numba_funcify_Cast(op, node, **kwargs):
@numba_funcify.register(Identity)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit
def identity(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论