提交 071e7535 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

reverted Elemwise behavior

上级 d1df6bb3
......@@ -13,6 +13,7 @@ import numpy
#import scalar_opt
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
y = Tensor(broadcastable = ybc, dtype = 'float64')('y')
......@@ -92,10 +93,8 @@ class _test_dimshuffle_lift(unittest.TestCase):
g = Env([x, y, z], [e])
gof.ExpandMacros().optimize(g)
self.failUnless(str(g) == "[add(InplaceDimShuffle{x,0,1}(add(InplaceDimShuffle{x,0}(x), y)), z)]", str(g))
print g
lift_dimshuffle.optimize(g)
gof.ExpandMacros().optimize(g)
print g
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
......
......@@ -313,19 +313,19 @@ class Elemwise(Op):
target_length = max([input.type.ndim for input in inputs])
if len(inputs) > 1:
inputs = [lcomplete(input, *inputs) for input in inputs]
# args = []
# for input in inputs:
# length = input.type.ndim
# difference = target_length - length
# if not difference:
# args.append(input)
# else:
# # TODO: use LComplete instead
# args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# inputs = args
# if len(inputs) > 1:
# inputs = [lcomplete(input, *inputs) for input in inputs]
args = []
for input in inputs:
length = input.type.ndim
difference = target_length - length
if not difference:
args.append(input)
else:
# TODO: use LComplete instead
args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length), inplace = True)(input))
inputs = args
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
......
......@@ -488,7 +488,7 @@ class NavigatorOptimizer(Optimizer):
class TopoOptimizer(NavigatorOptimizer):
def __init__(self, local_opt, order = 'out_to_in', ignore_newtrees = False, failure_callback = None):
def __init__(self, local_opt, order = 'in_to_out', ignore_newtrees = False, failure_callback = None):
if order not in ['out_to_in', 'in_to_out']:
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order
......@@ -516,6 +516,7 @@ class TopoOptimizer(NavigatorOptimizer):
except:
self.detach_updater(env, u)
raise
class OpKeyOptimizer(NavigatorOptimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论