提交 a8103cfd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix local_subtensor_make_vector dtype bug

上级 f8e7fe07
...@@ -21,7 +21,6 @@ from aesara.tensor.basic import ( ...@@ -21,7 +21,6 @@ from aesara.tensor.basic import (
cast, cast,
extract_constant, extract_constant,
get_scalar_constant_value, get_scalar_constant_value,
make_vector,
patternbroadcast, patternbroadcast,
switch, switch,
) )
...@@ -714,6 +713,8 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -714,6 +713,8 @@ def local_subtensor_make_vector(fgraph, node):
if not x.owner or not isinstance(x.owner.op, MakeVector): if not x.owner or not isinstance(x.owner.op, MakeVector):
return False return False
make_vector_op = x.owner.op
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
try: try:
...@@ -753,7 +754,7 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -753,7 +754,7 @@ def local_subtensor_make_vector(fgraph, node):
pass pass
elif idx.ndim == 1 and isinstance(idx, Constant): elif idx.ndim == 1 and isinstance(idx, Constant):
values = list(map(int, list(idx.value))) values = list(map(int, list(idx.value)))
ret = make_vector(*[x.owner.inputs[v] for v in values]) ret = make_vector_op(*[x.owner.inputs[v] for v in values])
# Copy over stack trace from previous output to new output # Copy over stack trace from previous output to new output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
...@@ -767,7 +768,7 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -767,7 +768,7 @@ def local_subtensor_make_vector(fgraph, node):
# it can, then try to unpack them. # it can, then try to unpack them.
try: try:
const_slice = node.op.get_constant_idx(node.inputs, allow_partial=False)[0] const_slice = node.op.get_constant_idx(node.inputs, allow_partial=False)[0]
ret = make_vector(*x.owner.inputs[const_slice]) ret = make_vector_op(*x.owner.inputs[const_slice])
# Copy over stack trace from previous outputs to new output # Copy over stack trace from previous outputs to new output
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable) ret = patternbroadcast(ret, node.outputs[0].broadcastable)
......
...@@ -45,6 +45,7 @@ from aesara.tensor.type import ( ...@@ -45,6 +45,7 @@ from aesara.tensor.type import (
dmatrix, dmatrix,
fmatrix, fmatrix,
iscalar, iscalar,
iscalars,
ivector, ivector,
lscalar, lscalar,
lscalars, lscalars,
...@@ -567,6 +568,30 @@ class TestLocalSubtensorMakeVector: ...@@ -567,6 +568,30 @@ class TestLocalSubtensorMakeVector:
assert isinstance(prog[0].op, DeepCopyOp) assert isinstance(prog[0].op, DeepCopyOp)
assert f(0, 1, 2) == 0 assert f(0, 1, 2) == 0
def test_idx_sybmolic(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
idx = aet.as_tensor([0], dtype=np.int64)
f = function([x, y, z], v[idx], mode=mode_opt)
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, Rebroadcast)
assert isinstance(opt_fgraph.outputs[0].owner.inputs[0].owner.op, MakeVector)
assert f(0, 1, 2) == np.array([0], dtype=np.int32)
def test_slice_idx_start(self):
x, y, z = iscalars("xyz")
v = MakeVector("int32")(x, y, z)
f = function([x, y, z], v[1:], mode=mode_opt, on_unused_input="ignore")
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert len(opt_fgraph.outputs[0].owner.inputs) == 2
r = f(0, 1, 2)
assert r[0] == 1 and r[1] == 2
def test_slice_idx_stop(self): def test_slice_idx_stop(self):
x, y, z = lscalars("xyz") x, y, z = lscalars("xyz")
v = make_vector(x, y, z) v = make_vector(x, y, z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论