Unverified 提交 d10f2459 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix wrong dtype arguments (#1456)

上级 0ea61bcc
...@@ -2302,7 +2302,7 @@ def _is_zero(x): ...@@ -2302,7 +2302,7 @@ def _is_zero(x):
class ZeroGrad(ViewOp): class ZeroGrad(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs] return [g_out.zeros_like() for g_out in g_outs]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
......
...@@ -3237,7 +3237,7 @@ class Exp2(UnaryScalarOp): ...@@ -3237,7 +3237,7 @@ class Exp2(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz * exp2(x) * log(np.array(2, dtype=x.type)),) return (gz * exp2(x) * log(np.array(2, dtype=x.dtype)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3376,7 +3376,7 @@ class Deg2Rad(UnaryScalarOp): ...@@ -3376,7 +3376,7 @@ class Deg2Rad(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz * np.array(np.pi / 180, dtype=gz.type),) return (gz * np.array(np.pi / 180, dtype=gz.dtype),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3411,7 +3411,7 @@ class Rad2Deg(UnaryScalarOp): ...@@ -3411,7 +3411,7 @@ class Rad2Deg(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz * np.array(180.0 / np.pi, dtype=gz.type),) return (gz * np.array(180.0 / np.pi, dtype=gz.dtype),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3484,7 +3484,7 @@ class ArcCos(UnaryScalarOp): ...@@ -3484,7 +3484,7 @@ class ArcCos(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) return (-gz / sqrt(np.array(1, dtype=x.dtype) - sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3558,7 +3558,7 @@ class ArcSin(UnaryScalarOp): ...@@ -3558,7 +3558,7 @@ class ArcSin(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) return (gz / sqrt(np.array(1, dtype=x.dtype) - sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3630,7 +3630,7 @@ class ArcTan(UnaryScalarOp): ...@@ -3630,7 +3630,7 @@ class ArcTan(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz / (np.array(1, dtype=x.type) + sqr(x)),) return (gz / (np.array(1, dtype=x.dtype) + sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3753,7 +3753,7 @@ class ArcCosh(UnaryScalarOp): ...@@ -3753,7 +3753,7 @@ class ArcCosh(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),) return (gz / sqrt(sqr(x) - np.array(1, dtype=x.dtype)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3830,7 +3830,7 @@ class ArcSinh(UnaryScalarOp): ...@@ -3830,7 +3830,7 @@ class ArcSinh(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),) return (gz / sqrt(sqr(x) + np.array(1, dtype=x.dtype)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
...@@ -3908,7 +3908,7 @@ class ArcTanh(UnaryScalarOp): ...@@ -3908,7 +3908,7 @@ class ArcTanh(UnaryScalarOp):
else: else:
return [x.zeros_like()] return [x.zeros_like()]
return (gz / (np.array(1, dtype=x.type) - sqr(x)),) return (gz / (np.array(1, dtype=x.dtype) - sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
......
...@@ -193,7 +193,7 @@ class AddSD_ccode(_NoPythonCOp): ...@@ -193,7 +193,7 @@ class AddSD_ccode(_NoPythonCOp):
def local_inplace_addsd_ccode(fgraph, node): def local_inplace_addsd_ccode(fgraph, node):
"""Rewrite to insert inplace versions of `AddSD`.""" """Rewrite to insert inplace versions of `AddSD`."""
if isinstance(node.op, sparse.AddSD) and config.cxx: if isinstance(node.op, sparse.AddSD) and config.cxx:
out_dtype = ps.upcast(*node.inputs) out_dtype = ps.upcast(*[inp.type.dtype for inp in node.inputs])
if out_dtype != node.inputs[1].dtype: if out_dtype != node.inputs[1].dtype:
return return
new_node = AddSD_ccode(format=node.inputs[0].type.format, inplace=True)( new_node = AddSD_ccode(format=node.inputs[0].type.format, inplace=True)(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论