提交 35268e1f authored 作者: Frederic Bastien's avatar Frederic Bastien

Call as_tensor_variable() before checking the attribute on it.

上级 734db136
......@@ -261,10 +261,10 @@ class Pool(Op):
self.mode = mode
def make_node(self, x):
if x.type.ndim != 4:
raise TypeError()
# TODO: consider restricting the dtype?
x = tensor.as_tensor_variable(x)
if x.type.ndim != 4:
raise TypeError()
# If the input shape are broadcastable we can have 0 in the output shape
broad = x.broadcastable[:2] + (False, False)
out = tensor.TensorType(x.dtype, broad)
......@@ -641,12 +641,12 @@ class MaxPoolGrad(PoolGrad):
def make_node(self, x, maxout, gz):
# make_node should only be called by the grad function of
# Pool, so these asserts should not fail.
assert isinstance(x, Variable) and x.ndim == 4
assert isinstance(maxout, Variable) and maxout.ndim == 4
assert isinstance(gz, Variable) and gz.ndim == 4
x = tensor.as_tensor_variable(x)
maxout = tensor.as_tensor_variable(maxout)
gz = tensor.as_tensor_variable(gz)
assert isinstance(x, Variable) and x.ndim == 4
assert isinstance(maxout, Variable) and maxout.ndim == 4
assert isinstance(gz, Variable) and gz.ndim == 4
return Apply(self, [x, maxout, gz], [x.type()])
......@@ -823,10 +823,10 @@ class AveragePoolGrad(PoolGrad):
def make_node(self, x, gz, dummy=None):
# make_node should only be called by the grad function of
# Pool, so these asserts should not fail.
assert isinstance(x, Variable) and x.ndim == 4
assert isinstance(gz, Variable) and gz.ndim == 4
x = tensor.as_tensor_variable(x)
gz = tensor.as_tensor_variable(gz)
assert isinstance(x, Variable) and x.ndim == 4
assert isinstance(gz, Variable) and gz.ndim == 4
return Apply(self, [x, gz], [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论