提交 ad44ab0d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #318 from goodfeli/CAReduce_R_op

improved doc for R_op it all looks good. There is one line that I think does not respect PEP8 convention but I'll fix it in another pull request
......@@ -449,6 +449,38 @@ class PureOp(object):
# Python implementation #
#########################
def R_op(self, inputs, eval_points):
"""
This method is primarily used by tensor.Rop
Suppose the op outputs
[ f_1(inputs), ..., f_n(inputs) ]
inputs: a Variable or list of Variables
eval_points: a Variable or list of Variables with
the same length as inputs. Each element
of eval_points specifies the value of
the corresponding input at the point
where the R op is to be evaluated.
returns: a list of n elements
rval[i] should be Rop(f = f_i(inputs),
wrt = inputs,
eval_points = eval_points)
"""
raise NotImplementedError(str(self)+' of type '+str(self.__class__.__name__)
+" does not "
"implement R_op. If this is a theano op, write to the "
"theano-dev mailing list for assistance. If it is your "
"own op, implement the R_op method.")
def perform(self, node, inputs, output_storage):
"""
Required: Calculate the function on the inputs and put the variables in the
......
......@@ -1007,6 +1007,7 @@ class CAReduce(Op):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.")
self.scalar_op = scalar_op
if axis is None:
self.axis = axis
elif isinstance(axis, int):
......
......@@ -65,8 +65,10 @@ def Rop(f, wrt, eval_points):
assert len(wrt) == len(eval_points)
#check that each element of wrt corresponds to an element
#of eval_points with the same dimensionality
for pack in enumerate(zip(wrt, eval_points)):
i = pack[0]
i = pack[0]
wrt_elem, eval_point = pack[1]
wrt_elem = as_tensor_variable(wrt_elem)
......@@ -82,15 +84,13 @@ def Rop(f, wrt, eval_points):
seen_nodes = {}
def _traverse(node):
""" TODO: writeme """
if node is None:
return None
else:
op = node.op
inputs = node.inputs
if not hasattr(op, 'R_op'):
raise Exception((' R_op was not implemented for %s'
' operation. Email the mailing list'
' for help') % op.__class__.__name__)
# Compute the evaluation points corresponding to each of the
# inputs of the node
local_eval_points = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论