提交 11c4882a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1556 from nouiz/mixed2

Faster compilation and sparse stuff.
......@@ -123,6 +123,7 @@ List of Implemented Operations
Both grad are implemented. Structured by default.
- :class:`SparseFromDense <theano.sparse.basic.SparseFromDense>` and ``csr_from_dense``, ``csc_from_dense``.
The grad implemented is structured.
- Theano SparseVariable object have a method ``toarray()`` that is the same as ``dense_from_sparse``.
- Construction of Sparses and their Properties
- :class:`CSM <theano.sparse.basic.CSM>` and ``CSC``, ``CSR`` to construct a matrix.
......
......@@ -208,6 +208,8 @@ def cleanup():
have_c_compiler = False
for obj in flatten(key):
if isinstance(obj, numpy.ndarray):
#Reuse have_npy_abi_version to
#force the removing of key
have_npy_abi_version = False
break
elif isinstance(obj, basestring):
......@@ -219,6 +221,8 @@ def cleanup():
hasattr(obj, 'c_code_cache_version')):
v = obj.c_code_cache_version()
if v not in [(), None] and v not in key[0]:
#Reuse have_npy_abi_version to
#force the removing of key
have_npy_abi_version = False
break
......
......@@ -442,7 +442,7 @@ if 0:
self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r):
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""app.inputs[i] changed from old_r to new_r """
if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
......@@ -827,7 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r):
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""app.inputs[i] changed from old_r to new_r """
if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
......
......@@ -376,7 +376,7 @@ class FunctionGraph(utils.object2):
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, [reason])
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output':
......@@ -512,14 +512,7 @@ class FunctionGraph(utils.object2):
# not existing
continue
#####HORRIBLE OPTIONAL ARGUMENT HACK
try:
fn(self, *args, **kwargs)
except TypeError, e:
if str(e) == "on_change_input() got an unexpected keyword argument 'reason'" and len(kwargs) == 1:
fn(self, *args)
else:
raise
fn(self, *args, **kwargs)
def collect_callbacks(self, name, *args):
"""WRITEME
......
......@@ -423,7 +423,7 @@ class MergeFeature(object):
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
def on_change_input(self, fgraph, node, i, r, new_r):
def on_change_input(self, fgraph, node, i, r, new_r, reason):
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
if node in self.nodes_seen:
......@@ -555,6 +555,9 @@ class MergeOptimizer(Optimizer):
# clear blacklist
fgraph.merge_feature.blacklist = []
def __str__(self):
return self.__class__.__name__
merge_optimizer = MergeOptimizer()
......@@ -1171,7 +1174,7 @@ class NavigatorOptimizer(Optimizer):
def on_prune(self, fgraph, node, reason):
pruner(node)
if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r):
def on_change_input(self, fgraph, node, i, r, new_r, reason):
chin(node, i, r, new_r)
u = Updater()
......@@ -1229,7 +1232,7 @@ class NavigatorOptimizer(Optimizer):
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
if rnew is not r]
if rnew is not r]
if len(repl_pairs) == 0:
return False
try:
......@@ -1302,6 +1305,10 @@ class TopoOptimizer(NavigatorOptimizer):
raise
self.detach_updater(fgraph, u)
def __str__(self):
return getattr(self, '__name__',
'<TopoOptimizer instance>')
class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
......@@ -1360,7 +1367,7 @@ class ChangeTracker:
def on_import(self, fgraph, node, reason):
self.changed = True
def on_change_input(self, fgraph, node, i, r, new_r):
def on_change_input(self, fgraph, node, i, r, new_r, reason):
self.changed = True
def reset(self):
......@@ -1415,23 +1422,29 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def apply(self, fgraph, start_from=None):
if start_from is None:
start_from = fgraph.outputs
else:
for node in start_from:
assert node in fgraph.outputs
changed = True
max_use_abort = False
opt_name = None
process_count = {}
global_process_count = {}
max_nb_nodes = len(fgraph.apply_nodes)
max_use = max_nb_nodes * self.max_use_ratio
loop_timing = []
loop_process_count = []
global_opt_timing = []
time_opts = {}
io_toposort_timing = []
nb_nodes = []
for opt in self.global_optimizers + self.local_optimizers:
process_count.setdefault(opt, 0)
global_process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0)
while changed and not max_use_abort:
process_count = {}
t0 = time.time()
changed = False
......@@ -1442,9 +1455,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
gopt.apply(fgraph)
time_opts[gopt] += time.time() - t_opt
if fgraph.change_tracker.changed:
process_count.setdefault(gopt, 0)
process_count[gopt] += 1
global_process_count[gopt] += 1
changed = True
if process_count[gopt] > max_use:
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
......@@ -1452,9 +1467,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
global_opt_timing.append(float(time.time() - t0))
#apply local optimizer
for node in start_from:
assert node in fgraph.outputs
topo_t0 = time.time()
q = deque(graph.io_toposort(fgraph.inputs, start_from))
io_toposort_timing.append(time.time() - topo_t0)
......@@ -1485,9 +1497,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
lopt_change = self.process_node(fgraph, node, lopt)
time_opts[lopt] += time.time() - t_opt
if lopt_change:
process_count.setdefault(lopt, 0)
process_count[lopt] += 1
global_process_count[lopt] += 1
changed = True
if process_count[lopt] > max_use:
if global_process_count[lopt] > max_use:
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
......@@ -1497,6 +1511,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
finally:
self.detach_updater(fgraph, u)
loop_process_count.append(process_count)
loop_timing.append(float(time.time() - t0))
if max_use_abort:
......@@ -1505,7 +1520,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+ "%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio)
return (self, loop_timing, process_count, max_nb_nodes,
return (self, loop_timing, loop_process_count, max_nb_nodes,
global_opt_timing, nb_nodes, time_opts, io_toposort_timing)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
......@@ -1519,40 +1534,68 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
def print_profile(stream, prof, level=0):
(opt, loop_timing, process_count, max_nb_nodes,
(opt, loop_timing, loop_process_count, max_nb_nodes,
global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof
blanc = (' ' * level)
print >> stream, blanc, "EquilibriumOptimizer",
print >> stream, blanc, getattr(opt, "name",
getattr(opt, "__name__", ""))
print >> stream, blanc, " time %.3fs for %d passes, %d nodes max" % (
print >> stream, blanc, " time %.3fs for %d passes, %d nodes max" % (
sum(loop_timing), len(loop_timing), max_nb_nodes)
print >> stream, blanc, " time io_toposort %.3fs" % sum(
print >> stream, blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing)
s = sum([time_opts[o] for o in opt.local_optimizers])
print >> stream, blanc, " time in local optimizers %.3fs" % s
s = sum([time_opts[o] for o in opt.global_optimizers])
print >> stream, blanc, " time in global optimizers %.3fs" % s
for i in range(len(loop_timing)):
print >> stream, blanc, ('%d - %.3fs (%.3fs in global opts, '
'%.3fs io_toposort) - %d nodes' % (
lopt = ""
if loop_process_count[i]:
d = list(reversed(sorted(loop_process_count[i].iteritems(),
key=lambda a: a[1])))
lopt = " ".join([str((str(k), v)) for k, v
in d[:5]])
if len(d) > 5:
lopt += " ..."
print >> stream, blanc, (' %2d - %.3fs %d (%.3fs in global opts, '
'%.3fs io_toposort) - %d nodes - %s' % (
i, loop_timing[i],
sum(loop_process_count[i].values()),
global_opt_timing[i],
io_toposort_timing[i], nb_nodes[i]))
io_toposort_timing[i], nb_nodes[i],
lopt))
count_opt = []
not_used = 0
not_used_time = 0
process_count = {}
for o in opt.global_optimizers + opt.local_optimizers:
process_count.setdefault(o, 0)
for count in loop_process_count:
for o, v in count.iteritems():
process_count[o] += v
for opt, count in process_count.iteritems():
if count > 0:
count_opt.append((time_opts[opt], count, opt))
else:
not_used += 1
not_used_time += time_opts[opt]
if count_opt:
print >> stream, blanc, \
'times applied - optimizer (only those applied):'
' times - times applied - name:'
count_opt.sort()
for (t, count, opt) in count_opt[::-1]:
print >> stream, blanc, ' %.3fs - %d - %s' % (
t, count, opt)
print >> stream, blanc, ' %.3fs - in %d optimization that where not used' % (
not_used_time, not_used)
print >> stream
@staticmethod
def merge_profile(prof1, prof2):
#(opt, loop_timing, process_count, max_nb_nodes,
#(opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = set(prof1[0].local_optimizers).union(
......@@ -1574,12 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
loop_timing = merge_list(prof1[1], prof2[1])
process_count = prof1[2].copy()
for process, count in prof2[2].iteritems():
if process in process_count:
process_count[process] += count
else:
process_count[process] = count
loop_process_count = prof1[2].copy()
for i in range(len(loop_process_count)):
process_count = loop_process_count[i]
for process, count in prof2[2][i].iteritems():
if process in process_count:
process_count[process] += count
else:
process_count[process] = count
for i in range(len(loop_process_count), len(prof2[2])):
loop_process_count.append(prof2[2].copy())
max_nb_nodes = max(prof1[3], prof2[3])
......@@ -1601,7 +1648,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
assert len(loop_timing) == max(len(prof1[1]), len(prof2[1]))
return (new_opt,
loop_timing,
process_count,
loop_process_count,
max_nb_nodes,
global_opt_timing,
nb_nodes,
......
......@@ -159,7 +159,7 @@ class HintsFeature(object):
if k not in new_hints:
new_hints[k] = v
def on_change_input(self, fgraph, node, i, r, new_r):
def on_change_input(self, fgraph, node, i, r, new_r, reason):
# TODO:
# This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do.
......
......@@ -300,6 +300,8 @@ class _sparse_py_operators:
# def _as_TensorVariable(self):
# return dense_from_sparse(self)
def toarray(self):
return dense_from_sparse(self)
shape = property(lambda self: tensor.shape(dense_from_sparse(self)))
# don't worry!
# the plan is that the ShapeFeature in tensor.opt will do shape propagation
......@@ -1843,6 +1845,8 @@ class AddSD(gof.op.Op):
def infer_shape(self, node, shapes):
return [shapes[3]]
def c_code_cache_version(self):
return (1,)
add_s_d = AddSD()
......@@ -1918,6 +1922,10 @@ def add(x, y):
x = as_sparse_variable(x)
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
if not isinstance(x, theano.Variable):
x = theano.tensor.as_tensor_variable(x)
if not isinstance(y, theano.Variable):
y = theano.tensor.as_tensor_variable(y)
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
......
......@@ -567,68 +567,58 @@ class T_AddMul(unittest.TestCase):
def _testSD(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
for mtype in _mtypes:
a = numpy.array(array1)
aR = tensor.as_tensor_variable(a)
self.assertFalse(aR.data is a) # constants are copied
self.assertTrue(_is_dense(a))
self.assertTrue(_is_dense_variable(aR))
b = mtype(array2)
bR = as_sparse_variable(b)
self.assertFalse(bR.data is b) # constants are copied
self.assertTrue(_is_sparse(b))
self.assertTrue(_is_sparse_variable(bR))
apb = op(aR, bR)
self.assertTrue(apb.type.dtype == aR.type.dtype, apb.type.dtype)
self.assertTrue(apb.type.dtype == bR.type.dtype, apb.type.dtype)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
if op is add:
self.assertTrue(_is_dense_variable(apb))
self.assertTrue(numpy.all(val == (a + b)))
ans = numpy.array([[1., 2], [3, 4], [5, 6]])
self.assertTrue(numpy.all(val == ans))
elif op is mul:
self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(numpy.all(val.todense() == (b.multiply(a))))
self.assertTrue(numpy.all(val.todense() == numpy.array(
[[1, 0], [9, 0], [0, 36]])))
for a in [numpy.array(array1), tensor.as_tensor_variable(array1)]:
b = mtype(array2)
bR = as_sparse_variable(b)
self.assertFalse(bR.data is b) # constants are copied
self.assertTrue(_is_sparse(b))
self.assertTrue(_is_sparse_variable(bR))
apb = op(a, bR)
self.assertTrue(apb.type.dtype == a.dtype, apb.type.dtype)
self.assertTrue(apb.type.dtype == bR.type.dtype, apb.type.dtype)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
if op is add:
self.assertTrue(_is_dense_variable(apb))
self.assertTrue(numpy.all(val == (array1 + b)))
ans = numpy.array([[1., 2], [3, 4], [5, 6]])
self.assertTrue(numpy.all(val == ans))
elif op is mul:
self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(numpy.all(val.todense() == (b.multiply(array1))))
self.assertTrue(numpy.all(val.todense() == numpy.array(
[[1, 0], [9, 0], [0, 36]])))
def _testDS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
for mtype in _mtypes:
a = mtype(array1)
aR = as_sparse_variable(a)
self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR))
b = numpy.asarray(array2)
bR = tensor.as_tensor_variable(b)
self.assertFalse(bR.data is b)
self.assertTrue(_is_dense(b))
self.assertTrue(_is_dense_variable(bR))
apb = op(aR, bR)
self.assertTrue(apb.type.dtype == aR.type.dtype, apb.type.dtype)
self.assertTrue(apb.type.dtype == bR.type.dtype, apb.type.dtype)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
if op is add:
self.assertTrue(_is_dense_variable(apb))
self.assertTrue(numpy.all(val == (a + b)))
ans = numpy.array([[1., 2], [3, 4], [5, 6]])
self.assertTrue(numpy.all(val == ans))
elif op is mul:
self.assertTrue(_is_sparse_variable(apb))
ans = numpy.array([[1, 0], [9, 0], [0, 36]])
self.assertTrue(numpy.all(val.todense() == (a.multiply(b))))
self.assertTrue(numpy.all(val.todense() == ans))
for b in [numpy.asarray(array2), tensor.as_tensor_variable(array2)]:
a = mtype(array1)
aR = as_sparse_variable(a)
self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR))
apb = op(aR, b)
self.assertTrue(apb.type.dtype == aR.type.dtype, apb.type.dtype)
self.assertTrue(apb.type.dtype == b.dtype, apb.type.dtype)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
if op is add:
self.assertTrue(_is_dense_variable(apb))
self.assertTrue(numpy.all(val == (a + array2)))
ans = numpy.array([[1., 2], [3, 4], [5, 6]])
self.assertTrue(numpy.all(val == ans))
elif op is mul:
self.assertTrue(_is_sparse_variable(apb))
ans = numpy.array([[1, 0], [9, 0], [0, 36]])
self.assertTrue(numpy.all(val.todense() == (a.multiply(array2))))
self.assertTrue(numpy.all(val.todense() == ans))
def test_upcast(self):
array1 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='float32')
......@@ -718,18 +708,25 @@ class T_conversion(unittest.TestCase):
self.assertTrue(str(val.dtype) == 'float64')
self.assertTrue(val.format == 'csr')
if 1:
def test2(self):
#call dense_from_sparse
for t in _mtypes:
s = t(scipy.sparse.identity(5))
d = dense_from_sparse(s)
# s should be copied into the graph as a constant
s[0, 0] = 3.0 # changes s, but not the copy
val = eval_outputs([d])
return
self.assertTrue(str(val.dtype) == s.dtype)
self.assertTrue(numpy.all(val[0] == [1, 0, 0, 0, 0]))
def test_dense_from_sparse(self):
#call dense_from_sparse
for t in _mtypes:
s = t(scipy.sparse.identity(5))
s = as_sparse_variable(s)
d = dense_from_sparse(s)
val = eval_outputs([d])
self.assertTrue(str(val.dtype) == s.dtype)
self.assertTrue(numpy.all(val[0] == [1, 0, 0, 0, 0]))
def test_todense(self):
#call sparse_var.todense()
for t in _mtypes:
s = t(scipy.sparse.identity(5))
s = as_sparse_variable(s)
d = s.toarray()
val = eval_outputs([d])
self.assertTrue(str(val.dtype) == s.dtype)
self.assertTrue(numpy.all(val[0] == [1, 0, 0, 0, 0]))
@staticmethod
def check_format_ndim(format, ndim):
......
......@@ -252,6 +252,8 @@ class CGer(BaseBLAS, Ger):
def c_code_cache_version(self):
return (8, blas_header_version())
cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
@local_optimizer([ger, ger_destructive])
......@@ -269,8 +271,8 @@ def use_c_ger(node):
@local_optimizer([CGer(False)])
def make_c_ger_destructive(node):
if node.op == CGer(False):
return [CGer(True)(*node.inputs)]
if node.op == cger_no_inplace:
return [cger_inplace(*node.inputs)]
####### ####### #######
......@@ -579,6 +581,8 @@ class CGemv(BaseBLAS, Gemv):
def c_code_cache_version(self):
return (10, blas_header_version())
cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)
@local_optimizer([gemv_inplace, gemv_no_inplace])
......@@ -596,8 +600,8 @@ def use_c_gemv(node):
@local_optimizer([CGemv(inplace=False)])
def make_c_gemv_destructive(node):
if node.op == CGemv(inplace=False):
return [CGemv(inplace=True)(*node.inputs)]
if node.op == cgemv_no_inplace:
return [cgemv_inplace(*node.inputs)]
####### ####### #######
......
......@@ -546,7 +546,7 @@ class Elemwise(Op):
args.append(DimShuffle(
input.type.broadcastable,
['x'] * difference + range(length),
inplace=True)(input))
inplace=False)(input))
inputs = args
#HERE: all the broadcast dims have the same length now
......
......@@ -47,29 +47,43 @@ theano.configparser.AddConfigVar('on_shape_error',
# Utilities
def out2in(*local_opts):
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace)
if not name:
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if not name:
#import pdb;pdb.set_trace()
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def _fill_chain(new_out, orig_inputs):
......@@ -1075,7 +1089,7 @@ class ShapeFeature(object):
for r, s in izip(node.outputs, o_shapes):
self.set_shape(r, s)
def on_change_input(self, fgraph, node, i, r, new_r):
def on_change_input(self, fgraph, node, i, r, new_r, reason):
if new_r not in self.shape_of:
# It happen that the fgraph didn't called on_import for some
# new_r. This happen when new_r don't have an
......@@ -2102,6 +2116,14 @@ def local_IncSubtensor_serialize(node):
#print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
# We register it in a TopoOptimizer inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 6 passes.
compile.optdb.register('pre_local_IncSubtensor_serialize',
in2out(local_IncSubtensor_serialize),
#Just before canonizer
0.99, 'fast_run')
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
......@@ -3717,7 +3739,8 @@ register_specialize(local_add_specialize)
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut,
local_fill_sink))
local_fill_sink),
name='mul_canonizer_groups')
def check_for_x_over_absX(numerators, denominators):
......@@ -3859,7 +3882,8 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, local_fill_cut,
local_fill_sink))
local_fill_sink),
name='add_canonizer_group')
register_canonicalize(local_add_canonizer, name='local_add_canonizer')
......
......@@ -124,13 +124,27 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z
g = FunctionGraph([x, y, z], [e])
self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}("
"InplaceDimShuffle{x,0,1}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,0}(x), y)), z)]"), str(g))
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace = (
"[Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z)]")
init_str_g_noinplace = (
"[Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z)]")
self.assertTrue(str(g) in (init_str_g_inplace, init_str_g_noinplace),
str(g))
opt_str_g_inplace = (
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]")
opt_str_g_noinplace = (
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z)]")
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}(Elemwise"
"{add,no_inplace}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle"
"{x,0,1}(y)), z)]"), str(g))
self.assertTrue(str(g) in (opt_str_g_inplace, opt_str_g_noinplace),
str(g))
def test_add_canonizer_problem0():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论