提交 a4315d4b authored 作者: Frederic's avatar Frederic

More check and move to GPU the symbolic input if needed.

上级 8d8899b2
# This is work in progress
from theano import Op, Apply
from theano import Op, Apply, tensor
from theano.gof import local_optimizer
from theano.sandbox.cuda import cuda_available, GpuOp
......@@ -7,7 +7,8 @@ from theano.sandbox.neighbours import Images2Neibs
if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType
from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host
from theano.sandbox.cuda.basic_ops import (
as_cuda_ndarray_variable, host_from_gpu, gpu_from_host)
from theano.sandbox.cuda.opt import register_opt as register_gpu_opt
......@@ -21,13 +22,16 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
self.mode = mode
def make_node(self, ten4, neib_shape, neib_step):
assert ten4.dtype == 'float32'
if not isinstance(ten4.type, CudaNdarrayType):
raise TypeError('ten4 must be cudandarray', ten4)
ten4 = as_cuda_ndarray_variable(ten4)
neib_shape = tensor.as_tensor_variable(neib_shape)
neib_step = tensor.as_tensor_variable(neib_step)
assert ten4.ndim == 4
assert ten4.dtype == 'float32'
assert neib_shape.ndim == 1
assert neib_step.ndim == 1
assert "int" in neib_shape.dtype
assert "int" in neib_step.dtype
return Apply(self, [ten4, neib_shape, neib_step],
[CudaNdarrayType(broadcastable=(False, False),
......
......@@ -29,6 +29,9 @@ class GpuImages2Neibs(Images2Neibs, Op):
self.mode = mode
def make_node(self, ten4, neib_shape, neib_step):
ten4 = as_gpuarray_variable(ten4)
neib_shape = T.as_tensor_variable(neib_shape)
neib_step = T.as_tensor_variable(neib_step)
assert ten4.ndim == 4
assert neib_shape.ndim == 1
......@@ -36,10 +39,6 @@ class GpuImages2Neibs(Images2Neibs, Op):
assert "int" in neib_shape.dtype
assert "int" in neib_step.dtype
ten4 = as_gpuarray_variable(ten4)
neib_shape = T.as_tensor_variable(neib_shape)
neib_step = T.as_tensor_variable(neib_step)
return Apply(self, [ten4, neib_shape, neib_step],
[GpuArrayType(broadcastable=(False, False),
dtype=ten4.type.dtype)()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论