提交 b2810ac1 authored 作者: Frederic's avatar Frederic

Clone the NoneConst variable. Why this is needed?

上级 4b2c6694
...@@ -1404,7 +1404,7 @@ class MaxAndArgmax(Op): ...@@ -1404,7 +1404,7 @@ class MaxAndArgmax(Op):
else: else:
all_axes = range(x.ndim) all_axes = range(x.ndim)
if axis is None: if axis is None:
axis = NoneConst axis = NoneConst.clone()
else: else:
axis = _as_tensor_variable(axis) axis = _as_tensor_variable(axis)
assert axis.ndim == 0 assert axis.ndim == 0
...@@ -1428,7 +1428,7 @@ class MaxAndArgmax(Op): ...@@ -1428,7 +1428,7 @@ class MaxAndArgmax(Op):
x, axis = inp x, axis = inp
max, argmax = out max, argmax = out
fail = sub["fail"] fail = sub["fail"]
assert node.inputs[1] is theano.tensor.type_other.NoneConst or node.inputs[1].ndim == 0 assert NoneConst.equals(node.inputs[1]) or node.inputs[1].ndim == 0
ret = """ ret = """
int axis; int axis;
if((PyObject*)%(axis)s == Py_None){ if((PyObject*)%(axis)s == Py_None){
...@@ -1532,7 +1532,7 @@ class MaxAndArgmax(Op): ...@@ -1532,7 +1532,7 @@ class MaxAndArgmax(Op):
# the gradient on its inputs is zero # the gradient on its inputs is zero
if g_max_disconnected: if g_max_disconnected:
return [x.zeros_like(), axis_grad] return [x.zeros_like(), axis_grad]
if axis is NoneConst: if NoneConst.equals(axis):
axis_ = range(x.ndim) axis_ = range(x.ndim)
else: else:
axis_ = axis axis_ = axis
...@@ -1541,7 +1541,7 @@ class MaxAndArgmax(Op): ...@@ -1541,7 +1541,7 @@ class MaxAndArgmax(Op):
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
pattern = [] pattern = []
out_dim = 0 out_dim = 0
if axis is NoneConst: if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions. # We are taking the max/argmax over all dimensions.
axis = None axis = None
for i in range(x.ndim): for i in range(x.ndim):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论