提交 40437b94 authored 作者: khaotik's avatar khaotik

added name and __str__

上级 cda13bdc
...@@ -72,6 +72,9 @@ class OpFromGraph(gof.Op): ...@@ -72,6 +72,9 @@ class OpFromGraph(gof.Op):
:class:`Variable <theano.gof.Variable>`. Each list element corresponds :class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op. to a specific output of R_op.
name : string, optional
A name for debugging purposes
**kwargs : optional **kwargs : optional
Check Check
:func:`orig_function <theano.compile.function_module.orig_function>` :func:`orig_function <theano.compile.function_module.orig_function>`
...@@ -157,7 +160,7 @@ class OpFromGraph(gof.Op): ...@@ -157,7 +160,7 @@ class OpFromGraph(gof.Op):
fn(2., 3., 4.) # [1., 8., 3.] fn(2., 3., 4.) # [1., 8., 3.]
""" """
def __init__(self, inputs, outputs, inline=False, grad_overrides=None, rop_overrides=None, **kwargs): def __init__(self, inputs, outputs, inline=False, grad_overrides=None, rop_overrides=None, name=None, **kwargs):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs) raise TypeError('outputs must be list', outputs)
for i in inputs + outputs: for i in inputs + outputs:
...@@ -195,6 +198,10 @@ class OpFromGraph(gof.Op): ...@@ -195,6 +198,10 @@ class OpFromGraph(gof.Op):
self.set_grad_overrides(grad_overrides) self.set_grad_overrides(grad_overrides)
self.set_rop_overrides(rop_overrides) self.set_rop_overrides(rop_overrides)
if name is not None:
assert isinstance(name, str), 'name must be None or string object'
self.name = name
def __eq__(self, other): def __eq__(self, other):
# TODO: recognize a copy # TODO: recognize a copy
return self is other return self is other
...@@ -203,6 +210,11 @@ class OpFromGraph(gof.Op): ...@@ -203,6 +210,11 @@ class OpFromGraph(gof.Op):
# TODO: use internal variables in hash # TODO: use internal variables in hash
return hash(type(self)) return hash(type(self))
def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
is_inline = 'inline' if self.is_inline else 'compiled'
return '%(name)s{%(is_inline)s}'%locals()
def _recompute_grad_op(self): def _recompute_grad_op(self):
if isinstance(self._grad_op, OpFromGraph): if isinstance(self._grad_op, OpFromGraph):
self._grad_op_is_cached = True self._grad_op_is_cached = True
...@@ -259,8 +271,9 @@ class OpFromGraph(gof.Op): ...@@ -259,8 +271,9 @@ class OpFromGraph(gof.Op):
self._grad_op = type(self)( self._grad_op = type(self)(
inputs=self.local_inputs + output_grads, inputs=self.local_inputs + output_grads,
outputs=all_grads_l, outputs=all_grads_l,
inline=self.is_inline, on_unused_input='ignore', inline=self.is_inline,
) name=(None if self.name is None else self.name + '_grad'),
on_unused_input='ignore')
self._grad_op_is_cached = True self._grad_op_is_cached = True
def _recompute_rop_op(self): def _recompute_rop_op(self):
...@@ -320,7 +333,9 @@ class OpFromGraph(gof.Op): ...@@ -320,7 +333,9 @@ class OpFromGraph(gof.Op):
self._rop_op = type(self)( self._rop_op = type(self)(
inputs=self.local_inputs + eval_points, inputs=self.local_inputs + eval_points,
outputs=all_rops_l, outputs=all_rops_l,
inline=self.is_inline, on_unused_input='ignore') inline=self.is_inline,
name=(None if self.name is None else self.name + '_rop'),
on_unused_input='ignore')
self._rop_op_is_cached = True self._rop_op_is_cached = True
def get_grad_op(self): def get_grad_op(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论