提交 9f4587bf authored 作者: Frederic Bastien's avatar Frederic Bastien

implemented Shape_i.c_code for TensorType and CudaNdarrayType.

上级 191db4ae
"""
This file test tensor op that should also operate on CudaNdaray.
"""
import sys, time
from theano import shared
from theano.compile.pfunc import pfunc
from theano import tensor
import numpy
import theano
import theano.tensor as T
# Skip test if cuda_ndarray is not available.
from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray
if cuda_ndarray.cuda_available == False:
raise SkipTest('Optional package cuda disabled')
import theano.sandbox.cuda as tcn
import theano.sandbox.cuda as cuda
import theano.compile.mode
from theano.tests import unittest_tools as utt
if theano.config.mode=='FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpu')
else:
mode_with_gpu = theano.compile.mode.get_default_mode().including('gpu')
mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpu')
def test_shape_i():
x = cuda.ftensor3()
v = numpy.zeros((3,4,5),dtype='float32')
f = theano.function([x],x.shape[1])
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,T.opt.Shape_i)
assert f(v)==4
......@@ -271,6 +271,26 @@ class Shape_i(T.Op):
out[0] = theano._asarray(x.shape[self.i], dtype='int64')
else:
out[0][...] = x.shape[self.i]
def c_code_cache_version(self):
return (0,1)
def c_code(self, node, name, (x, ), (out, ), sub):
i = self.i
if isinstance(node.inputs[0].type,T.TensorType):
return """
if(!%(out)s)
%(out)s=(PyArrayObject*)PyArray_ZEROS(0, NULL, PyArray_INT64, 0);
((npy_int64*)PyArray_DATA(%(out)s))[0]=%(x)s->dimensions[%(i)s];
"""%locals()
elif node.inputs[0].type.__class__.__name__=="CudaNdarrayType":
#Don't want to import cuda stuff here.
return """
if(!%(out)s)
%(out)s=(PyArrayObject*)PyArray_ZEROS(0, NULL, PyArray_INT64, 0);
((npy_int64*)PyArray_DATA(%(out)s))[0]=CudaNdarray_HOST_DIMS(%(x)s)[%(i)s];
"""%locals()
else:
return super(Shape_i, self).c_code(node, name, (x,), (out,), sub)
def grad(self, (x,), (gz,)):
return [None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论