提交 7ccde642 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Change TypedList Count and Index output to int64

上级 f49a6c51
...@@ -2,11 +2,10 @@ import numpy as np ...@@ -2,11 +2,10 @@ import numpy as np
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.debugmode import _lessbroken_deepcopy from pytensor.compile.debugmode import _lessbroken_deepcopy
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.tensor.type import scalar from pytensor.tensor.type import lscalar
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.typed_list.type import TypedListType from pytensor.typed_list.type import TypedListType
...@@ -508,7 +507,7 @@ class Index(Op): ...@@ -508,7 +507,7 @@ class Index(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
return Apply(self, [x, elem], [scalar()]) return Apply(self, [x, elem], [lscalar()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" """
...@@ -520,7 +519,7 @@ class Index(Op): ...@@ -520,7 +519,7 @@ class Index(Op):
(out,) = outputs (out,) = outputs
for y in range(len(x)): for y in range(len(x)):
if node.inputs[0].ttype.values_eq(x[y], elem): if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] = np.asarray(y, dtype=config.floatX) out[0] = np.asarray(y, dtype="int64")
break break
def __str__(self): def __str__(self):
...@@ -537,7 +536,7 @@ class Count(Op): ...@@ -537,7 +536,7 @@ class Count(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
return Apply(self, [x, elem], [scalar()]) return Apply(self, [x, elem], [lscalar()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" """
...@@ -551,7 +550,7 @@ class Count(Op): ...@@ -551,7 +550,7 @@ class Count(Op):
for y in range(len(x)): for y in range(len(x)):
if node.inputs[0].ttype.values_eq(x[y], elem): if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] += 1 out[0] += 1
out[0] = np.asarray(out[0], dtype=config.floatX) out[0] = np.asarray(out[0], "int64")
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -583,7 +582,7 @@ class Length(COp): ...@@ -583,7 +582,7 @@ class Length(COp):
def make_node(self, x): def make_node(self, x):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
return Apply(self, [x], [scalar(dtype="int64")]) return Apply(self, [x], [lscalar()])
def perform(self, node, x, outputs): def perform(self, node, x, outputs):
(out,) = outputs (out,) = outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论