提交 ea704d6b authored 作者: Paul Christiano's avatar Paul Christiano

Clarified docstring and removed unnecessary subtensor.

上级 d6cd1459
...@@ -2181,7 +2181,7 @@ def local_subtensor_of_alloc(node): ...@@ -2181,7 +2181,7 @@ def local_subtensor_of_alloc(node):
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
def local_subtensor_of_dot(node): def local_subtensor_of_dot(node):
""" """
T.dot(A,B)[i] -> T.dot(A[i],B) This optimization translates T.dot(A, B)[xs+ys] into T.dot(A[xs,:], B[:,ys]).
""" """
if not isinstance(node.op, Subtensor): if not isinstance(node.op, Subtensor):
...@@ -2193,6 +2193,7 @@ def local_subtensor_of_dot(node): ...@@ -2193,6 +2193,7 @@ def local_subtensor_of_dot(node):
# We don't want to compute twice the sub part. # We don't want to compute twice the sub part.
if len(node.inputs[0].clients) > 1: if len(node.inputs[0].clients) > 1:
return return
a = node.inputs[0].owner.inputs[0] a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1] b = node.inputs[0].owner.inputs[1]
...@@ -2209,7 +2210,7 @@ def local_subtensor_of_dot(node): ...@@ -2209,7 +2210,7 @@ def local_subtensor_of_dot(node):
import pdb; pdb.set_trace() import pdb; pdb.set_trace()
a_sub = Subtensor(a_indices).make_node(a, *a_inputs) a_sub = Subtensor(a_indices).make_node(a, *a_inputs)
b_sub = Subtensor(b_indices).make_node(b, *b_inputs) b_sub = b if len(b_indices) == 1 else Subtensor(b_indices).make_node(b, *b_inputs)
return [T.dot(a_sub, b_sub)] return [T.dot(a_sub, b_sub)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论