提交 7ccde642 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Change TypedList Count and Index output to int64

上级 f49a6c51
......@@ -2,11 +2,10 @@ import numpy as np
import pytensor.tensor as pt
from pytensor.compile.debugmode import _lessbroken_deepcopy
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.tensor.type import scalar
from pytensor.tensor.type import lscalar
from pytensor.tensor.type_other import SliceType
from pytensor.tensor.variable import TensorVariable
from pytensor.typed_list.type import TypedListType
......@@ -508,7 +507,7 @@ class Index(Op):
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
return Apply(self, [x, elem], [scalar()])
return Apply(self, [x, elem], [lscalar()])
def perform(self, node, inputs, outputs):
"""
......@@ -520,7 +519,7 @@ class Index(Op):
(out,) = outputs
for y in range(len(x)):
if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] = np.asarray(y, dtype=config.floatX)
out[0] = np.asarray(y, dtype="int64")
break
def __str__(self):
......@@ -537,7 +536,7 @@ class Count(Op):
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
return Apply(self, [x, elem], [scalar()])
return Apply(self, [x, elem], [lscalar()])
def perform(self, node, inputs, outputs):
"""
......@@ -551,7 +550,7 @@ class Count(Op):
for y in range(len(x)):
if node.inputs[0].ttype.values_eq(x[y], elem):
out[0] += 1
out[0] = np.asarray(out[0], dtype=config.floatX)
out[0] = np.asarray(out[0], "int64")
def __str__(self):
return self.__class__.__name__
......@@ -583,7 +582,7 @@ class Length(COp):
def make_node(self, x):
assert isinstance(x.type, TypedListType)
return Apply(self, [x], [scalar(dtype="int64")])
return Apply(self, [x], [lscalar()])
def perform(self, node, x, outputs):
(out,) = outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论