提交 aa616e6f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Pass dtype directly to zeros_like

上级 935ce79a
......@@ -273,7 +273,7 @@ class IfElse(_NoPythonOp):
# `condition` does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX)
condition_grad = condition.zeros_like(dtype=config.floatX)
return [
condition_grad,
......
......@@ -1323,8 +1323,8 @@ class LogicalComparison(BinaryScalarOp):
x, y = inputs
assert outputs[0].type == bool
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
def c_code_cache_version(self):
......@@ -1358,7 +1358,7 @@ class FixedLogicalComparison(UnaryScalarOp):
def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs
assert outputs[0].type == bool
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]
def c_code_cache_version(self):
super_version = super().c_code_cache_version()
......@@ -1577,7 +1577,7 @@ class InRange(LogicalComparison):
)
raise NotImplementedError(msg)
elif elem.type in discrete_types:
return elem.zeros_like().astype(config.floatX)
return elem.zeros_like(dtype=config.floatX)
else:
return elem.zeros_like()
......@@ -1611,13 +1611,13 @@ class Switch(ScalarOp):
second_part = switch(cond, 0.0, gz)
if outputs[0].type in discrete_types:
first_part = ift.zeros_like(config.floatX)
second_part = iff.zeros_like(config.floatX)
first_part = ift.zeros_like(dtype=config.floatX)
second_part = iff.zeros_like(dtype=config.floatX)
# cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = cond.zeros_like().astype(config.floatX)
condition_grad = cond.zeros_like(dtype=config.floatX)
return (condition_grad, first_part, second_part)
......@@ -1644,7 +1644,7 @@ class UnaryBitOp(UnaryScalarOp):
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [inputs[0].zeros_like().astype(config.floatX)]
return [inputs[0].zeros_like(dtype=config.floatX)]
class BinaryBitOp(BinaryScalarOp):
......@@ -1664,8 +1664,8 @@ class BinaryBitOp(BinaryScalarOp):
def grad(self, inputs, output_gradients):
a, b = inputs
return [
a.zeros_like().astype(config.floatX),
b.zeros_like().astype(config.floatX),
a.zeros_like(dtype=config.floatX),
b.zeros_like(dtype=config.floatX),
]
......@@ -1776,8 +1776,8 @@ class ScalarMaximum(BinaryScalarOp):
if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
......@@ -1818,8 +1818,8 @@ class ScalarMinimum(BinaryScalarOp):
if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
......@@ -1861,7 +1861,7 @@ class Add(ScalarOp):
retval = []
for ii, inp in enumerate(inputs):
if hasattr(inp, "zeros_like"):
retval.append(inp.zeros_like().astype(config.floatX))
retval.append(inp.zeros_like(dtype=config.floatX))
else:
retval.append(grad_undefined(self, ii, inp))
else:
......@@ -1937,7 +1937,7 @@ class Mul(ScalarOp):
)
if output_type in discrete_types:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]
for input in inputs:
if gz.type in complex_types:
......@@ -1980,8 +1980,8 @@ class Sub(BinaryScalarOp):
raise NotImplementedError()
if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
first_part = gz
......@@ -2293,8 +2293,8 @@ class Pow(BinaryScalarOp):
if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
first_part = gz * y * x ** (y - 1)
......@@ -2385,7 +2385,7 @@ class Clip(ScalarOp):
def handle_int(v):
if outputs[0].type in int_types:
return v.zeros_like().astype(config.floatX)
return v.zeros_like(dtype=config.floatX)
return v
return list(map(handle_int, [gx, gmn, gmx]))
......@@ -2422,7 +2422,7 @@ class Second(BinaryScalarOp):
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like()
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
second = Second(transfer_type(1), name="second")
......@@ -2494,7 +2494,7 @@ class Cast(UnaryScalarOp):
if self.o_type in continuous_types:
return [gz]
else:
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]
def c_code_cache_version(self):
s = super().c_code_cache_version()
......@@ -2715,7 +2715,7 @@ class Trunc(UnaryScalarOp):
def grad(self, inputs, gout):
(x,) = inputs
(gz,) = gout
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......
......@@ -589,7 +589,7 @@ class TensorFromScalar(COp):
# Currently, pytensor.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(config.floatX)]
return [s.zeros_like(dtype=config.floatX)]
raise NotImplementedError("grad not implemented for complex dtypes")
......@@ -1876,7 +1876,7 @@ class MakeVector(COp):
def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]
grads = [output_gradients[0][i] for i in range(len(inputs))]
return grads
......
......@@ -946,7 +946,7 @@ class Subtensor(COp):
x = inputs[0]
rest = inputs[1:]
if x.dtype in discrete_dtypes:
first = x.zeros_like().astype(config.floatX)
first = x.zeros_like(dtype=config.floatX)
else:
# For best optimization, we let this as an inc.
# This allow the opt local_IncSubtensor_serialize to apply first.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论