提交 99589bb9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Simplify boolean operations with any and all

上级 8a6d2aae
......@@ -1041,13 +1041,12 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# list of bools indicating if each input is connected to the cost
inputs_connected = [
(
True
in [
any(
input_to_output and output_to_cost
for input_to_output, output_to_cost in zip(
input_to_outputs, outputs_connected
)
]
)
)
for input_to_outputs in connection_pattern
]
......@@ -1067,25 +1066,24 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# List of bools indicating if each input only has NullType outputs
only_connected_to_nan = [
(
True
not in [
not any(
in_to_out and out_to_cost and not out_nan
for in_to_out, out_to_cost, out_nan in zip(
in_to_outs, outputs_connected, ograd_is_nan
)
]
)
)
for in_to_outs in connection_pattern
]
if True not in inputs_connected:
if not any(inputs_connected):
# All outputs of this op are disconnected so we can skip
# Calling the op's grad method and report that the inputs
# are disconnected
# (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case)
input_grads = [disconnected_type() for ipt in inputs]
elif False not in only_connected_to_nan:
elif all(only_connected_to_nan):
# All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient
# with respect to all connected inputs is nan.
......
......@@ -201,12 +201,12 @@ class DimShuffle(ExternalCOp):
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
for expected, b in zip(self.input_broadcastable, ib):
if expected is True and b is False:
if expected and not b:
raise TypeError(
"The broadcastable pattern of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
# else, expected == b or expected is False and b is True
# else, expected == b or not expected and b
# Both case are good.
out_static_shape = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论