提交 f5f1ffa8 authored 作者: notoraptor's avatar notoraptor

Make code more clear by writing 'int64' directly in castings.

上级 c66b296e
...@@ -15,8 +15,6 @@ except ImportError as e: ...@@ -15,8 +15,6 @@ except ImportError as e:
# To make sure theano is importable # To make sure theano is importable
pass pass
dtype_name_for_casting = 'int64'
class GpuPool(CGpuKernelBase): class GpuPool(CGpuKernelBase):
""" """
...@@ -70,9 +68,9 @@ class GpuPool(CGpuKernelBase): ...@@ -70,9 +68,9 @@ class GpuPool(CGpuKernelBase):
if pad.dtype not in theano.tensor.int_dtypes: if pad.dtype not in theano.tensor.int_dtypes:
raise TypeError('Padding parameters must be ints.') raise TypeError('Padding parameters must be ints.')
ws = theano.tensor.cast(ws, dtype_name_for_casting) ws = theano.tensor.cast(ws, 'int64')
stride = theano.tensor.cast(stride, dtype_name_for_casting) stride = theano.tensor.cast(stride, 'int64')
pad = theano.tensor.cast(pad, dtype_name_for_casting) pad = theano.tensor.cast(pad, 'int64')
return Apply(self, [inp, ws, stride, pad], [inp.type()]) return Apply(self, [inp, ws, stride, pad], [inp.type()])
...@@ -190,9 +188,9 @@ class GpuMaxPoolGrad(CGpuKernelBase): ...@@ -190,9 +188,9 @@ class GpuMaxPoolGrad(CGpuKernelBase):
if pad.dtype not in theano.tensor.int_dtypes: if pad.dtype not in theano.tensor.int_dtypes:
raise TypeError('Padding parameters must be ints.') raise TypeError('Padding parameters must be ints.')
ws = theano.tensor.cast(ws, dtype_name_for_casting) ws = theano.tensor.cast(ws, 'int64')
stride = theano.tensor.cast(stride, dtype_name_for_casting) stride = theano.tensor.cast(stride, 'int64')
pad = theano.tensor.cast(pad, dtype_name_for_casting) pad = theano.tensor.cast(pad, 'int64')
return Apply(self, [inp, out, out_grad, ws, stride, pad], [inp.type()]) return Apply(self, [inp, out, out_grad, ws, stride, pad], [inp.type()])
...@@ -269,9 +267,9 @@ class GpuAveragePoolGrad(CGpuKernelBase): ...@@ -269,9 +267,9 @@ class GpuAveragePoolGrad(CGpuKernelBase):
if pad.dtype not in theano.tensor.int_dtypes: if pad.dtype not in theano.tensor.int_dtypes:
raise TypeError('Padding parameters must be ints.') raise TypeError('Padding parameters must be ints.')
ws = theano.tensor.cast(ws, dtype_name_for_casting) ws = theano.tensor.cast(ws, 'int64')
stride = theano.tensor.cast(stride, dtype_name_for_casting) stride = theano.tensor.cast(stride, 'int64')
pad = theano.tensor.cast(pad, dtype_name_for_casting) pad = theano.tensor.cast(pad, 'int64')
return Apply(self, [inp, out_grad, ws, stride, pad], [inp.type()]) return Apply(self, [inp, out_grad, ws, stride, pad], [inp.type()])
...@@ -351,9 +349,9 @@ class GpuDownsampleFactorMaxGradGrad(CGpuKernelBase): ...@@ -351,9 +349,9 @@ class GpuDownsampleFactorMaxGradGrad(CGpuKernelBase):
if pad.dtype not in theano.tensor.int_dtypes: if pad.dtype not in theano.tensor.int_dtypes:
raise TypeError('Padding parameters must be ints.') raise TypeError('Padding parameters must be ints.')
ws = theano.tensor.cast(ws, dtype_name_for_casting) ws = theano.tensor.cast(ws, 'int64')
stride = theano.tensor.cast(stride, dtype_name_for_casting) stride = theano.tensor.cast(stride, 'int64')
pad = theano.tensor.cast(pad, dtype_name_for_casting) pad = theano.tensor.cast(pad, 'int64')
return Apply(self, [inp, out, out_grad, ws, stride, pad], [inp.type()]) return Apply(self, [inp, out, out_grad, ws, stride, pad], [inp.type()])
...@@ -430,9 +428,9 @@ class GpuMaxPoolRop(CGpuKernelBase): ...@@ -430,9 +428,9 @@ class GpuMaxPoolRop(CGpuKernelBase):
if pad.dtype not in theano.tensor.int_dtypes: if pad.dtype not in theano.tensor.int_dtypes:
raise TypeError('Padding parameters must be ints.') raise TypeError('Padding parameters must be ints.')
ws = theano.tensor.cast(ws, dtype_name_for_casting) ws = theano.tensor.cast(ws, 'int64')
stride = theano.tensor.cast(stride, dtype_name_for_casting) stride = theano.tensor.cast(stride, 'int64')
pad = theano.tensor.cast(pad, dtype_name_for_casting) pad = theano.tensor.cast(pad, 'int64')
return Apply(self, [inp, eval_point, ws, stride, pad], [eval_point.type()]) return Apply(self, [inp, eval_point, ws, stride, pad], [eval_point.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论