提交 a6e6d494 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Convert inputs of maxpool to tensor

上级 b1963ef9
......@@ -208,7 +208,8 @@ class DownsampleFactorMax(Op):
def make_node(self, x):
if x.type.ndim != 4:
raise TypeError()
# TODO: consider restrucing the dtype?
# TODO: consider restricting the dtype?
x = tensor.as_tensor_variable(x)
return gof.Apply(self, [x], [x.type()])
def perform(self, node, inp, out):
......@@ -371,6 +372,9 @@ class DownsampleFactorMaxGrad(Op):
assert isinstance(x, Variable) and x.ndim == 4
assert isinstance(maxout, Variable) and maxout.ndim == 4
assert isinstance(gz, Variable) and gz.ndim == 4
x = tensor.as_tensor_variable(x)
maxout = tensor.as_tensor_variable(maxout)
gz = tensor.as_tensor_variable(gz)
return Apply(self, [x, maxout, gz], [x.type()])
......@@ -625,6 +629,9 @@ class DownsampleFactorMaxGradGrad(Op):
assert isinstance(x, Variable) and x.ndim == 4
assert isinstance(maxout, Variable) and maxout.ndim == 4
assert isinstance(gz, Variable) and gz.ndim == 4
x = tensor.as_tensor_variable(x)
maxout = tensor.as_tensor_variable(maxout)
gz = tensor.as_tensor_variable(gz)
return Apply(self, [x, maxout, gz], [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论