提交 b6732d25 authored 作者: skaae's avatar skaae

add c code

上级 921b8eb2
...@@ -406,11 +406,22 @@ class J1(UnaryScalarOp): ...@@ -406,11 +406,22 @@ class J1(UnaryScalarOp):
return scipy.special.j1(x) return scipy.special.j1(x)
def impl(self, x): def impl(self, x):
return self.st_impl(x) if imported_scipy_special:
return self.st_impl(x)
else:
super(J1, self).impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
raise NotImplementedError() raise NotImplementedError()
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
j1(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -429,13 +440,24 @@ class J0(UnaryScalarOp): ...@@ -429,13 +440,24 @@ class J0(UnaryScalarOp):
return scipy.special.j0(x) return scipy.special.j0(x)
def impl(self, x): def impl(self, x):
return self.st_impl(x) if imported_scipy_special:
return self.st_impl(x)
else:
super(J0, self).impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
return [gz * -1 * j1(x)] return [gz * -1 * j1(x)]
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in float_types:
return """%(z)s =
j0(%(x)s);""" % locals()
raise NotImplementedError('only floating point is implemented')
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论