提交 a3831280 authored 作者: Frederic's avatar Frederic

Make test work for sparse type too.

上级 5e9bf42a
......@@ -256,10 +256,8 @@ def Rop(f, wrt, eval_points):
eval_point = as_tensor_variable(eval_point)
try:
wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable)
if wrt_dim != eval_dim:
if wrt_elem.type.ndim != eval_point.type.ndim:
raise ValueError('Element ' +
str(i) +
' of wrt/eval_point have mismatched ' +
......@@ -267,9 +265,9 @@ def Rop(f, wrt, eval_points):
str(wrt_dim) +
' versus ' +
str(eval_dim))
except:
# wrt_elem and eval_point can be non-tensor variable which do
# not have broadcastable flags
except AttributeError:
# wrt_elem and eval_point don't always have ndim like random type
# Tensor, Sparse and CudaNdArray have the ndim attribute
pass
seen_nodes = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论