提交 3a4ecfea authored 作者: Razvan Pascanu's avatar Razvan Pascanu

more details when the optimization applies

上级 0b3c01b5
...@@ -1418,8 +1418,11 @@ def scan_pushout_dot1(node): ...@@ -1418,8 +1418,11 @@ def scan_pushout_dot1(node):
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
return False return False
# Replace pattern of the form # Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value[t]) # x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value # with Sequence.reshape((-1, seq.shape[2])) \dot Value
# When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op op = node.op
sitsot_ins = op.inner_sitsot(op.inputs) sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs) sitsot_outs = op.inner_sitsot_outs(op.outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论