提交 ea704d6b authored 作者: Paul Christiano's avatar Paul Christiano

Clarified docstring and removed unnecessary subtensor.

上级 d6cd1459
......@@ -2181,7 +2181,7 @@ def local_subtensor_of_alloc(node):
@gof.local_optimizer([Subtensor])
def local_subtensor_of_dot(node):
"""
T.dot(A,B)[i] -> T.dot(A[i],B)
This optimization translates T.dot(A, B)[xs+ys] into T.dot(A[xs,:], B[:,ys]).
"""
if not isinstance(node.op, Subtensor):
......@@ -2193,6 +2193,7 @@ def local_subtensor_of_dot(node):
# We don't want to compute twice the sub part.
if len(node.inputs[0].clients) > 1:
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
......@@ -2209,7 +2210,7 @@ def local_subtensor_of_dot(node):
import pdb; pdb.set_trace()
a_sub = Subtensor(a_indices).make_node(a, *a_inputs)
b_sub = Subtensor(b_indices).make_node(b, *b_inputs)
b_sub = b if len(b_indices) == 1 else Subtensor(b_indices).make_node(b, *b_inputs)
return [T.dot(a_sub, b_sub)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论