提交 1b02e36d authored 作者: James Bergstra's avatar James Bergstra

Print - refactored for more flexibility and printing of GPU tensors

上级 c84d102d
......@@ -66,6 +66,16 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
else:
_file.flush()
def _print_fn(op, xin):
for attr in op.attrs:
temp = getattr(xin, attr)
if callable(temp):
pmsg = temp()
else:
pmsg = temp
print op.message, attr,'=', pmsg
class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs
when it runs. Default behaviour is to print the __str__ representation. Optionally, one
......@@ -80,9 +90,10 @@ class Print(Op):
:note: WARNING. This can disable some optimization(speed and stabilization)!
"""
view_map={0:[0]}
def __init__(self,message="", attrs=("__str__",)):
def __init__(self,message="", attrs=("__str__",), global_fn=_print_fn):
self.message=message
self.attrs=tuple(attrs) # attrs should be a hashable iterable
self.global_fn=global_fn
def make_node(self,xin):
xout = xin.type.make_variable()
......@@ -92,13 +103,7 @@ class Print(Op):
xin, = inputs
xout, = output_storage
xout[0] = xin
for attr in self.attrs:
temp = getattr(xin, attr)
if callable(temp):
pmsg = temp()
else:
pmsg = temp
print self.message, attr,'=', pmsg
self.global_fn(self, xin)
def grad(self,input,output_gradients):
return output_gradients
......@@ -109,6 +114,10 @@ class Print(Op):
def __hash__(self):
return hash(self.message) ^ hash(self.attrs)
def __setstate__(self, dct):
dct.setdefault('global_fn', _print_fn)
self.__dict__.update(dct)
def c_code_cache_version(self):
return (1,)
......
......@@ -390,6 +390,10 @@ def local_gpu_rebroadcast(node):
gpu_x = x.owner.inputs[0]
return [host_from_gpu(node.op(gpu_x))]
def gpu_print_wrapper(op, cnda):
op.old_op.global_fn(op.old_op, numpy.asarray(cnda))
@register_opt()
@local_optimizer([])
def local_print_op(node):
......@@ -397,7 +401,9 @@ def local_print_op(node):
x, = node.inputs
if x.owner and x.owner.op == host_from_gpu:
gpu_x, = x.owner.inputs
return [host_from_gpu(node.op(gpu_x))]
new_op = node.op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = node.op
return [host_from_gpu(new_op(gpu_x))]
return False
def cast(x, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论