提交 1b02e36d authored 作者: James Bergstra's avatar James Bergstra

Print - refactored for more flexibility and printing of GPU tensors

上级 c84d102d
...@@ -66,6 +66,16 @@ def debugprint(obj, depth=-1, print_type=False, file=None): ...@@ -66,6 +66,16 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
else: else:
_file.flush() _file.flush()
def _print_fn(op, xin):
for attr in op.attrs:
temp = getattr(xin, attr)
if callable(temp):
pmsg = temp()
else:
pmsg = temp
print op.message, attr,'=', pmsg
class Print(Op): class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs """This identity-like Op has the side effect of printing a message followed by its inputs
when it runs. Default behaviour is to print the __str__ representation. Optionally, one when it runs. Default behaviour is to print the __str__ representation. Optionally, one
...@@ -80,9 +90,10 @@ class Print(Op): ...@@ -80,9 +90,10 @@ class Print(Op):
:note: WARNING. This can disable some optimization(speed and stabilization)! :note: WARNING. This can disable some optimization(speed and stabilization)!
""" """
view_map={0:[0]} view_map={0:[0]}
def __init__(self,message="", attrs=("__str__",)): def __init__(self,message="", attrs=("__str__",), global_fn=_print_fn):
self.message=message self.message=message
self.attrs=tuple(attrs) # attrs should be a hashable iterable self.attrs=tuple(attrs) # attrs should be a hashable iterable
self.global_fn=global_fn
def make_node(self,xin): def make_node(self,xin):
xout = xin.type.make_variable() xout = xin.type.make_variable()
...@@ -92,13 +103,7 @@ class Print(Op): ...@@ -92,13 +103,7 @@ class Print(Op):
xin, = inputs xin, = inputs
xout, = output_storage xout, = output_storage
xout[0] = xin xout[0] = xin
for attr in self.attrs: self.global_fn(self, xin)
temp = getattr(xin, attr)
if callable(temp):
pmsg = temp()
else:
pmsg = temp
print self.message, attr,'=', pmsg
def grad(self,input,output_gradients): def grad(self,input,output_gradients):
return output_gradients return output_gradients
...@@ -109,6 +114,10 @@ class Print(Op): ...@@ -109,6 +114,10 @@ class Print(Op):
def __hash__(self): def __hash__(self):
return hash(self.message) ^ hash(self.attrs) return hash(self.message) ^ hash(self.attrs)
def __setstate__(self, dct):
dct.setdefault('global_fn', _print_fn)
self.__dict__.update(dct)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
......
...@@ -390,6 +390,10 @@ def local_gpu_rebroadcast(node): ...@@ -390,6 +390,10 @@ def local_gpu_rebroadcast(node):
gpu_x = x.owner.inputs[0] gpu_x = x.owner.inputs[0]
return [host_from_gpu(node.op(gpu_x))] return [host_from_gpu(node.op(gpu_x))]
def gpu_print_wrapper(op, cnda):
op.old_op.global_fn(op.old_op, numpy.asarray(cnda))
@register_opt() @register_opt()
@local_optimizer([]) @local_optimizer([])
def local_print_op(node): def local_print_op(node):
...@@ -397,7 +401,9 @@ def local_print_op(node): ...@@ -397,7 +401,9 @@ def local_print_op(node):
x, = node.inputs x, = node.inputs
if x.owner and x.owner.op == host_from_gpu: if x.owner and x.owner.op == host_from_gpu:
gpu_x, = x.owner.inputs gpu_x, = x.owner.inputs
return [host_from_gpu(node.op(gpu_x))] new_op = node.op.__class__(global_fn=gpu_print_wrapper)
new_op.old_op = node.op
return [host_from_gpu(new_op(gpu_x))]
return False return False
def cast(x, dtype): def cast(x, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论