提交 3650c6e8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

cleaned up gof/__init__.py, removed dead files gof/modes and gof/features

上级 1cfd8fda
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import time import time
import unittest import unittest
from gof import Result, Op, Env, modes from gof import Result, Op, Env
import gof import gof
from scalar import * from scalar import *
...@@ -11,15 +11,6 @@ import tensor ...@@ -11,15 +11,6 @@ import tensor
from elemwise import * from elemwise import *
# def inputs():
# x = modes.build(Tensor('float64', (0, 0), name = 'x'))
# y = modes.build(Tensor('float64', (1, 0), name = 'y'))
# z = modes.build(Tensor('float64', (0, 0), name = 'z'))
# return x, y, z
# def env(inputs, outputs, validate = True, features = []):
# return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_DimShuffle(unittest.TestCase): class _test_DimShuffle(unittest.TestCase):
...@@ -34,7 +25,6 @@ class _test_DimShuffle(unittest.TestCase): ...@@ -34,7 +25,6 @@ class _test_DimShuffle(unittest.TestCase):
ib = [(entry == 1) for entry in xsh] ib = [(entry == 1) for entry in xsh]
x = Tensor('float64', ib)('x') x = Tensor('float64', ib)('x')
e = DimShuffle(ib, shuffle)(x) e = DimShuffle(ib, shuffle)(x)
# print shuffle, e.owner.grad(e.owner.inputs, e.owner.outputs).owner.new_order
f = linker(Env([x], [e])).make_function() f = linker(Env([x], [e])).make_function()
assert f(numpy.ones(xsh)).shape == zsh assert f(numpy.ones(xsh)).shape == zsh
...@@ -58,18 +48,10 @@ class _test_Broadcast(unittest.TestCase): ...@@ -58,18 +48,10 @@ class _test_Broadcast(unittest.TestCase):
y = Tensor('float64', [(entry == 1) for entry in ysh])('y') y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(add)(x, y) e = Elemwise(add)(x, y)
f = linker(Env([x, y], [e])).make_function() f = linker(Env([x, y], [e])).make_function()
# xv = numpy.array(range(numpy.product(xsh)))
# xv = xv.reshape(xsh)
# yv = numpy.array(range(numpy.product(ysh)))
# yv = yv.reshape(ysh)
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv zv = xv + yv
# print "AAAAAAAAAAAAAAAAAA"
# print f(xv, yv)
# print zv
# print "BBBBBBBBBBBBBBBBBB"
self.failUnless((f(xv, yv) == zv).all()) self.failUnless((f(xv, yv) == zv).all())
def with_linker_inplace(self, linker): def with_linker_inplace(self, linker):
...@@ -152,12 +134,6 @@ class _test_CAReduce(unittest.TestCase): ...@@ -152,12 +134,6 @@ class _test_CAReduce(unittest.TestCase):
zv = xv zv = xv
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.add.reduce(zv, axis) zv = numpy.add.reduce(zv, axis)
# print "AAAAAAAAAAAAAAAAAA"
# print xsh, tosum
# print f(xv)
# print zv
# print f(xv) - zv
# print "BBBBBBBBBBBBBBBBBB"
self.failUnless((numpy.abs(f(xv) - zv) < 1e-10).all()) self.failUnless((numpy.abs(f(xv) - zv) < 1e-10).all())
def test_perform(self): def test_perform(self):
...@@ -169,79 +145,3 @@ class _test_CAReduce(unittest.TestCase): ...@@ -169,79 +145,3 @@ class _test_CAReduce(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# # x = modes.build(Tensor('int32', [0, 0], name = 'x'))
# # y = modes.build(Tensor('int32', [0, 0], name = 'y'))
# from scalar import Scalar, composite
# x = modes.build(Tensor('float64', [0, 0], name = 'x'))
# y = modes.build(Tensor('float64', [0, 0], name = 'y'))
# xs, ys = Scalar('float64'), Scalar('float64')
# e = Broadcast(composite([xs, ys], [(xs * ys) + (xs / ys) * 7.0]), (x, y)).out
# f = gof.CLinker(env([x, y], [e])).make_function(inplace = False)
# size = 2000
# xv = numpy.random.rand(size, size)
# yv = numpy.random.rand(size, size)
# zv = numpy.random.rand(size, size)
# # xv = numpy.random.randint(1, 5, (1000, 1000))
# # yv = numpy.random.randint(1, 5, (1000, 1000))
# # t0 = time.time()
# # for i in xrange(100):
# # xv / yv
# # print time.time() - t0
# # t0 = time.time()
# # for i in xrange(10):
# # f(xv, yv)
# # print time.time() - t0
# # t0 = time.time()
# # for i in xrange(10):
# # (xv * yv) + (xv / yv) * 7.0
# # print time.time() - t0
# from scipy import weave
# import numpy
# t0 = time.time()
# for i in xrange(10):
# weave.blitz("zv = dot(xv, yv)", locals())
# print time.time() - t0
# speed ratios:
# add : 1
# mul : 1
# div : 2
# pow : 20
# def test_straightforward(self):
# x, y, z = inputs()
# e0 = CAReduce(Add, [x]).out
# # print e0.owner
# f = gof.PerformLinker(env([x], [e0])).make_function(inplace=True)
# assert f(numpy.ones((2, 2))) == 4.0
##########
##########
# def test_straightforward(self):
# x, y, z = inputs()
# e0 = Broadcast(Add, (x, y)).out
# f = gof.PerformLinker(env([x, y], [e0])).make_function(inplace=True)
# assert (f(numpy.ones((2, 2)), numpy.ones((1, 2))) == numpy.ones((2, 2))*2).all()
# # for result in e0.owner.grad(e0.owner.inputs, (z, )):
# # print env([x, y, z], [result])
# def test_c(self):
# x = modes.build(Tensor('float64', (0, 0), name = 'x'))
# y = modes.build(Tensor('float64', (0, 1), name = 'y'))
# z = modes.build(Tensor('float64', (0, 0), name = 'z'))
# # x = modes.build(Tensor('float64', (), name = 'x'))
# # y = modes.build(Tensor('float64', (), name = 'y'))
# # x, y, z = inputs()
# e0 = Broadcast(Add, (x, y)).out
# f = gof.CLinker(env([x, y], [e0])).make_function(inplace=True)
# print f(numpy.ones((4, 4), order = 'f'), numpy.array([[1], [2], [3], [4]]))
# # print f(numpy.ones(()), numpy.ones(()))
# assert (f(numpy.ones((2, 2)), numpy.ones((2, 1))) == numpy.ones((2, 2))*2).all()
import op, type, ext, link, env, features, toolbox, graph, cc, opt from cc import CLinker, OpWiseCLinker, DualLinker
from env import InconsistencyError, Env
from op import * from ext import DestroyHandler, view_roots
from graph import Apply, Result, Constant, as_apply, as_result from graph import Apply, Result, Constant, Value
from type import * from link import Linker, LocalLinker, PerformLinker, Profiler
from ext import * from op import Op
from link import * from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge
from env import * from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from features import * from type import Type, Generic, generic
from toolbox import * from utils import object2, AbstractFunctionError
from cc import *
from opt import *
# import op, ext, lib, link, result, env, prog, features, opt, graph
# from op import *
# from ext import *
# from lib import *
# from link import *
# from result import *
# from env import *
# from prog import *
# from features import *
# from opt import *
# import graph
import unittest
from modes import *
from graph import Result
from op import Op
from env import Env
class Double(Result):
def __init__(self, data, name = "oignon"):
Result.__init__(self, role = None, name = name)
assert isinstance(data, float)
self.data = data
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __add__(self, other):
return add(self, other)
def convert(x):
if isinstance(x, float):
return Double(x)
elif isinstance(x, Double):
return x
raise Exception("Error 1")
class MyOp(Op):
nin = -1
def __init__(self, *inputs):
assert len(inputs) == self.nin
inputs = [convert(input) for input in inputs]
self.inputs = inputs
self.outputs = [Double(0.0, self.__class__.__name__ + "_R")]
def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
class Unary(MyOp):
nin = 1
class Binary(MyOp):
nin = 2
class Add(Binary):
def impl(self, x, y):
return x + y
class Sub(Binary):
def impl(self, x, y):
return x - y
class Mul(Binary):
def impl(self, x, y):
return x * y
class Div(Binary):
def impl(self, x, y):
return x / y
make_constructors(globals())
def inputs(mode):
x = mode(Double(1.0, 'x'))
y = mode(Double(2.0, 'y'))
z = mode(Double(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [], consistency_check = validate)
# class _test_Modes(unittest.TestCase):
# def test_0(self):
# x, y, z = inputs(build)
# e = add(add(x, y), z)
# g = env([x, y, z], [e])
# assert str(g) == "[Add(Add(x, y), z)]"
# assert e.data == 0.0
# def test_1(self):
# x, y, z = inputs(build_eval)
# e = add(add(x, y), z)
# g = env([x, y, z], [e])
# assert str(g) == "[Add(Add(x, y), z)]"
# assert e.data == 6.0
# def test_2(self):
# x, y, z = inputs(eval)
# e = add(add(x, y), z)
# g = env([x, y, z], [e])
# assert str(g) == "[Add_R]"
# assert e.data == 6.0
# def test_3(self):
# x, y, z = inputs(build)
# e = x + y + z
# g = env([x, y, z], [e])
# assert str(g) == "[Add(Add(x, y), z)]"
# assert e.data == 0.0
# def test_4(self):
# x, y, z = inputs(build_eval)
# e = x + 34.0
# g = env([x, y, z], [e])
# assert str(g) == "[Add(x, oignon)]"
# assert e.data == 35.0
# def test_5(self):
# xb, yb, zb = inputs(build)
# xe, ye, ze = inputs(eval)
# try:
# e = xb + ye
# except TypeError:
# # Trying to add inputs from different modes is forbidden
# pass
# else:
# raise Exception("Expected an error.")
if __name__ == '__main__':
unittest.main()
...@@ -3,7 +3,6 @@ import unittest ...@@ -3,7 +3,6 @@ import unittest
from graph import Result, as_result, Apply, Constant from graph import Result, as_result, Apply, Constant
from op import Op from op import Op
from ext import Destroyer
from opt import * from opt import *
from env import Env from env import Env
from toolbox import * from toolbox import *
......
...@@ -478,50 +478,50 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -478,50 +478,50 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
class Destroyer: # class Destroyer:
""" # """
Base class for Ops that destroy one or more of their inputs in an # Base class for Ops that destroy one or more of their inputs in an
inplace operation, use them as temporary storage, puts garbage in # inplace operation, use them as temporary storage, puts garbage in
them or anything else that invalidates the contents for use by other # them or anything else that invalidates the contents for use by other
Ops. # Ops.
Usage of this class in an env requires DestroyHandler. # Usage of this class in an env requires DestroyHandler.
""" # """
def destroyed_inputs(self): # def destroyed_inputs(self):
raise AbstractFunctionError() # raise AbstractFunctionError()
def destroy_map(self): # def destroy_map(self):
""" # """
Returns the map {output: [list of destroyed inputs]} # Returns the map {output: [list of destroyed inputs]}
While it typically means that the storage of the output is # While it typically means that the storage of the output is
shared with each of the destroyed inputs, it does necessarily # shared with each of the destroyed inputs, it does necessarily
have to be the case. # have to be the case.
""" # """
# compatibility # # compatibility
return {self.out: self.destroyed_inputs()} # return {self.out: self.destroyed_inputs()}
__env_require__ = DestroyHandler # __env_require__ = DestroyHandler
class Viewer: # class Viewer:
""" # """
Base class for Ops that return one or more views over one or more inputs, # Base class for Ops that return one or more views over one or more inputs,
which means that the inputs and outputs share their storage. Unless it also # which means that the inputs and outputs share their storage. Unless it also
extends Destroyer, this Op does not modify the storage in any way and thus # extends Destroyer, this Op does not modify the storage in any way and thus
the input is safe for use by other Ops even after executing this one. # the input is safe for use by other Ops even after executing this one.
""" # """
def view_map(self): # def view_map(self):
""" # """
Returns the map {output: [list of viewed inputs]} # Returns the map {output: [list of viewed inputs]}
It means that the output shares storage with each of the inputs # It means that the output shares storage with each of the inputs
in the list. # in the list.
Note: support for more than one viewed input is minimal, but # Note: support for more than one viewed input is minimal, but
this might improve in the future. # this might improve in the future.
""" # """
raise AbstractFunctionError() # raise AbstractFunctionError()
def view_roots(r): def view_roots(r):
......
import utils
__all__ = ['Feature',
'Listener',
'Constraint',
'Orderings',
'Tool',
'uniq_features',
]
class Feature(object):
def __init__(self, env):
"""
Initializes the L{Feature}'s env field to the parameter
provided.
"""
self.env = env
class Listener(Feature):
"""
When registered by an L{Env}, each listener is informed of any L{Op}
entering or leaving the subgraph (which happens at construction
time and whenever there is a replacement).
"""
def on_import(self, op):
"""
This method is called by the L{Env} whenever a new L{Op} is
added to the graph.
"""
raise utils.AbstractFunctionError()
def on_prune(self, op):
"""
This method is called by the L{Env} whenever an L{Op} is
removed from the graph.
"""
raise utils.AbstractFunctionError()
def on_rewire(self, clients, r, new_r):
"""
@param clients: (op, i) pairs such that op.inputs[i] is new_r but used to be r
@param r: the old result that was used by the L{Op}s in clients
@param new_r: the new result that is now used by the L{Op}s in clients
Note that the change from r to new_r is done before this
method is called.
"""
raise utils.AbstractFunctionError()
class Constraint(Feature):
"""
When registered by an L{Env}, a L{Constraint} can restrict the L{Op}s that
can be in the subgraph or restrict the ways L{Op}s interact with each
other.
"""
def validate(self):
"""
Raises an L{InconsistencyError} if the L{Env} is currently
invalid from the perspective of this object.
"""
raise utils.AbstractFunctionError()
class Orderings(Feature):
"""
When registered by an L{Env}, an L{Orderings} object can provide supplemental
ordering constraints to the subgraph's topological sort.
"""
def orderings(self):
"""
Returns {op: set(ops that must be evaluated before this op), ...}
This is called by L{Env.orderings}() and used in L{Env.toposort}() but
not in L{Env.io_toposort}().
"""
raise utils.AbstractFunctionError()
class Tool(Feature):
"""
A L{Tool} can extend the functionality of an L{Env} so that, for example,
optimizations can have access to efficient ways to search the graph.
"""
def publish(self):
"""
This is only called once by the L{Env}, when the L{Tool} is added.
Adds methods to L{Env}.
"""
raise utils.AbstractFunctionError()
def uniq_features(_features, *_rest):
"""Return a list such that no element is a subclass of another"""
# used in Env.__init__
features = [x for x in _features]
for other in _rest:
features += [x for x in other]
res = []
while features:
feature = features.pop()
for feature2 in features:
if issubclass(feature2, feature):
break
else:
res.append(feature)
return res
...@@ -568,36 +568,36 @@ def as_string(i, o, ...@@ -568,36 +568,36 @@ def as_string(i, o,
class Graph: # class Graph:
""" # """
Object-oriented wrapper for all the functions in this module. # Object-oriented wrapper for all the functions in this module.
""" # """
def __init__(self, inputs, outputs): # def __init__(self, inputs, outputs):
self.inputs = inputs # self.inputs = inputs
self.outputs = outputs # self.outputs = outputs
def ops(self): # def ops(self):
return ops(self.inputs, self.outputs) # return ops(self.inputs, self.outputs)
def values(self): # def values(self):
return values(self.inputs, self.outputs) # return values(self.inputs, self.outputs)
def orphans(self): # def orphans(self):
return orphans(self.inputs, self.outputs) # return orphans(self.inputs, self.outputs)
def io_toposort(self): # def io_toposort(self):
return io_toposort(self.inputs, self.outputs) # return io_toposort(self.inputs, self.outputs)
def toposort(self): # def toposort(self):
return self.io_toposort() # return self.io_toposort()
def clone(self): # def clone(self):
o = clone(self.inputs, self.outputs) # o = clone(self.inputs, self.outputs)
return Graph(self.inputs, o) # return Graph(self.inputs, o)
def __str__(self): # def __str__(self):
return as_string(self.inputs, self.outputs) # return as_string(self.inputs, self.outputs)
......
### CODE CAN BE SIMPLIFIED IF WE ONLY KEEP BUILD MODE ###
import utils
import traceback
from op import Op
__all__ = ['ModalConstructor',
'add_modal_members',
'build',
'eval',
'build_eval',
'make_constructors',
]
class ModalConstructor:
def __init__(self, fn):
self.fn = fn
def __call__(self, *args):
modal_wrapper = None
fn_args = []
for arg in args:
mode = getattr(arg, '__mode__', False)
if mode:
if modal_wrapper is None:
modal_wrapper = mode
else:
if mode != modal_wrapper:
raise TypeError("Inconsistent modes.")
fn_args.append(arg)
op = self.fn(*fn_args)
if modal_wrapper:
modal_wrapper(op)
for output in op.outputs:
output.__mode__ = modal_wrapper
if len(op.outputs) == 1:
return op.outputs[0]
else:
return op.outputs
def add_modal_members(cls, *members):
def fn(member):
def ret(self, *args):
constructor = ModalConstructor(getattr(self.r.__class__, member))
return constructor(self, *args)
return ret
for member in members:
setattr(cls, member, fn(member))
def attach_trace(op):
"""
Extracts the stack trace at the point of construction and
puts it in the op's trace field.
"""
stack = traceback.extract_stack()[:-3] # we discard 3 levels
op.trace = stack
def build_mode(op):
attach_trace(op)
def eval_mode(op):
attach_trace(op)
op.perform()
for output in op.outputs:
output._role = None
def build_eval_mode(op):
attach_trace(op)
op.perform()
def mode_setter(mode):
def f(r):
r.__mode__ = mode
return r
return f
build = mode_setter(build_mode)
eval = mode_setter(eval_mode)
build_eval = mode_setter(build_eval_mode)
def _is_op(x):
try: return issubclass(x, Op)
except: return False
def make_constructors(source,
dest = None,
name_filter = utils.camelcase_to_separated,
candidate_filter = _is_op):
if dest is None:
dest = source
for symbol, value in source.items():
if candidate_filter(value):
dest[name_filter(symbol)] = ModalConstructor(value)
return dest
...@@ -16,23 +16,23 @@ class Bookkeeper: ...@@ -16,23 +16,23 @@ class Bookkeeper:
self.on_prune(env, node) self.on_prune(env, node)
class Toposorter: # class Toposorter:
def on_attach(self, env): # def on_attach(self, env):
if hasattr(env, 'toposort'): # if hasattr(env, 'toposort'):
raise Exception("Toposorter feature is already present or in conflict with another plugin.") # raise Exception("Toposorter feature is already present or in conflict with another plugin.")
env.toposort = partial(self.toposort, env) # env.toposort = partial(self.toposort, env)
def on_detach(self, env): # def on_detach(self, env):
del env.toposort # del env.toposort
def toposort(self, env): # def toposort(self, env):
ords = {} # ords = {}
for feature in env._features: # for feature in env._features:
if hasattr(feature, 'orderings'): # if hasattr(feature, 'orderings'):
for op, prereqs in feature.orderings(env).items(): # for op, prereqs in feature.orderings(env).items():
ords.setdefault(op, set()).update(prereqs) # ords.setdefault(op, set()).update(prereqs)
order = graph.io_toposort(env.inputs, env.outputs, ords) # order = graph.io_toposort(env.inputs, env.outputs, ords)
return order # return order
# def supplemental_orderings(self): # def supplemental_orderings(self):
......
...@@ -5,7 +5,7 @@ from copy import copy ...@@ -5,7 +5,7 @@ from copy import copy
import numpy import numpy
import gof import gof
from gof import PropertiedType, Op, PropertiedOp, utils, Result, Constant, Type, Apply, Env from gof import Op, utils, Result, Constant, Type, Apply, Env
from gof.python25 import partial from gof.python25 import partial
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
......
...@@ -6,7 +6,7 @@ import numpy ...@@ -6,7 +6,7 @@ import numpy
from copy import copy from copy import copy
from gof import Result, Op, utils, Destroyer, Viewer, AbstractFunctionError, Type, Result, Constant, Apply, Value from gof import Result, Op, utils, AbstractFunctionError, Type, Result, Constant, Apply, Value
import gof import gof
import blas # for gemm, dot import blas # for gemm, dot
...@@ -574,54 +574,54 @@ transpose_inplace = TransposeInplace() ...@@ -574,54 +574,54 @@ transpose_inplace = TransposeInplace()
def transpose(x, **kwargs): def transpose(x, **kwargs):
return transpose_inplace(tensor_copy(x), **kwargs) return transpose_inplace(tensor_copy(x), **kwargs)
class Subtensor_dx(Op, Viewer): # class Subtensor_dx(Op, Viewer):
"""Return a tensor full of zeros, except for what was sliced from x by # """Return a tensor full of zeros, except for what was sliced from x by
Subtensor. # Subtensor.
@todo: pass the shape of x, rather than x itself. # @todo: pass the shape of x, rather than x itself.
@todo: add support for advanced tensor indexing (breaks current perform # @todo: add support for advanced tensor indexing (breaks current perform
implementation). # implementation).
""" # """
def __init__(self, inputs, idx_list, **kwargs): # def __init__(self, inputs, idx_list, **kwargs):
Op.__init__(self, **kwargs) # Op.__init__(self, **kwargs)
self.inputs = inputs # self.inputs = inputs
self.outputs = [Tensor(inputs[0].dtype, inputs[0].broadcastable)] # self.outputs = [Tensor(inputs[0].dtype, inputs[0].broadcastable)]
self.idx_list = idx_list # self.idx_list = idx_list
def perform(self): # def perform(self):
x = self.inputs[0] # x = self.inputs[0]
gz = self.inputs[-1] # gz = self.inputs[-1]
cdata = [] # cdata = []
for c in self.idx_list: # for c in self.idx_list:
if isinstance(c, slice): # if isinstance(c, slice):
if c.start is None: start = None # if c.start is None: start = None
else: start = self.inputs[c.start].data # else: start = self.inputs[c.start].data
if c.stop is None: stop = None # if c.stop is None: stop = None
else: stop = self.inputs[c.stop].data # else: stop = self.inputs[c.stop].data
if c.step is None: step = None # if c.step is None: step = None
else: step = self.inputs[c.step].data # else: step = self.inputs[c.step].data
cdata.append(slice(start, stop, step)) # cdata.append(slice(start, stop, step))
else: # else:
d = self.inputs[c].data # d = self.inputs[c].data
assert 'int' in str(d.dtype) # assert 'int' in str(d.dtype)
cdata.append(d) # cdata.append(d)
if len(cdata) > 1: # if len(cdata) > 1:
cdata = tuple(cdata) #there's a diff between tuple and list here... # cdata = tuple(cdata) #there's a diff between tuple and list here...
else: # else:
cdata = cdata[0] # cdata = cdata[0]
#print cdata # #print cdata
#print gz.data # #print gz.data
gx = numpy.zeros_like(x.data) # gx = numpy.zeros_like(x.data)
gx[cdata] = gz.data # gx[cdata] = gz.data
#print gx # #print gx
self.outputs[0].data = gx # self.outputs[0].data = gx
def clone_with_new_inputs(self, *new_inputs): # def clone_with_new_inputs(self, *new_inputs):
assert len(self.inputs) == len(new_inputs) # assert len(self.inputs) == len(new_inputs)
return Subtensor_dx(new_inputs, self.idx_list) # return Subtensor_dx(new_inputs, self.idx_list)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论