提交 29da6642 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5389 from nouiz/printer_crash

theano.pp crash fix and display fix
......@@ -967,35 +967,43 @@ class SubtensorPrinter:
elif isinstance(r.owner.op, Subtensor):
idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs)
input = inputs.pop()
input = inputs.pop(0)
sidxs = []
inbrack_pstate = pstate.clone(precedence=-1000)
for entry in idxs:
if isinstance(entry, integer_types):
sidxs.append(str(entry))
elif isinstance(entry, scal.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice):
if entry.start is None or entry.start == 0:
msg1 = ""
else:
msg1 = entry.start
if entry.stop is None or entry.stop == sys.maxsize:
msg2 = ""
else:
msg2 = entry.stop
if entry.step is None:
msg3 = ""
else:
msg3 = ":%s" % entry.step
old_precedence = getattr(pstate, 'precedence', None)
try:
pstate.precedence = -1000
for entry in idxs:
if isinstance(entry, integer_types):
sidxs.append(str(entry))
elif isinstance(entry, scal.Scalar):
sidxs.append(pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice):
if entry.start is None or entry.start == 0:
msg1 = ""
else:
msg1 = entry.start
if entry.stop is None or entry.stop == sys.maxsize:
msg2 = ""
else:
msg2 = entry.stop
if entry.step is None:
msg3 = ""
else:
msg3 = ":%s" % entry.step
sidxs.append("%s:%s%s" % (msg1, msg2, msg3))
finally:
pstate.precedence = old_precedence
sidxs.append("%s:%s%s" % (msg1, msg2, msg3))
return "%s[%s]" % (pstate.pprinter.process(
input,
pstate.clone(precedence=1000)),
", ".join(sidxs))
try:
pstate.precedence = 1000
sub = pstate.pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return "%s[%s]" % (sub, ", ".join(sidxs))
else:
raise TypeError("Can only print Subtensor.")
......
......@@ -738,3 +738,9 @@ def test_printing_scan():
allow_input_downcast=True)
theano.printing.pydotprint(output, scan_graphs=True)
theano.printing.pydotprint(f, scan_graphs=True)
def test_subtensor():
x = theano.tensor.dvector()
y = x[1]
assert theano.pp(y) == "<TensorType(float64, vector)>[Constant{1}]"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论