提交 c77a8a1b authored 作者: Frederic's avatar Frederic

Add contig version of pow, cos, sin, log, log2, log10

上级 68aad6d2
...@@ -1649,6 +1649,31 @@ class Pow(BinaryScalarOp): ...@@ -1649,6 +1649,31 @@ class Pow(BinaryScalarOp):
return (first_part, second_part) return (first_part, second_part)
def c_code_contiguous(self, node, name, (x, y), (z, ), sub):
if (not theano.config.lib.amdlibm or
# We compare the dtype AND the broadcast flag
# as this function do not broadcast
node.inputs[0].type != node.outputs[0].type or
node.inputs[1].type != node.outputs[0].type):
raise theano.gof.utils.MethodNotDefined()
if node.inputs[0].type == float32 and self.amd_float32 is not None:
dtype = 'float'
fct = "amd_vrsa_powf"
# amdlibm 3.0 do not have a float64 version of this SIMD function
#elif node.inputs[0].type == float64 and self.amd_float64 is not None:
# dtype = 'double'
# fct = self.amd_float64
else:
raise theano.gof.utils.MethodNotDefined()
return """
npy_intp n = PyArray_SIZE(%(z)s);
%(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
%(dtype)s * y = (%(dtype)s*) PyArray_DATA(%(y)s);
%(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
%(fct)s(n, x, y, z);
""" % locals()
pow = Pow(upcast_out, name='pow') pow = Pow(upcast_out, name='pow')
...@@ -2053,6 +2078,9 @@ inv = Inv(upgrade_to_float, name='inv') ...@@ -2053,6 +2078,9 @@ inv = Inv(upgrade_to_float, name='inv')
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
""" log base e """ """ log base e """
amd_float32 = "amd_vrsa_logf"
amd_float64 = "amd_vrda_log"
def impl(self, x): def impl(self, x):
return numpy.log(x) return numpy.log(x)
...@@ -2076,6 +2104,9 @@ log = Log(upgrade_to_float, name='log') ...@@ -2076,6 +2104,9 @@ log = Log(upgrade_to_float, name='log')
class Log2(UnaryScalarOp): class Log2(UnaryScalarOp):
""" log base 2 """ """ log base 2 """
amd_float32 = "amd_vrsa_log2f"
amd_float64 = "amd_vrda_log2"
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
...@@ -2096,6 +2127,9 @@ log2 = Log2(upgrade_to_float, name='log2') ...@@ -2096,6 +2127,9 @@ log2 = Log2(upgrade_to_float, name='log2')
class Log10(UnaryScalarOp): class Log10(UnaryScalarOp):
""" log base 10 """ """ log base 10 """
amd_float32 = "amd_vrsa_log10f"
amd_float64 = "amd_vrda_log10"
def impl(self, x): def impl(self, x):
return numpy.log10(x) return numpy.log10(x)
...@@ -2268,6 +2302,9 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg') ...@@ -2268,6 +2302,9 @@ rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg')
class Cos(UnaryScalarOp): class Cos(UnaryScalarOp):
amd_float32 = "amd_vrsa_cosf"
amd_float64 = "amd_vrda_cos"
def impl(self, x): def impl(self, x):
return numpy.cos(x) return numpy.cos(x)
...@@ -2306,6 +2343,9 @@ arccos = ArcCos(upgrade_to_float, name='arccos') ...@@ -2306,6 +2343,9 @@ arccos = ArcCos(upgrade_to_float, name='arccos')
class Sin(UnaryScalarOp): class Sin(UnaryScalarOp):
amd_float32 = "amd_vrsa_sinf"
amd_float64 = "amd_vrda_sin"
def impl(self, x): def impl(self, x):
return numpy.sin(x) return numpy.sin(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论