提交 eb2b9afb authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Remove unnecessary use of patternbroadcast

The behavior was already accounted by filter_variable which is called directly on as a fallback by the optimizer routines
上级 e6d204ec
......@@ -19,7 +19,7 @@ from aesara.sparse.basic import (
usmm,
)
from aesara.tensor import blas
from aesara.tensor.basic import as_tensor_variable, cast, patternbroadcast
from aesara.tensor.basic import as_tensor_variable, cast
from aesara.tensor.basic_opt import register_canonicalize, register_specialize
from aesara.tensor.math import mul, neg, sub
from aesara.tensor.shape import shape, specify_shape
......@@ -42,13 +42,7 @@ def local_csm_properties_csm(fgraph, node):
if node.op == csm_properties:
(csm,) = node.inputs
if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR):
# csm.owner.inputs could be broadcastable. In that case, we have
# to adjust the broadcasting flag here.
ret_var = [
patternbroadcast(i, o.broadcastable)
for i, o in zip(csm.owner.inputs, node.outputs)
]
return ret_var
return csm.owner.inputs
return False
......
......@@ -61,7 +61,6 @@ from aesara.tensor.basic import (
get_scalar_constant_value,
join,
ones_like,
patternbroadcast,
stack,
switch,
tensor_copy,
......@@ -2425,15 +2424,6 @@ def local_join_empty(fgraph, node):
# by an error in the old join op.
copy_stack_trace(node.outputs, ret)
if not o.type.is_super(ret.type):
assert ret.dtype == o.dtype
assert ret.ndim == o.ndim
ret = patternbroadcast(ret, node.outputs[0].broadcastable)
# Copy over stacktrace from previous output
# (after patternbroadcast op) for same reasons as before.
copy_stack_trace(node.outputs, ret)
return [ret]
......@@ -2832,20 +2822,7 @@ def local_reshape_lift(fgraph, node):
# Copy stacktrace from both previous Reshape and UnaryElemwise op
# because an error in new cg could have been caused by either ops.
copy_stack_trace(node.outputs + node.inputs, e)
# In rare case the original broadcast was (False, True), but
# the new one is (False, False). So don't crash in that case.
if not node.outputs[0].type.is_super(e.type):
re = patternbroadcast(e, node.outputs[0].broadcastable)
# Copy over stack trace.
# If the graph fails it is usually due to the fact that a dimension
# that should be broadcastable does not actually have length 1,
copy_stack_trace(e, re)
else:
re = e
return [re]
return [e]
register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy")
......
......@@ -30,11 +30,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.raise_op import Assert
from aesara.tensor.basic import (
as_tensor_variable,
get_scalar_constant_value,
patternbroadcast,
)
from aesara.tensor.basic import as_tensor_variable, get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.var import TensorConstant, TensorVariable
......@@ -2704,11 +2700,7 @@ class AbstractConv2d(AbstractConv):
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_bottom = patternbroadcast(d_bottom, bottom.broadcastable)
d_bottom = bottom.type.filter_variable(d_bottom)
d_weights = patternbroadcast(d_weights, weights.broadcastable)
d_weights = weights.type.filter_variable(d_weights)
return d_bottom, d_weights
......@@ -2765,11 +2757,7 @@ class AbstractConv3d(AbstractConv):
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_bottom = patternbroadcast(d_bottom, bottom.broadcastable)
d_bottom = bottom.type.filter_variable(d_bottom)
d_weights = patternbroadcast(d_weights, weights.broadcastable)
d_weights = weights.type.filter_variable(d_weights)
return d_bottom, d_weights
......@@ -3062,11 +3050,7 @@ class AbstractConv2d_gradWeights(AbstractConv_gradWeights):
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_bottom = patternbroadcast(d_bottom, bottom.broadcastable)
d_bottom = bottom.type.filter_variable(d_bottom)
d_top = patternbroadcast(d_top, top.broadcastable)
d_top = top.type.filter_variable(d_top)
d_height_width = (aesara.gradient.DisconnectedType()(),)
......@@ -3129,11 +3113,7 @@ class AbstractConv3d_gradWeights(AbstractConv_gradWeights):
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_bottom = patternbroadcast(d_bottom, bottom.broadcastable)
d_bottom = bottom.type.filter_variable(d_bottom)
d_top = patternbroadcast(d_top, top.broadcastable)
d_top = top.type.filter_variable(d_top)
d_depth_height_width = (aesara.gradient.DisconnectedType()(),)
......@@ -3452,11 +3432,7 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_weights = patternbroadcast(d_weights, weights.broadcastable)
d_weights = weights.type.filter_variable(d_weights)
d_top = patternbroadcast(d_top, top.broadcastable)
d_top = top.type.filter_variable(d_top)
d_height_width = (aesara.gradient.DisconnectedType()(),)
......@@ -3519,11 +3495,7 @@ class AbstractConv3d_gradInputs(AbstractConv_gradInputs):
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
# that the dimensions are broadcastable.
# Also make sure that the gradient lives on the same device than
# the corresponding input.
d_weights = patternbroadcast(d_weights, weights.broadcastable)
d_weights = weights.type.filter_variable(d_weights)
d_top = patternbroadcast(d_top, top.broadcastable)
d_top = top.type.filter_variable(d_top)
d_depth_height_width = (aesara.gradient.DisconnectedType()(),)
......
......@@ -823,11 +823,6 @@ def local_abstract_batch_norm_train(fgraph, node):
)
results.append(running_var)
results = [
at.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)
]
for var in aesara.graph.basic.vars_between(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
......@@ -862,11 +857,6 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
g_wrt_bias = at_sum(dy, axis=axes, keepdims=True)
results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias]
results = [
at.patternbroadcast(r, r_orig.broadcastable)
for (r, r_orig) in zip(results, node.outputs)
]
for var in aesara.graph.basic.vars_between(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
......@@ -895,7 +885,6 @@ def local_abstract_batch_norm_inference(fgraph, node):
epsilon = epsilon.astype("float32")
result = (x - estimated_mean) * (scale / sqrt(estimated_variance + epsilon)) + bias
result = at.patternbroadcast(result, node.outputs[0].broadcastable)
for var in aesara.graph.basic.vars_between(node.inputs, [result]):
if var not in node.inputs:
......
......@@ -164,7 +164,6 @@ def local_abstractconv_gradweight_gemm(fgraph, node):
if node.op.filter_flip:
flip = (slice(None),) * (rval.ndim - 2) + (slice(None, None, -1),) * 2
rval = rval[flip]
rval = aesara.tensor.patternbroadcast(rval, node.outputs[0].broadcastable)
copy_stack_trace(node.outputs[0], rval)
return [rval]
......@@ -193,7 +192,6 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node):
# need to flip the kernel if necessary
if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1, ::-1]
rval = aesara.tensor.patternbroadcast(rval, node.outputs[0].broadcastable)
copy_stack_trace(node.outputs[0], rval)
return [rval]
......@@ -393,10 +391,8 @@ def local_conv2d_gradweight_cpu(fgraph, node):
if node.op.border_mode == "valid":
res = res.dimshuffle((1, 0, 2, 3))
res = res[:, :, ::-1, ::-1]
res = aesara.tensor.patternbroadcast(res, node.outputs[0].broadcastable)
copy_stack_trace(node.outputs[0], res)
return [res]
......@@ -485,8 +481,6 @@ def local_conv2d_gradinputs_cpu(fgraph, node):
)
din = din(topgrad, filters)
copy_stack_trace(node.outputs[0], din)
din = aesara.tensor.patternbroadcast(din, node.outputs[0].broadcastable)
copy_stack_trace(node.outputs[0], din)
return [din]
......
......@@ -23,7 +23,6 @@ from aesara.tensor.basic import (
concatenate,
extract_constant,
get_scalar_constant_value,
patternbroadcast,
switch,
)
from aesara.tensor.basic_opt import (
......@@ -533,14 +532,6 @@ def local_subtensor_merge(fgraph, node):
# because of either of the two original slicing operations
orig_out = node.outputs[0]
copy_stack_trace([orig_out, node.inputs[0]], out)
# Restore original broadcastable dimensions that `subtens()` may
# have been unable to infer again
if not orig_out.type.is_super(out.type):
assert out.dtype == orig_out.dtype
assert out.ndim == orig_out.ndim
out = patternbroadcast(out, orig_out.broadcastable)
copy_stack_trace([orig_out, node.inputs[0]], out)
return [out]
......@@ -658,11 +649,6 @@ def local_subtensor_of_alloc(fgraph, node):
rval = alloc(nw_val, *nw_dims)
if not isinstance(rval, (list, tuple)):
rval = [rval]
if not node.outputs[0].type.is_super(rval[0].type):
# It happen that the make_node() isn't able to infer the same pattern.
# We know it is safe, so fix that.
rval[0] = patternbroadcast(rval[0], node.outputs[0].broadcastable)
return rval
......@@ -766,7 +752,6 @@ def local_subtensor_make_vector(fgraph, node):
values = list(map(int, list(idx.value)))
ret = make_vector_op(*[x.owner.inputs[v] for v in values])
copy_stack_trace(node.outputs[0], ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable)
return [ret]
elif isinstance(idx, slice):
# The index is a slice. If it's a constant slice, we can perform the
......@@ -777,7 +762,6 @@ def local_subtensor_make_vector(fgraph, node):
)[0]
ret = make_vector_op(*x.owner.inputs[const_slice])
copy_stack_trace(node.outputs, ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable)
return [ret]
except NotScalarConstantError:
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论