提交 3f31dc24 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Normalize AbstractConv inputs types to not have nodes with mixed input types.

上级 42318305
...@@ -235,6 +235,11 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -235,6 +235,11 @@ class AbstractConv2d(BaseAbstractConv2d):
filter_flip) filter_flip)
def make_node(self, img, kern): def make_node(self, img, kern):
# Normalize the inputs types
ktype = img.type.clone(dtype=kern.dtype,
broadcastable=kern.broadcastable)
kern = ktype.filter_variable(kern)
if img.type.ndim != 4: if img.type.ndim != 4:
raise TypeError('img must be 4D tensor') raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4: if kern.type.ndim != 4:
...@@ -328,6 +333,11 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -328,6 +333,11 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
# Update shape/height_width # Update shape/height_width
def make_node(self, img, topgrad, shape): def make_node(self, img, topgrad, shape):
# Normalize the inputs types
gtype = img.type.clone(dtype=topgrad.dtype,
broadcastable=topgrad.broadcastable)
topgrad = gtype.filter_variable(topgrad)
if img.type.ndim != 4: if img.type.ndim != 4:
raise TypeError('img must be 4D tensor') raise TypeError('img must be 4D tensor')
if topgrad.type.ndim != 4: if topgrad.type.ndim != 4:
...@@ -415,6 +425,11 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -415,6 +425,11 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
# Update shape/height_width # Update shape/height_width
def make_node(self, kern, topgrad, shape): def make_node(self, kern, topgrad, shape):
# Normalize the inputs types
gtype = kern.type.clone(dtype=topgrad.dtype,
broadcastable=topgrad.broadcastable)
topgrad = gtype.filter_variable(topgrad)
if kern.type.ndim != 4: if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor') raise TypeError('kern must be 4D tensor')
if topgrad.type.ndim != 4: if topgrad.type.ndim != 4:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论