Unverified 提交 d10f2459 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix wrong dtype arguments (#1456)

上级 0ea61bcc
......@@ -2302,7 +2302,7 @@ def _is_zero(x):
class ZeroGrad(ViewOp):
def grad(self, args, g_outs):
return [g_out.zeros_like(g_out) for g_out in g_outs]
return [g_out.zeros_like() for g_out in g_outs]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......
......@@ -3237,7 +3237,7 @@ class Exp2(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz * exp2(x) * log(np.array(2, dtype=x.type)),)
return (gz * exp2(x) * log(np.array(2, dtype=x.dtype)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3376,7 +3376,7 @@ class Deg2Rad(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz * np.array(np.pi / 180, dtype=gz.type),)
return (gz * np.array(np.pi / 180, dtype=gz.dtype),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3411,7 +3411,7 @@ class Rad2Deg(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz * np.array(180.0 / np.pi, dtype=gz.type),)
return (gz * np.array(180.0 / np.pi, dtype=gz.dtype),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3484,7 +3484,7 @@ class ArcCos(UnaryScalarOp):
else:
return [x.zeros_like()]
return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),)
return (-gz / sqrt(np.array(1, dtype=x.dtype) - sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3558,7 +3558,7 @@ class ArcSin(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),)
return (gz / sqrt(np.array(1, dtype=x.dtype) - sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3630,7 +3630,7 @@ class ArcTan(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz / (np.array(1, dtype=x.type) + sqr(x)),)
return (gz / (np.array(1, dtype=x.dtype) + sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3753,7 +3753,7 @@ class ArcCosh(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),)
return (gz / sqrt(sqr(x) - np.array(1, dtype=x.dtype)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3830,7 +3830,7 @@ class ArcSinh(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),)
return (gz / sqrt(sqr(x) + np.array(1, dtype=x.dtype)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......@@ -3908,7 +3908,7 @@ class ArcTanh(UnaryScalarOp):
else:
return [x.zeros_like()]
return (gz / (np.array(1, dtype=x.type) - sqr(x)),)
return (gz / (np.array(1, dtype=x.dtype) - sqr(x)),)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......
......@@ -193,7 +193,7 @@ class AddSD_ccode(_NoPythonCOp):
def local_inplace_addsd_ccode(fgraph, node):
"""Rewrite to insert inplace versions of `AddSD`."""
if isinstance(node.op, sparse.AddSD) and config.cxx:
out_dtype = ps.upcast(*node.inputs)
out_dtype = ps.upcast(*[inp.type.dtype for inp in node.inputs])
if out_dtype != node.inputs[1].dtype:
return
new_node = AddSD_ccode(format=node.inputs[0].type.format, inplace=True)(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论