提交 bf69ab87 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

lots of stuff

上级 67b911ab
...@@ -540,7 +540,8 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to ...@@ -540,7 +540,8 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to
num_grad = gradient.numeric_grad(cost_fn, pt) num_grad = gradient.numeric_grad(cost_fn, pt)
symbolic_grad = grad(cost, tensor_pt,as_tensor(1.0,name='g_cost')) #symbolic_grad = exec_grad(cost, tensor_pt,as_tensor(1.0,name='g_cost'))
symbolic_grad = grad.make_node(cost, tensor_pt).outputs
if 0: if 0:
print '-------' print '-------'
print '----------' print '----------'
...@@ -898,7 +899,7 @@ class T_subtensor(unittest.TestCase): ...@@ -898,7 +899,7 @@ class T_subtensor(unittest.TestCase):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
z = scal.constant(0) z = scal.constant(0)
t = n[z:,z] t = n[z:,z]
gn = grad(sum(exp(t)), n) gn = exec_grad(sum(exp(t)), n)
gval = eval_outputs([gn]) gval = eval_outputs([gn])
s0 = 'array([ 2.05362099, 0. , 0. ])' s0 = 'array([ 2.05362099, 0. , 0. ])'
s1 = 'array([ 1.55009327, 0. , 0. ])' s1 = 'array([ 1.55009327, 0. , 0. ])'
...@@ -908,7 +909,7 @@ class T_subtensor(unittest.TestCase): ...@@ -908,7 +909,7 @@ class T_subtensor(unittest.TestCase):
def test_grad_0d(self): def test_grad_0d(self):
n = as_tensor(numpy.random.rand(2,3)) n = as_tensor(numpy.random.rand(2,3))
t = n[1,0] t = n[1,0]
gn = grad(sum(exp(t)), n) gn = exec_grad(sum(exp(t)), n)
gval = eval_outputs([gn]) gval = eval_outputs([gn])
g0 = repr(gval[0,:]) g0 = repr(gval[0,:])
g1 = repr(gval[1,:]) g1 = repr(gval[1,:])
...@@ -937,7 +938,7 @@ class T_Stack(unittest.TestCase): ...@@ -937,7 +938,7 @@ class T_Stack(unittest.TestCase):
a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]])) a = as_tensor(numpy.array([[1, 2, 3], [4, 5, 6]]))
b = as_tensor(numpy.array([[7, 8, 9]])) b = as_tensor(numpy.array([[7, 8, 9]]))
s = vertical_stack(a, b) s = vertical_stack(a, b)
ga,gb = grad(sum(vertical_stack(a,b)), [a,b]) ga,gb = exec_grad(sum(vertical_stack(a,b)), [a,b])
gval = eval_outputs([ga, gb]) gval = eval_outputs([ga, gb])
self.failUnless(numpy.all(gval[0] == 1.0)) self.failUnless(numpy.all(gval[0] == 1.0))
...@@ -1671,13 +1672,13 @@ class _test_grad(unittest.TestCase): ...@@ -1671,13 +1672,13 @@ class _test_grad(unittest.TestCase):
"""grad: Test passing a single result param""" """grad: Test passing a single result param"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
self.failUnless(o.gval0 is grad(a1.outputs[0], a1.inputs[0])) self.failUnless(o.gval0 is exec_grad(a1.outputs[0], a1.inputs[0]))
def test_Nparam(self): def test_Nparam(self):
"""grad: Test passing multiple result params""" """grad: Test passing multiple result params"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1 = grad(a1.outputs[0], a1.inputs) g0,g1 = exec_grad(a1.outputs[0], a1.inputs)
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1) self.failUnless(o.gval1 is g1)
...@@ -1685,13 +1686,13 @@ class _test_grad(unittest.TestCase): ...@@ -1685,13 +1686,13 @@ class _test_grad(unittest.TestCase):
"""grad: Test returning a single None from grad""" """grad: Test returning a single None from grad"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
self.failUnless(None is grad(a1.outputs[0], a1.outputs[1])) self.failUnless(None is exec_grad(a1.outputs[0], a1.outputs[1]))
self.failUnless(None is grad(a1.outputs[0], 'wtf')) self.failUnless(None is exec_grad(a1.outputs[0], 'wtf'))
def test_NNone_rval(self): def test_NNone_rval(self):
"""grad: Test returning some Nones from grad""" """grad: Test returning some Nones from grad"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + ['wtf']) g0,g1,g2 = exec_grad(a1.outputs[0], a1.inputs + ['wtf'])
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1) self.failUnless(o.gval1 is g1)
self.failUnless(None is g2) self.failUnless(None is g2)
......
## PENDING REWRITE OF tensor_opt.py ## PENDING REWRITE OF tensor_opt.py
# import unittest import unittest
# import gof import gof
# from tensor_opt import * from tensor_opt import *
# import tensor import tensor
# from tensor import Tensor from tensor import Tensor
# from gof import Env from gof import Env
# from elemwise import DimShuffle from elemwise import DimShuffle
# import numpy import numpy
# import scalar_opt #import scalar_opt
# def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# x = Tensor(broadcastable = xbc, dtype = 'float64')('x') x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
# y = Tensor(broadcastable = ybc, dtype = 'float64')('y') y = Tensor(broadcastable = ybc, dtype = 'float64')('y')
# z = Tensor(broadcastable = zbc, dtype = 'float64')('z') z = Tensor(broadcastable = zbc, dtype = 'float64')('z')
# return x, y, z return x, y, z
# ds = DimShuffle
# class _test_inplace_opt(unittest.TestCase): # class _test_inplace_opt(unittest.TestCase):
...@@ -60,39 +58,45 @@ ...@@ -60,39 +58,45 @@
# self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]") # self.failUnless(str(g) == "[Broadcast{Add}{0: 1}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
# class _test_dimshuffle_lift(unittest.TestCase): ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
# def test_double_transpose(self): class _test_dimshuffle_lift(unittest.TestCase):
# x, y, z = inputs()
# e = ds(ds(x, (1, 0)), (1, 0)) def test_double_transpose(self):
# g = Env([x], [e]) x, y, z = inputs()
# self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x))]") e = ds(ds(x, (1, 0)), (1, 0))
# lift_dimshuffle.optimize(g) g = Env([x], [e])
# self.failUnless(str(g) == "[x]") self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{1,0}(x))]")
lift_dimshuffle.optimize(g)
# def test_merge2(self): self.failUnless(str(g) == "[x]")
# x, y, z = inputs()
# e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) def test_merge2(self):
# g = Env([x], [e]) x, y, z = inputs()
# self.failUnless(str(g) == "[InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x))]", str(g)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
# lift_dimshuffle.optimize(g) g = Env([x], [e])
# self.failUnless(str(g) == "[InplaceDimShuffle{0,1,x,x}(x)]", str(g)) self.failUnless(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
lift_dimshuffle.optimize(g)
# def test_elim3(self): self.failUnless(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
# x, y, z = inputs()
# e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) def test_elim3(self):
# g = Env([x], [e]) x, y, z = inputs()
# self.failUnless(str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{0,x,1}(x)))]", str(g)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
# lift_dimshuffle.optimize(g) g = Env([x], [e])
# self.failUnless(str(g) == "[x]", str(g)) self.failUnless(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
# def test_lift(self): self.failUnless(str(g) == "[x]", str(g))
# x, y, z = inputs([0]*1, [0]*2, [0]*3)
# e = x + y + z def test_lift(self):
# g = Env([x, y, z], [e]) x, y, z = inputs([False]*1, [False]*2, [False]*3)
# self.failUnless(str(g) == "[Broadcast{Add}(InplaceDimShuffle{x,0,1}(Broadcast{Add}(InplaceDimShuffle{x,0}(x), y)), z)]", str(g)) e = x + y + z
# lift_dimshuffle.optimize(g) g = Env([x, y, z], [e])
# self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) gof.ExpandMacros().optimize(g)
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
print g
lift_dimshuffle.optimize(g)
gof.ExpandMacros().optimize(g)
print g
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
# class _test_cliques(unittest.TestCase): # class _test_cliques(unittest.TestCase):
...@@ -185,8 +189,8 @@ ...@@ -185,8 +189,8 @@
# if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() unittest.main()
......
...@@ -104,6 +104,9 @@ class FunctionFactory: ...@@ -104,6 +104,9 @@ class FunctionFactory:
if not isinstance(r, gof.Result): if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r) raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs, disown_inputs = disown_inputs) env = std_env(inputs, outputs, disown_inputs = disown_inputs)
gof.ExpandMacros().optimize(env)
#gof.ExpandMacros(lambda node: getattr(node.op, 'level', 0) <= 1).optimize(env)
#gof.ExpandMacros(lambda node: node.op.level == 2).optimize(env)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
env.validate() env.validate()
......
...@@ -29,11 +29,6 @@ def TensorConstant(*inputs, **kwargs): ...@@ -29,11 +29,6 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ### ### DimShuffle ###
################## ##################
## TODO: rule-based version of DimShuffle
## would allow for Transpose, LComplete, RComplete, etc.
## Can be optimized into DimShuffle later on.
class ShuffleRule(Macro): class ShuffleRule(Macro):
""" """
ABSTRACT Op - it has no perform and no c_code ABSTRACT Op - it has no perform and no c_code
...@@ -41,20 +36,30 @@ class ShuffleRule(Macro): ...@@ -41,20 +36,30 @@ class ShuffleRule(Macro):
Apply ExpandMacros to this node to obtain Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed. an equivalent DimShuffle which can be performed.
""" """
def __init__(self, rule = None, name = None): level = 1
def __init__(self, rule = None, inplace = False, name = None):
if rule is not None: if rule is not None:
self.rule = rule self.rule = rule
self.inplace = inplace
if inplace:
self.view_map = {0: [0]}
self.name = name self.name = name
def make_node(self, input, *models): def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models)) pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
ib = input.type.broadcastable
return gof.Apply(self, return gof.Apply(self,
(input,) + models, (input,) + models,
[Tensor(dtype = input.type.dtype, [Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' for x in pattern]).make_result()]) broadcastable = [x == 'x' or ib[x] for x in pattern]).make_result()])
def expand(self, r): def expand(self, node):
input, models = r.owner.inputs[0], r.owner.inputs[1:] input, models = node.inputs[0], node.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models)) new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return DimShuffle(input.type.broadcastable, new_order)(input) #print new_order, node.outputs[0].type, DimShuffle(input.type.broadcastable, new_order)(input).type, node.outputs[0].type == DimShuffle(input.type.broadcastable, new_order)(input).type
if list(new_order) == range(input.type.ndim) and self.inplace:
return [input]
else:
return [DimShuffle(input.type.broadcastable, new_order, self.inplace)(input)]
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other): def __hash__(self, other):
...@@ -66,10 +71,13 @@ class ShuffleRule(Macro): ...@@ -66,10 +71,13 @@ class ShuffleRule(Macro):
return "ShuffleRule{%s}" % self.role return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1), _transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
inplace = True,
name = 'transpose') name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)), lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
inplace = True,
name = 'lcomplete') name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)), rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
inplace = True,
name = 'rcomplete') name = 'rcomplete')
...@@ -170,7 +178,7 @@ class DimShuffle(Op): ...@@ -170,7 +178,7 @@ class DimShuffle(Op):
ob = [] ob = []
for value in self.new_order: for value in self.new_order:
if value == 'x': if value == 'x':
ob.append(1) ob.append(True)
else: else:
ob.append(ib[value]) ob.append(ib[value])
...@@ -304,8 +312,10 @@ class Elemwise(Op): ...@@ -304,8 +312,10 @@ class Elemwise(Op):
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs]) shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
target_length = max([input.type.ndim for input in inputs]) target_length = max([input.type.ndim for input in inputs])
if len(inputs) > 1: if len(inputs) > 1:
inputs = [lcomplete(input, *inputs) for input in inputs] inputs = [lcomplete(input, *inputs) for input in inputs]
# args = [] # args = []
# for input in inputs: # for input in inputs:
# length = input.type.ndim # length = input.type.ndim
...@@ -316,7 +326,7 @@ class Elemwise(Op): ...@@ -316,7 +326,7 @@ class Elemwise(Op):
# # TODO: use LComplete instead # # TODO: use LComplete instead
# args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input)) # args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# inputs = args # inputs = args
try: try:
assert len(set([len(input.type.broadcastable) for input in inputs])) == 1 assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError): except (AssertionError, AttributeError):
......
差异被折叠。
import link import link
from functools import partial from functools import partial
class DebugException(Exception): class DebugException(Exception):
pass pass
......
...@@ -30,14 +30,14 @@ class Optimizer: ...@@ -30,14 +30,14 @@ class Optimizer:
""" """
pass pass
def optimize(self, env): def optimize(self, env, *args, **kwargs):
""" """
This is meant as a shortcut to:: This is meant as a shortcut to::
env.satisfy(opt) env.satisfy(opt)
opt.apply(env) opt.apply(env)
""" """
self.add_requirements(env) self.add_requirements(env)
self.apply(env) self.apply(env, *args, **kwargs)
def __call__(self, env): def __call__(self, env):
""" """
...@@ -218,8 +218,14 @@ class LocalOpKeyOptGroup(LocalOptGroup): ...@@ -218,8 +218,14 @@ class LocalOpKeyOptGroup(LocalOptGroup):
class ExpandMacro(LocalOptimizer): class ExpandMacro(LocalOptimizer):
def __init__(self, filter = None):
if filter is None:
self.filter = lambda node: True
else:
self.filter = filter
def transform(self, node): def transform(self, node):
if not isinstance(node.op, op.Macro): if not isinstance(node.op, op.Macro) or not self.filter(node):
return False return False
return node.op.expand(node) return node.op.expand(node)
...@@ -466,7 +472,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -466,7 +472,7 @@ class NavigatorOptimizer(Optimizer):
def process_node(self, env, node): def process_node(self, env, node):
replacements = self.local_opt.transform(node) replacements = self.local_opt.transform(node)
if replacements is False: if replacements is False or replacements is None:
return return
repl_pairs = zip(node.outputs, replacements) repl_pairs = zip(node.outputs, replacements)
try: try:
...@@ -490,13 +496,15 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -490,13 +496,15 @@ class TopoOptimizer(NavigatorOptimizer):
self.order = order self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback) NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
def apply(self, env): def apply(self, env, start_from = None):
q = deque(graph.io_toposort(env.inputs, env.outputs)) if start_from is None: start_from = env.outputs
q = deque(graph.io_toposort(env.inputs, start_from))
def importer(node): def importer(node):
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
q.remove(node) try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
...@@ -529,7 +537,8 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -529,7 +537,8 @@ class OpKeyOptimizer(NavigatorOptimizer):
if node.op == op: q.append(node) if node.op == op: q.append(node)
def pruner(node): def pruner(node):
if node is not current_node and node.op == op: if node is not current_node and node.op == op:
q.remove(node) try: q.remove(node)
except ValueError: pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
while q: while q:
...@@ -554,7 +563,10 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -554,7 +563,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
### Pre-defined optimizers ### ### Pre-defined optimizers ###
############################## ##############################
expand_macros = TopoOptimizer(ExpandMacro(), ignore_newtrees = False) def ExpandMacros(filter = None):
return TopoOptimizer(ExpandMacro(filter = filter),
order = 'in_to_out',
ignore_newtrees = False)
......
...@@ -83,7 +83,7 @@ class DimShufflePrinter: ...@@ -83,7 +83,7 @@ class DimShufflePrinter:
raise TypeError("Can only print DimShuffle.") raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, ShuffleRule): elif isinstance(r.owner.op, ShuffleRule):
#print r, r.owner.op #print r, r.owner.op
new_r = r.owner.op.expand(r) new_r = r.owner.op.expand(r.owner)
#print new_r.owner, isinstance(new_r.owner.op, ShuffleRule) #print new_r.owner, isinstance(new_r.owner.op, ShuffleRule)
return self.process(new_r, pstate) return self.process(new_r, pstate)
elif isinstance(r.owner.op, DimShuffle): elif isinstance(r.owner.op, DimShuffle):
...@@ -163,16 +163,6 @@ class PPrinter: ...@@ -163,16 +163,6 @@ class PPrinter:
return cp return cp
# class ExtendedPPrinter:
# def __init__(self, pprinter, leaf_pprinter):
# self.pprinter = pprinter
# self.leaf_pprinter = pprinter
# def process(self, r, pstate = None):
from tensor import * from tensor import *
from elemwise import Sum, ShuffleRule from elemwise import Sum, ShuffleRule
...@@ -194,18 +184,18 @@ pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimS ...@@ -194,18 +184,18 @@ pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimS
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, ShuffleRule), DimShufflePrinter()) pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, ShuffleRule), DimShufflePrinter())
print pp.process(x + y * z) # print pp.process(x + y * z)
print pp.process((x + y) * z) # print pp.process((x + y) * z)
print pp.process(x * (y * z)) # print pp.process(x * (y * z))
print pp.process(x / (y / z) / x) # print pp.process(x / (y / z) / x)
print pp.process((x ** y) ** z) # print pp.process((x ** y) ** z)
print pp.process(-x+y) # print pp.process(-x+y)
print pp.process(-x*y) # print pp.process(-x*y)
print pp.process(sum(x)) # print pp.process(sum(x))
print pp.process(sum(x * 10)) # print pp.process(sum(x * 10))
a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha') # a = Tensor(broadcastable=(False,False,False), dtype='float64')('alpha')
print a.type # print a.type
print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a) # print pp.process(DimShuffle((False,)*2, [1, 0])(x) + a)
print pp.process(x / (y * z)) # print pp.process(x / (y * z))
...@@ -1142,15 +1142,19 @@ gemm = Gemm() ...@@ -1142,15 +1142,19 @@ gemm = Gemm()
# Gradient # Gradient
######################### #########################
class SGrad(gof.Op): class Grad(gof.Macro):
level = 2
def make_node(self, cost, wrt): def make_node(self, cost, wrt):
return Apply(self, [cost, wrt], [wrt.type()]) if not isinstance(wrt, list):
def expand(self, r): wrt = [wrt]
cost, wrt = r.owner.inputs return Apply(self, [cost] + wrt, [_wrt.type() for _wrt in wrt])
return grad(cost, wrt) def expand(self, node):
sgrad = SGrad() cost, wrt = node.inputs[0], node.inputs[1:]
g = exec_grad(cost, wrt)
def grad(cost, wrt, g_cost=None): return g
grad = Grad()
def exec_grad(cost, wrt, g_cost=None):
""" """
@type cost: L{Result} @type cost: L{Result}
@type wrt: L{Result} or list of L{Result}s. @type wrt: L{Result} or list of L{Result}s.
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论