提交 29da6642 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5389 from nouiz/printer_crash

theano.pp crash fix and display fix
...@@ -967,14 +967,17 @@ class SubtensorPrinter: ...@@ -967,14 +967,17 @@ class SubtensorPrinter:
elif isinstance(r.owner.op, Subtensor): elif isinstance(r.owner.op, Subtensor):
idxs = r.owner.op.idx_list idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs) inputs = list(r.owner.inputs)
input = inputs.pop() input = inputs.pop(0)
sidxs = [] sidxs = []
inbrack_pstate = pstate.clone(precedence=-1000) old_precedence = getattr(pstate, 'precedence', None)
try:
pstate.precedence = -1000
for entry in idxs: for entry in idxs:
if isinstance(entry, integer_types): if isinstance(entry, integer_types):
sidxs.append(str(entry)) sidxs.append(str(entry))
elif isinstance(entry, scal.Scalar): elif isinstance(entry, scal.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop())) sidxs.append(pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice): elif isinstance(entry, slice):
if entry.start is None or entry.start == 0: if entry.start is None or entry.start == 0:
msg1 = "" msg1 = ""
...@@ -992,10 +995,15 @@ class SubtensorPrinter: ...@@ -992,10 +995,15 @@ class SubtensorPrinter:
msg3 = ":%s" % entry.step msg3 = ":%s" % entry.step
sidxs.append("%s:%s%s" % (msg1, msg2, msg3)) sidxs.append("%s:%s%s" % (msg1, msg2, msg3))
return "%s[%s]" % (pstate.pprinter.process( finally:
input, pstate.precedence = old_precedence
pstate.clone(precedence=1000)),
", ".join(sidxs)) try:
pstate.precedence = 1000
sub = pstate.pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return "%s[%s]" % (sub, ", ".join(sidxs))
else: else:
raise TypeError("Can only print Subtensor.") raise TypeError("Can only print Subtensor.")
......
...@@ -738,3 +738,9 @@ def test_printing_scan(): ...@@ -738,3 +738,9 @@ def test_printing_scan():
allow_input_downcast=True) allow_input_downcast=True)
theano.printing.pydotprint(output, scan_graphs=True) theano.printing.pydotprint(output, scan_graphs=True)
theano.printing.pydotprint(f, scan_graphs=True) theano.printing.pydotprint(f, scan_graphs=True)
def test_subtensor():
x = theano.tensor.dvector()
y = x[1]
assert theano.pp(y) == "<TensorType(float64, vector)>[Constant{1}]"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论