提交 b8141fca authored 作者: Ian Goodfellow's avatar Ian Goodfellow

improved doc for R_op

上级 9517a606
...@@ -313,7 +313,7 @@ class PureOp(object): ...@@ -313,7 +313,7 @@ class PureOp(object):
add_stack_trace_on_call = True add_stack_trace_on_call = True
"""This class variable governs whether __call__ adds a stack trace to the node it creates. """This class variable governs whether __call__ adds a stack trace to the node it creates.
The tag trace is meant to connect a node to the line a user typed. It is nice for The tag trace is meant to connect a node to the line a user typed. It is nice for
debugging. It does not make as much sense during optimizations to store this information. debugging. It does not make as much sense during optimizations to store this information.
""" """
...@@ -449,6 +449,38 @@ class PureOp(object): ...@@ -449,6 +449,38 @@ class PureOp(object):
# Python implementation # # Python implementation #
######################### #########################
def R_op(self, inputs, eval_points):
"""
This method is primarily used by tensor.Rop
Suppose the op outputs
[ f_1(inputs), ..., f_n(inputs) ]
inputs: a Variable or list of Variables
eval_points: a Variable or list of Variables with
the same length as inputs. Each element
of eval_points specifies the value of
the corresponding input at the point
where the R op is to be evaluated.
returns: a list of n elements
rval[i] should be Rop(f = f_i(inputs),
wrt = inputs,
eval_points = eval_points)
"""
raise NotImplementedError(str(self)+' of type '+str(self.__class__.__name__)
+" does not "
"implement R_op. If this is a theano op, write to the "
"theano-dev mailing list for assistance. If it is your "
"own op, implement the R_op method.")
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
""" """
Required: Calculate the function on the inputs and put the variables in the Required: Calculate the function on the inputs and put the variables in the
......
...@@ -1007,6 +1007,7 @@ class CAReduce(Op): ...@@ -1007,6 +1007,7 @@ class CAReduce(Op):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1: if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.") raise NotImplementedError("CAReduce only supports binary functions with a single output.")
self.scalar_op = scalar_op self.scalar_op = scalar_op
if axis is None: if axis is None:
self.axis = axis self.axis = axis
elif isinstance(axis, int): elif isinstance(axis, int):
......
...@@ -65,8 +65,10 @@ def Rop(f, wrt, eval_points): ...@@ -65,8 +65,10 @@ def Rop(f, wrt, eval_points):
assert len(wrt) == len(eval_points) assert len(wrt) == len(eval_points)
#check that each element of wrt corresponds to an element
#of eval_points with the same dimensionality
for pack in enumerate(zip(wrt, eval_points)): for pack in enumerate(zip(wrt, eval_points)):
i = pack[0] i = pack[0]
wrt_elem, eval_point = pack[1] wrt_elem, eval_point = pack[1]
wrt_elem = as_tensor_variable(wrt_elem) wrt_elem = as_tensor_variable(wrt_elem)
...@@ -82,15 +84,13 @@ def Rop(f, wrt, eval_points): ...@@ -82,15 +84,13 @@ def Rop(f, wrt, eval_points):
seen_nodes = {} seen_nodes = {}
def _traverse(node): def _traverse(node):
""" TODO: writeme """
if node is None: if node is None:
return None return None
else: else:
op = node.op op = node.op
inputs = node.inputs inputs = node.inputs
if not hasattr(op, 'R_op'):
raise Exception((' R_op was not implemented for %s'
' operation. Email the mailing list'
' for help') % op.__class__.__name__)
# Compute the evaluation points corresponding to each of the # Compute the evaluation points corresponding to each of the
# inputs of the node # inputs of the node
local_eval_points = [] local_eval_points = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论