提交 c080a640 authored 作者: james@X40's avatar james@X40

added support for dot pseudo-opreator

上级 60b10b99
...@@ -575,6 +575,11 @@ class _tensor_py_operators: ...@@ -575,6 +575,11 @@ class _tensor_py_operators:
""" """
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
""" The dtype of this tensor. """ """ The dtype of this tensor. """
#extra pseudo-operator symbols
def __dot__(left, right): return dot(left, right)
def __rdot__(right, left): return dot(left, right)
class TensorResult(Result, _tensor_py_operators): class TensorResult(Result, _tensor_py_operators):
...@@ -2089,13 +2094,14 @@ outer = Outer() ...@@ -2089,13 +2094,14 @@ outer = Outer()
# Gradient # Gradient
######################### #########################
def grad(cost, wrt, g_cost=None): def grad(cost, wrt, g_cost=None, consider_constant=[]):
""" """
@type cost: L{Result} @type cost: L{Result}
@type wrt: L{Result} or list of L{Result}s. @type wrt: L{Result} or list of L{Result}s.
@type g_cost: L{Result} broadcastable to size of I{cost}, or None @type g_cost: L{Result} broadcastable to size of I{cost}, or None
@param g_cost: an expression for the gradient through cost. The default is @param g_cost: an expression for the gradient through cost. The default is
{{{ones_like(cost)}}} {{{ones_like(cost)}}}
@param consider_constant: a list of expressions not to backpropagate through
@rtype: L{Result} or list of L{Result}s (depending upon I{wrt}) @rtype: L{Result} or list of L{Result}s (depending upon I{wrt})
@return: symbolic expression of gradient of I{cost} with respect to I{wrt}. @return: symbolic expression of gradient of I{cost} with respect to I{wrt}.
...@@ -2111,7 +2117,7 @@ def grad(cost, wrt, g_cost=None): ...@@ -2111,7 +2117,7 @@ def grad(cost, wrt, g_cost=None):
if g_cost is None: if g_cost is None:
g_cost = ones_like(cost) g_cost = ones_like(cost)
inputs = gof.graph.inputs([cost]) inputs = gof.graph.inputs([cost])
gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs) gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant)
def zero(p): def zero(p):
return TensorConstant( return TensorConstant(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论