提交 7c8b5808 authored 作者: Jesse Livezey's avatar Jesse Livezey

cleanup from review

上级 c7653446
......@@ -75,33 +75,13 @@ class BaseCorrMM(gof.Op):
str(self.subsample),
str(self.filter_dilation))
def cast(self, in1, in2):
@staticmethod
def as_common_dtype(in1, in2):
"""
Upcast input variables if neccesary.
"""
float_types = ['float32', 'float64']
dtype_1 = in1.type.dtype
dtype_2 = in2.type.dtype
assert dtype_1 in float_types
assert dtype_2 in float_types
if dtype_1 == 'float64' or dtype_2 == 'float64':
dtype_o = 'float64'
else:
dtype_o = 'float32'
if dtype_1 != dtype_o:
out1 = in1.astype(dtype_o)
else:
out1 = in1
if dtype_2 != dtype_o:
out2 = in2.astype(dtype_o)
else:
out2 = in2
return out1, out2
dtype = theano.scalar.upcast(in1.dtype, in2.dtype)
return in1.astype(dtype), in2.astype(dtype)
def c_support_code(self):
return blas_header_text()
......@@ -411,7 +391,7 @@ class CorrMM(BaseCorrMM):
def make_node(self, img, kern):
img = as_tensor_variable(img)
kern = as_tensor_variable(kern)
img, kern = self.cast(img, kern)
img, kern = self.as_common_dtype(img, kern)
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4:
......@@ -474,7 +454,7 @@ class CorrMM_gradWeights(BaseCorrMM):
def make_node(self, img, topgrad, shape=None):
img = as_tensor_variable(img)
topgrad = as_tensor_variable(topgrad)
img, topgrad = self.cast(img, topgrad)
img, topgrad = self.as_common_dtype(img, topgrad)
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if topgrad.type.ndim != 4:
......@@ -576,7 +556,7 @@ class CorrMM_gradInputs(BaseCorrMM):
def make_node(self, kern, topgrad, shape=None):
kern = as_tensor_variable(kern)
topgrad = as_tensor_variable(topgrad)
kern, topgrad = self.cast(kern, topgrad)
kern, topgrad = self.as_common_dtype(kern, topgrad)
if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor')
if topgrad.type.ndim != 4:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论