提交 536d3b00 authored 作者: --global's avatar --global

Alter PdbBreakpoint to show the user modifiables cpu-versions of the monitored variables

上级 ee6ec110
import numpy
import pdb
import theano
......@@ -78,14 +79,17 @@ class PdbBreakpoint(Op):
# Because the user might be tempted to instantiate PdbBreakpoint only
# once and apply it many times on different number of inputs, we must
# create a new instance of the op here, define the view_map in that
# instance and then apply it on the inputs.
# create a new instance of the op here, define the instance attributes
# (view_map and var_types) in that instance and then apply it on the
# inputs.
new_op = PdbBreakpoint(name=self.name)
new_op.view_map = {}
new_op.inp_types = []
for i in range(len(monitored_vars)):
# Every output i is a view of the input i+1 because of the input
# condition.
new_op.view_map[i] = [i+1]
new_op.inp_types.append(monitored_vars[i].type)
# Build the Apply node
inputs = [condition] + list(monitored_vars)
......@@ -94,18 +98,28 @@ class PdbBreakpoint(Op):
def perform(self, node, inputs, output_storage):
condition = inputs[0]
monitored = inputs[1:]
try:
monitored = [numpy.asarray(inp) for inp in inputs[1:]]
except:
raise ValueError("Some of the inputs to the PdbBreakpoint op '%s'"
"could not be casted to NumPy arrays" %
self.name)
if condition:
print "-------------------------------------------------"
print "Conditional breakpoint %s activated" % self.name
print "The monitored variables are stored, in order,"
print "in the list variable 'monitored'"
print "-------------------------------------------------"
print("\n")
print("-------------------------------------------------")
print("Conditional breakpoint '%s' activated\n" % self.name)
print("The monitored variables are stored, in order,")
print("in the list variable 'monitored' as NumPy arrays.\n")
print("Their contents can be altered and, when execution")
print("resumes, the updated values will be used.")
print("-------------------------------------------------")
pdb.set_trace()
# Take the new values in monitored, cast them back to their original
# type and store them in the output_storage
for i in range(len(output_storage)):
output_storage[i][0] = monitored[i]
output_storage[i][0] = self.inp_types[i].filter(monitored[i])
def grad(self, inputs, output_gradients):
return ([DisconnectedType()()] + output_gradients)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论