提交 4089d4d9 authored 作者: João Victor Risso's avatar João Victor Risso

Use integer_dtypes instead of int_dtypes to check grid dimensions' type

上级 8b25d693
...@@ -2774,7 +2774,7 @@ class GpuDnnTransformerDesc(COp): ...@@ -2774,7 +2774,7 @@ class GpuDnnTransformerDesc(COp):
def make_node(self, dimensions): def make_node(self, dimensions):
dimensions = as_tensor_variable(dimensions) dimensions = as_tensor_variable(dimensions)
assert dimensions.dtype in theano.tensor.basic.int_dtypes assert dimensions.dtype in theano.tensor.basic.integer_dtypes
assert dimensions.ndim == 1 assert dimensions.ndim == 1
dimensions = theano.tensor.basic.cast(dimensions, 'int64') dimensions = theano.tensor.basic.cast(dimensions, 'int64')
...@@ -2815,7 +2815,7 @@ class GpuDnnTransformerGrid(DnnBase): ...@@ -2815,7 +2815,7 @@ class GpuDnnTransformerGrid(DnnBase):
# Setup grid dimensions using input from descriptor # Setup grid dimensions using input from descriptor
grid_dims = as_tensor_variable(desc.owner.inputs[0]) grid_dims = as_tensor_variable(desc.owner.inputs[0])
assert grid_dims.dtype in theano.tensor.basic.int_dtypes assert grid_dims.dtype in theano.tensor.basic.integer_dtypes
assert grid_dims.ndim == 1 assert grid_dims.ndim == 1
# Ensure 64-bit ints are passed to the C code # Ensure 64-bit ints are passed to the C code
grid_dims = theano.tensor.basic.cast(grid_dims, 'int64') grid_dims = theano.tensor.basic.cast(grid_dims, 'int64')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论