提交 99589bb9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Simplify boolean operations with any and all

上级 8a6d2aae
...@@ -1041,13 +1041,12 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1041,13 +1041,12 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# list of bools indicating if each input is connected to the cost # list of bools indicating if each input is connected to the cost
inputs_connected = [ inputs_connected = [
( (
True any(
in [
input_to_output and output_to_cost input_to_output and output_to_cost
for input_to_output, output_to_cost in zip( for input_to_output, output_to_cost in zip(
input_to_outputs, outputs_connected input_to_outputs, outputs_connected
) )
] )
) )
for input_to_outputs in connection_pattern for input_to_outputs in connection_pattern
] ]
...@@ -1067,25 +1066,24 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1067,25 +1066,24 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# List of bools indicating if each input only has NullType outputs # List of bools indicating if each input only has NullType outputs
only_connected_to_nan = [ only_connected_to_nan = [
( (
True not any(
not in [
in_to_out and out_to_cost and not out_nan in_to_out and out_to_cost and not out_nan
for in_to_out, out_to_cost, out_nan in zip( for in_to_out, out_to_cost, out_nan in zip(
in_to_outs, outputs_connected, ograd_is_nan in_to_outs, outputs_connected, ograd_is_nan
) )
] )
) )
for in_to_outs in connection_pattern for in_to_outs in connection_pattern
] ]
if True not in inputs_connected: if not any(inputs_connected):
# All outputs of this op are disconnected so we can skip # All outputs of this op are disconnected so we can skip
# Calling the op's grad method and report that the inputs # Calling the op's grad method and report that the inputs
# are disconnected # are disconnected
# (The op's grad method could do this too, but this saves the # (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case) # implementer the trouble of worrying about this case)
input_grads = [disconnected_type() for ipt in inputs] input_grads = [disconnected_type() for ipt in inputs]
elif False not in only_connected_to_nan: elif all(only_connected_to_nan):
# All inputs are only connected to nan gradients, so we don't # All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient # need to bother calling the grad method. We know the gradient
# with respect to all connected inputs is nan. # with respect to all connected inputs is nan.
......
...@@ -201,12 +201,12 @@ class DimShuffle(ExternalCOp): ...@@ -201,12 +201,12 @@ class DimShuffle(ExternalCOp):
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
) )
for expected, b in zip(self.input_broadcastable, ib): for expected, b in zip(self.input_broadcastable, ib):
if expected is True and b is False: if expected and not b:
raise TypeError( raise TypeError(
"The broadcastable pattern of the " "The broadcastable pattern of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
) )
# else, expected == b or expected is False and b is True # else, expected == b or not expected and b
# Both case are good. # Both case are good.
out_static_shape = [] out_static_shape = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论