提交 a14c3837 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

allow one of the eval points to be None

上级 11781d77
...@@ -4556,11 +4556,12 @@ class Dot(Op): ...@@ -4556,11 +4556,12 @@ class Dot(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluted at c for a and d for b is # R_op for a \dot b evaluted at c for a and d for b is
# simply c \dot b + a \dot d # simply c \dot b + a \dot d
if None in eval_points:
return [None]
assert len(inputs) == 2 assert len(inputs) == 2
assert len(eval_points) == 2 assert len(eval_points) == 2
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != 'off' debugger_available = config.compute_test_value != 'off'
...@@ -4605,11 +4606,17 @@ class Dot(Op): ...@@ -4605,11 +4606,17 @@ class Dot(Op):
' %s and %s, respectively' % ( ' %s and %s, respectively' % (
str(input_values[i].shape), str(input_values[i].shape),
str(eval_point_values[i].shape))) str(eval_point_values[i].shape)))
if eval_points[0]:
t1 = self(eval_points[0], inputs[1]) t1 = self(eval_points[0], inputs[1])
if eval_points[1]:
t2 = self(inputs[0], eval_points[1]) t2 = self(inputs[0], eval_points[1])
if eval_points[0] and eval_points[1]:
return [t1 + t2] return [t1 + t2]
elif eval_points[0]:
return [t1]
else:
return [t2]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论