提交 66d7f5c7 authored 作者: abergeron's avatar abergeron

Merge pull request #1789 from nouiz/faster_opt

Faster opt (and other unrelated things).
...@@ -296,38 +296,15 @@ class GpuDimShuffle(GpuOp): ...@@ -296,38 +296,15 @@ class GpuDimShuffle(GpuOp):
def __init__(self, input_broadcastable, new_order): def __init__(self, input_broadcastable, new_order):
input_broadcastable = tuple(input_broadcastable) input_broadcastable = tuple(input_broadcastable)
self.input_broadcastable = input_broadcastable self.input_broadcastable = input_broadcastable
new_order = tuple(new_order)
self.new_order = new_order self.new_order = new_order
# list of dimensions of the input to drop
self.drop = []
# this maps i before dropping dimensions to j after dropping
# dimensions so self.shuffle can be set properly later on
i2j = {}
j = 0
for i, b in enumerate(input_broadcastable): for i, b in enumerate(input_broadcastable):
if i not in new_order: if i not in new_order:
# we want to drop this dimension because it's not a if not b:
# value in new_order
if b == 1: # 1 aka True
self.drop.append(i)
else:
# we cannot drop non-broadcastable dimensions # we cannot drop non-broadcastable dimensions
raise ValueError("You cannot drop a non-broadcastable" raise ValueError("You cannot drop a non-broadcastable"
" dimension.", " dimension.",
(input_broadcastable, new_order)) (input_broadcastable, new_order))
else:
i2j[i] = j
j += 1
# transposition of non-broadcastable dimensions This is how
# the dimensions will be permuted, without accounting for the
# extra 'x' broadcastable dimensions to insert.
self.shuffle = [i2j[x] for x in new_order if x != 'x']
# list of dimensions of the output that are broadcastable and
# were not in the original input
self.augment = [i for i, x in enumerate(new_order) if x == 'x']
self.view_map = {0: [0]} self.view_map = {0: [0]}
...@@ -481,8 +458,6 @@ class GpuDimShuffle(GpuOp): ...@@ -481,8 +458,6 @@ class GpuDimShuffle(GpuOp):
print self print self
print "IN BROAD", self.input_broadcastable print "IN BROAD", self.input_broadcastable
print "NEW ORDER", self.new_order print "NEW ORDER", self.new_order
print "SHUFFLE", self.shuffle
print "AUGMENT", self.augment
print '------------' print '------------'
print '' print ''
print sio.getvalue() print sio.getvalue()
......
...@@ -2611,8 +2611,10 @@ register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy') ...@@ -2611,8 +2611,10 @@ register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
def local_fill_sink(node): def local_fill_sink(node):
""" """
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e))) f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
""" """
if not (node.op and isinstance(node.op, T.Elemwise) and node.op != T.fill): if not isinstance(node.op, T.Elemwise) or node.op == T.fill:
return False return False
models = [] models = []
inputs = [] inputs = []
...@@ -2622,7 +2624,7 @@ def local_fill_sink(node): ...@@ -2622,7 +2624,7 @@ def local_fill_sink(node):
inputs.append(input.owner.inputs[1]) inputs.append(input.owner.inputs[1])
else: else:
inputs.append(input) inputs.append(input)
if inputs == node.inputs: if not models:
return False return False
c = node.op(*inputs) c = node.op(*inputs)
for model in models: for model in models:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论