提交 aa616e6f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Pass dtype directly to zeros_like

上级 935ce79a
...@@ -273,7 +273,7 @@ class IfElse(_NoPythonOp): ...@@ -273,7 +273,7 @@ class IfElse(_NoPythonOp):
# `condition` does affect the elements of the output so it is connected. # `condition` does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that # For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition # condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX) condition_grad = condition.zeros_like(dtype=config.floatX)
return [ return [
condition_grad, condition_grad,
......
...@@ -1323,8 +1323,8 @@ class LogicalComparison(BinaryScalarOp): ...@@ -1323,8 +1323,8 @@ class LogicalComparison(BinaryScalarOp):
x, y = inputs x, y = inputs
assert outputs[0].type == bool assert outputs[0].type == bool
return [ return [
x.zeros_like().astype(config.floatX), x.zeros_like(dtype=config.floatX),
y.zeros_like().astype(config.floatX), y.zeros_like(dtype=config.floatX),
] ]
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -1358,7 +1358,7 @@ class FixedLogicalComparison(UnaryScalarOp): ...@@ -1358,7 +1358,7 @@ class FixedLogicalComparison(UnaryScalarOp):
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs (x,) = inputs
assert outputs[0].type == bool assert outputs[0].type == bool
return [x.zeros_like().astype(config.floatX)] return [x.zeros_like(dtype=config.floatX)]
def c_code_cache_version(self): def c_code_cache_version(self):
super_version = super().c_code_cache_version() super_version = super().c_code_cache_version()
...@@ -1577,7 +1577,7 @@ class InRange(LogicalComparison): ...@@ -1577,7 +1577,7 @@ class InRange(LogicalComparison):
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
elif elem.type in discrete_types: elif elem.type in discrete_types:
return elem.zeros_like().astype(config.floatX) return elem.zeros_like(dtype=config.floatX)
else: else:
return elem.zeros_like() return elem.zeros_like()
...@@ -1611,13 +1611,13 @@ class Switch(ScalarOp): ...@@ -1611,13 +1611,13 @@ class Switch(ScalarOp):
second_part = switch(cond, 0.0, gz) second_part = switch(cond, 0.0, gz)
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
first_part = ift.zeros_like(config.floatX) first_part = ift.zeros_like(dtype=config.floatX)
second_part = iff.zeros_like(config.floatX) second_part = iff.zeros_like(dtype=config.floatX)
# cond does affect the elements of the output so it is connected. # cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that # For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition # condition + epsilon always triggers the same branch as condition
condition_grad = cond.zeros_like().astype(config.floatX) condition_grad = cond.zeros_like(dtype=config.floatX)
return (condition_grad, first_part, second_part) return (condition_grad, first_part, second_part)
...@@ -1644,7 +1644,7 @@ class UnaryBitOp(UnaryScalarOp): ...@@ -1644,7 +1644,7 @@ class UnaryBitOp(UnaryScalarOp):
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [inputs[0].zeros_like().astype(config.floatX)] return [inputs[0].zeros_like(dtype=config.floatX)]
class BinaryBitOp(BinaryScalarOp): class BinaryBitOp(BinaryScalarOp):
...@@ -1664,8 +1664,8 @@ class BinaryBitOp(BinaryScalarOp): ...@@ -1664,8 +1664,8 @@ class BinaryBitOp(BinaryScalarOp):
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
a, b = inputs a, b = inputs
return [ return [
a.zeros_like().astype(config.floatX), a.zeros_like(dtype=config.floatX),
b.zeros_like().astype(config.floatX), b.zeros_like(dtype=config.floatX),
] ]
...@@ -1776,8 +1776,8 @@ class ScalarMaximum(BinaryScalarOp): ...@@ -1776,8 +1776,8 @@ class ScalarMaximum(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [ return [
x.zeros_like().astype(config.floatX), x.zeros_like(dtype=config.floatX),
y.zeros_like().astype(config.floatX), y.zeros_like(dtype=config.floatX),
] ]
# This form handle the case when both value are the same. # This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0. # In that case, gx will be gz, gy will be 0.
...@@ -1818,8 +1818,8 @@ class ScalarMinimum(BinaryScalarOp): ...@@ -1818,8 +1818,8 @@ class ScalarMinimum(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [ return [
x.zeros_like().astype(config.floatX), x.zeros_like(dtype=config.floatX),
y.zeros_like().astype(config.floatX), y.zeros_like(dtype=config.floatX),
] ]
# This form handle the case when both value are the same. # This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0. # In that case, gx will be gz, gy will be 0.
...@@ -1861,7 +1861,7 @@ class Add(ScalarOp): ...@@ -1861,7 +1861,7 @@ class Add(ScalarOp):
retval = [] retval = []
for ii, inp in enumerate(inputs): for ii, inp in enumerate(inputs):
if hasattr(inp, "zeros_like"): if hasattr(inp, "zeros_like"):
retval.append(inp.zeros_like().astype(config.floatX)) retval.append(inp.zeros_like(dtype=config.floatX))
else: else:
retval.append(grad_undefined(self, ii, inp)) retval.append(grad_undefined(self, ii, inp))
else: else:
...@@ -1937,7 +1937,7 @@ class Mul(ScalarOp): ...@@ -1937,7 +1937,7 @@ class Mul(ScalarOp):
) )
if output_type in discrete_types: if output_type in discrete_types:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs] return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]
for input in inputs: for input in inputs:
if gz.type in complex_types: if gz.type in complex_types:
...@@ -1980,8 +1980,8 @@ class Sub(BinaryScalarOp): ...@@ -1980,8 +1980,8 @@ class Sub(BinaryScalarOp):
raise NotImplementedError() raise NotImplementedError()
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [ return [
x.zeros_like().astype(config.floatX), x.zeros_like(dtype=config.floatX),
y.zeros_like().astype(config.floatX), y.zeros_like(dtype=config.floatX),
] ]
first_part = gz first_part = gz
...@@ -2293,8 +2293,8 @@ class Pow(BinaryScalarOp): ...@@ -2293,8 +2293,8 @@ class Pow(BinaryScalarOp):
if outputs[0].type in discrete_types: if outputs[0].type in discrete_types:
return [ return [
x.zeros_like().astype(config.floatX), x.zeros_like(dtype=config.floatX),
y.zeros_like().astype(config.floatX), y.zeros_like(dtype=config.floatX),
] ]
first_part = gz * y * x ** (y - 1) first_part = gz * y * x ** (y - 1)
...@@ -2385,7 +2385,7 @@ class Clip(ScalarOp): ...@@ -2385,7 +2385,7 @@ class Clip(ScalarOp):
def handle_int(v): def handle_int(v):
if outputs[0].type in int_types: if outputs[0].type in int_types:
return v.zeros_like().astype(config.floatX) return v.zeros_like(dtype=config.floatX)
return v return v
return list(map(handle_int, [gx, gmn, gmx])) return list(map(handle_int, [gx, gmn, gmx]))
...@@ -2422,7 +2422,7 @@ class Second(BinaryScalarOp): ...@@ -2422,7 +2422,7 @@ class Second(BinaryScalarOp):
# to deal with real-valued inputs by rounding them to the # to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient # nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined # is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like() return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
second = Second(transfer_type(1), name="second") second = Second(transfer_type(1), name="second")
...@@ -2494,7 +2494,7 @@ class Cast(UnaryScalarOp): ...@@ -2494,7 +2494,7 @@ class Cast(UnaryScalarOp):
if self.o_type in continuous_types: if self.o_type in continuous_types:
return [gz] return [gz]
else: else:
return [x.zeros_like().astype(config.floatX)] return [x.zeros_like(dtype=config.floatX)]
def c_code_cache_version(self): def c_code_cache_version(self):
s = super().c_code_cache_version() s = super().c_code_cache_version()
...@@ -2715,7 +2715,7 @@ class Trunc(UnaryScalarOp): ...@@ -2715,7 +2715,7 @@ class Trunc(UnaryScalarOp):
def grad(self, inputs, gout): def grad(self, inputs, gout):
(x,) = inputs (x,) = inputs
(gz,) = gout (gz,) = gout
return [x.zeros_like().astype(config.floatX)] return [x.zeros_like(dtype=config.floatX)]
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
......
...@@ -589,7 +589,7 @@ class TensorFromScalar(COp): ...@@ -589,7 +589,7 @@ class TensorFromScalar(COp):
# Currently, pytensor.grad insists that the dtype of the returned # Currently, pytensor.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX. # gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes: if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(config.floatX)] return [s.zeros_like(dtype=config.floatX)]
raise NotImplementedError("grad not implemented for complex dtypes") raise NotImplementedError("grad not implemented for complex dtypes")
...@@ -1876,7 +1876,7 @@ class MakeVector(COp): ...@@ -1876,7 +1876,7 @@ class MakeVector(COp):
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass # If the output is of an integer dtype, no gradient shall pass
if self.dtype in discrete_dtypes: if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs] return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]
grads = [output_gradients[0][i] for i in range(len(inputs))] grads = [output_gradients[0][i] for i in range(len(inputs))]
return grads return grads
......
...@@ -946,7 +946,7 @@ class Subtensor(COp): ...@@ -946,7 +946,7 @@ class Subtensor(COp):
x = inputs[0] x = inputs[0]
rest = inputs[1:] rest = inputs[1:]
if x.dtype in discrete_dtypes: if x.dtype in discrete_dtypes:
first = x.zeros_like().astype(config.floatX) first = x.zeros_like(dtype=config.floatX)
else: else:
# For best optimization, we let this as an inc. # For best optimization, we let this as an inc.
# This allow the opt local_IncSubtensor_serialize to apply first. # This allow the opt local_IncSubtensor_serialize to apply first.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论