提交 8d3a512a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a pretty printer for IncSubtensor

上级 4a818873
...@@ -1190,50 +1190,44 @@ class Subtensor(COp): ...@@ -1190,50 +1190,44 @@ class Subtensor(COp):
class SubtensorPrinter(Printer): class SubtensorPrinter(Printer):
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: return self._process(r.owner.op.idx_list, r.owner.inputs, pstate)
raise TypeError("Can only print Subtensor.")
elif isinstance(r.owner.op, Subtensor): def _process(self, idxs, op_inputs, pstate):
idxs = r.owner.op.idx_list inputs = list(op_inputs)
inputs = list(r.owner.inputs) input = inputs.pop(0)
input = inputs.pop(0) sidxs = []
sidxs = [] old_precedence = getattr(pstate, "precedence", None)
old_precedence = getattr(pstate, "precedence", None) for entry in idxs:
try: if isinstance(entry, aes.Scalar):
pstate.precedence = -1000 pstate.precedence = -1000
try:
sidxs.append(pstate.pprinter.process(inputs.pop()))
finally:
pstate.precedence = old_precedence
elif isinstance(entry, slice):
if entry.start is None or entry.start == 0:
msg1 = ""
else:
msg1 = entry.start
for entry in idxs: if entry.stop is None or entry.stop == sys.maxsize:
if isinstance(entry, int): msg2 = ""
sidxs.append(str(entry)) else:
elif isinstance(entry, aes.Scalar): msg2 = entry.stop
sidxs.append(pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice): if entry.step is None:
if entry.start is None or entry.start == 0: msg3 = ""
msg1 = "" else:
else: msg3 = f":{entry.step}"
msg1 = entry.start
sidxs.append(f"{msg1}:{msg2}{msg3}")
if entry.stop is None or entry.stop == sys.maxsize:
msg2 = "" try:
else: pstate.precedence = 1000
msg2 = entry.stop sub = pstate.pprinter.process(input, pstate)
finally:
if entry.step is None: pstate.precedence = old_precedence
msg3 = "" return f"{sub}[{', '.join(sidxs)}]"
else:
msg3 = f":{entry.step}"
sidxs.append(f"{msg1}:{msg2}{msg3}")
finally:
pstate.precedence = old_precedence
try:
pstate.precedence = 1000
sub = pstate.pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return f"{sub}[{', '.join(sidxs)}]"
else:
raise TypeError("Can only print Subtensor.")
pprint.assign(Subtensor, SubtensorPrinter()) pprint.assign(Subtensor, SubtensorPrinter())
...@@ -1839,6 +1833,29 @@ class IncSubtensor(COp): ...@@ -1839,6 +1833,29 @@ class IncSubtensor(COp):
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
class IncSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
x, y, *idx_args = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x] + idx_args, pstate)
old_precedence = pstate.precedence
try:
pstate.precedence = 1000
y_str = pstate.pprinter.process(r.owner.inputs[1], pstate)
finally:
pstate.precedence = old_precedence
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
else:
res = f"inc_subtensor({res}, {y_str})"
return res
pprint.assign(IncSubtensor, IncSubtensorPrinter())
def _sum_grad_over_bcasted_dims(x, gx): def _sum_grad_over_bcasted_dims(x, gx):
""" """
Sum of gx over dimensions to reproduce x.broadcastable. Sum of gx over dimensions to reproduce x.broadcastable.
......
...@@ -14,6 +14,7 @@ from aesara.compile.io import In ...@@ -14,6 +14,7 @@ from aesara.compile.io import In
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.opt_utils import is_same_graph from aesara.graph.opt_utils import is_same_graph
from aesara.printing import pprint
from aesara.scalar.basic import as_scalar from aesara.scalar.basic import as_scalar
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
...@@ -2524,3 +2525,36 @@ def test_get_vector_length(): ...@@ -2524,3 +2525,36 @@ def test_get_vector_length():
with pytest.raises(ValueError, match="^Length of .*"): with pytest.raises(ValueError, match="^Length of .*"):
get_vector_length(x[lscalar() :]) get_vector_length(x[lscalar() :])
@pytest.mark.parametrize(
"indices, exp_res",
[
((0,), "x[0]"),
# TODO: The numbers should be printed
((slice(None, 2),), "x[:int64]"),
((slice(0, None),), "x[int64:]"),
((slice(0, 2),), "x[int64:int64]"),
((slice(0, 2, 2),), "x[int64:int64:int64]"),
((slice(0, 2), 0, slice(0, 2)), "x[int64:int64, 2, int64:int64]"),
],
)
def test_pprint_Subtensor(indices, exp_res):
x = tensor4("x")
y = x[indices]
assert pprint(y) == exp_res
@pytest.mark.parametrize(
"indices, set_instead_of_inc, exp_res",
[
((0,), False, "inc_subtensor(x[0], z)"),
((0,), True, "set_subtensor(x[0], z)"),
((slice(0, 2),), True, "set_subtensor(x[int64:int64], z)"),
],
)
def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
x = tensor4("x")
z = tensor3("z")
y = inc_subtensor(x[indices], z, set_instead_of_inc=set_instead_of_inc)
assert pprint(y) == exp_res
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论