提交 40437b94 authored 作者: khaotik's avatar khaotik

added name and __str__

上级 cda13bdc
......@@ -72,6 +72,9 @@ class OpFromGraph(gof.Op):
:class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op.
name : string, optional
A name for debugging purposes
**kwargs : optional
Check
:func:`orig_function <theano.compile.function_module.orig_function>`
......@@ -157,7 +160,7 @@ class OpFromGraph(gof.Op):
fn(2., 3., 4.) # [1., 8., 3.]
"""
def __init__(self, inputs, outputs, inline=False, grad_overrides=None, rop_overrides=None, **kwargs):
def __init__(self, inputs, outputs, inline=False, grad_overrides=None, rop_overrides=None, name=None, **kwargs):
if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs)
for i in inputs + outputs:
......@@ -195,6 +198,10 @@ class OpFromGraph(gof.Op):
self.set_grad_overrides(grad_overrides)
self.set_rop_overrides(rop_overrides)
if name is not None:
assert isinstance(name, str), 'name must be None or string object'
self.name = name
def __eq__(self, other):
# TODO: recognize a copy
return self is other
......@@ -203,6 +210,11 @@ class OpFromGraph(gof.Op):
# TODO: use internal variables in hash
return hash(type(self))
def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
is_inline = 'inline' if self.is_inline else 'compiled'
return '%(name)s{%(is_inline)s}'%locals()
def _recompute_grad_op(self):
if isinstance(self._grad_op, OpFromGraph):
self._grad_op_is_cached = True
......@@ -259,8 +271,9 @@ class OpFromGraph(gof.Op):
self._grad_op = type(self)(
inputs=self.local_inputs + output_grads,
outputs=all_grads_l,
inline=self.is_inline, on_unused_input='ignore',
)
inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'),
on_unused_input='ignore')
self._grad_op_is_cached = True
def _recompute_rop_op(self):
......@@ -320,7 +333,9 @@ class OpFromGraph(gof.Op):
self._rop_op = type(self)(
inputs=self.local_inputs + eval_points,
outputs=all_rops_l,
inline=self.is_inline, on_unused_input='ignore')
inline=self.is_inline,
name=(None if self.name is None else self.name + '_rop'),
on_unused_input='ignore')
self._rop_op_is_cached = True
def get_grad_op(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论