提交 12b63cdc authored 作者: Jesse Livezey's avatar Jesse Livezey

tests pass with new casting

上级 b00ab8eb
......@@ -75,6 +75,34 @@ class BaseCorrMM(gof.Op):
str(self.subsample),
str(self.filter_dilation))
def cast(self, in1, in2):
"""
Upcast input variables if neccesary.
"""
float_types = ['float32', 'float64']
dtype_1 = in1.type.dtype
dtype_2 = in2.type.dtype
assert dtype_1 in float_types
assert dtype_2 in float_types
if dtype_1 == 'float64' or dtype_2 == 'float64':
dtype_o = 'float64'
else:
dtype_o = 'float32'
if dtype_1 != dtype_o:
out1 = in1.astype(dtype_o)
else:
out1 = in1
if dtype_2 != dtype_o:
out2 = in2.astype(dtype_o)
else:
out2 = in2
return out1, out2
def c_support_code(self):
return blas_header_text()
......@@ -383,6 +411,7 @@ class CorrMM(BaseCorrMM):
def make_node(self, img, kern):
img = as_tensor_variable(img)
kern = as_tensor_variable(kern)
img, kern = self.cast(img, kern)
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4:
......@@ -445,6 +474,7 @@ class CorrMM_gradWeights(BaseCorrMM):
def make_node(self, img, topgrad, shape=None):
img = as_tensor_variable(img)
topgrad = as_tensor_variable(topgrad)
img, topgrad = self.cast(img, topgrad)
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if topgrad.type.ndim != 4:
......@@ -546,6 +576,7 @@ class CorrMM_gradInputs(BaseCorrMM):
def make_node(self, kern, topgrad, shape=None):
kern = as_tensor_variable(kern)
topgrad = as_tensor_variable(topgrad)
kern, topgrad = self.cast(kern, topgrad)
if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor')
if topgrad.type.ndim != 4:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论