提交 c39dce3e authored 作者: Frederic's avatar Frederic

New opt: T.dot(A,B)[i] -> T.dot(A[i],B)

上级 c7e88a00
......@@ -2175,6 +2175,34 @@ def local_subtensor_of_alloc(node):
return rval
@register_canonicalize
@register_stabilize
@register_specialize
@gof.local_optimizer([Subtensor])
def local_subtensor_of_dot(node):
"""
T.dot(A,B)[i] -> T.dot(A[i],B)
"""
if not isinstance(node.op, Subtensor):
return
if (not node.inputs[0].owner or
not isinstance(node.inputs[0].owner.op, T.Dot)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if len(node.inputs[0].clients) > 1:
return
# Do we index/slice on the outer dimensions only?
if node.inputs[0].ndim >= 1 and len(node.op.idx_list) == 1:
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
a_sub = node.op.make_node(a, *node.inputs[1:]).outputs[0]
pad_dim = a.ndim - a_sub.ndim
return [T.dot(a_sub, b)]
return
@register_canonicalize
@gof.local_optimizer([T.add])
def local_IncSubtensor_serialize(node):
......
......@@ -2476,6 +2476,29 @@ class Test_alloc_zero(unittest.TestCase):
_e1[2], _e2[1])
def test_local_subtensor_of_dot():
m1 = theano.tensor.matrix()
m2 = theano.tensor.matrix()
d1 = numpy.arange(6).reshape((3, 2)).astype(config.floatX)
d2 = numpy.arange(8).reshape((2, 4)).astype(config.floatX) + 10
mode = compile.get_default_mode().including("local_subtensor_of_dot")
# [cst]
f = theano.function([m1, m2], theano.dot(m1, m2)[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert numpy.allclose(f(d1, d2), numpy.dot(d1, d2)[1])
# DimShuffle happen in FAST_COMPILE
assert isinstance(topo[-1].op, (T.blas_c.CGemv, T.blas.Gemv, T.DimShuffle))
# slice
f = theano.function([m1, m2], theano.dot(m1, m2)[1:2], mode=mode)
topo = f.maker.fgraph.toposort()
assert numpy.allclose(f(d1, d2), numpy.dot(d1, d2)[1:2])
assert isinstance(topo[-1].op, (T.blas.Dot22))
# TODO: tests with vector and tensor3d.
# TODO: new opt for AdvancedSubtensor1
def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论