提交 d3f724ab authored 作者: Frederic's avatar Frederic

Refactor code to reuse it later.

上级 b0a47681
...@@ -909,7 +909,22 @@ class UnaryScalarOp(ScalarOp): ...@@ -909,7 +909,22 @@ class UnaryScalarOp(ScalarOp):
node.inputs[0].type != node.outputs[0].type): node.inputs[0].type != node.outputs[0].type):
raise theano.gof.utils.MethodNotDefined() raise theano.gof.utils.MethodNotDefined()
dtype = node.inputs[0].dtype dtype = node.inputs[0].type.dtype_specs()[1]
fct_call = self.c_code_contiguous_raw(dtype, 'n', 'x', 'z')
return """
{
npy_intp n = PyArray_SIZE(%(z)s);
%(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
%(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
%(fct_call)s;
}
""" % locals()
def c_code_contiguous_raw(self, dtype, n, i, o):
if not config.lib.amdlibm:
raise theano.gof.utils.MethodNotDefined()
if dtype.startswith('npy_'):
dtype = dtype[4:]
if dtype == 'float32' and self.amd_float32 is not None: if dtype == 'float32' and self.amd_float32 is not None:
dtype = 'float' dtype = 'float'
fct = self.amd_float32 fct = self.amd_float32
...@@ -918,12 +933,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -918,12 +933,7 @@ class UnaryScalarOp(ScalarOp):
fct = self.amd_float64 fct = self.amd_float64
else: else:
raise theano.gof.utils.MethodNotDefined() raise theano.gof.utils.MethodNotDefined()
return """ return "%(fct)s(%(n)s, %(i)s, %(o)s)" % locals()
npy_intp n = PyArray_SIZE(%(z)s);
%(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
%(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
%(fct)s(n, x, z);
""" % locals()
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论