提交 6b6ed43a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor aesara.tensor.basic_opt.local_fill_to_alloc

上级 3bd247e4
......@@ -1652,50 +1652,42 @@ def local_fill_sink(fgraph, node):
@register_specialize
@register_stabilize
# @register_canonicalize # We make full pass after the canonizer phase.
@local_optimizer([fill])
def local_fill_to_alloc(fgraph, node):
"""fill(s,v) -> alloc(v, shape(s))
r"""Remove `fill`\s or replace them with `Alloc`\s.
This is an important optimization because with the shape_to_shape_i
optimization, the dependency on 's' is often removed.
`Alloc`\s are preferable because they replace explicit tensor dependencies
with their dependencies on those tensors' shapes, and sometimes those
shapes can be computed without needing to compute the tensors themselves.
XXX: This rewrite can produce inconsistent results, so do *not* consider
making it a canonicalization until those inconsistencies are
resolved/justified.
"""
if node.op == fill:
r, v = node.inputs
if v.type == node.outputs[0].type:
# this is a useless fill, erase it.
rval = [v]
elif v.type.broadcastable == node.outputs[0].type.broadcastable:
# this is a cast
rval = [cast(v, node.outputs[0].type.dtype)]
elif r.type.broadcastable == node.outputs[0].type.broadcastable:
# we are broadcasting v somehow, but not r
o = broadcast_like(v, r, fgraph, dtype=v.dtype)
copy_stack_trace(node.outputs[0], o)
rval = [o]
else:
# we are broadcasting both v and r,
# the output shape must be computed
#
# TODO: implement this case (including a test!)
#
# I think the strategy should be to extend the shorter
# shape vector with 1s (how?) and then take the
# elementwise max of the two. - how to flag an error of
# shape mismatch where broadcasting should be illegal?
return
# TODO: cut out un-necessary dimshuffles of v
assert rval[0].type == node.outputs[0].type, (
"rval",
rval[0].type,
"orig",
node.outputs[0].type,
"node",
node,
) # aesara.printing.debugprint(node.outputs[0], file='str'))
return rval
shape_ref, values_ref = node.inputs
out_type = node.outputs[0].type
if values_ref.type.broadcastable == out_type.broadcastable:
# The assumption here is that `values_ref` already has the same shape
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
# XXX FIXME TODO: The only way this can be determined is if one
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
# equal.
# This is an old rewrite, and it's only a
# "specialization/stabilization", so we're going to leave it be for
# now.
return [values_ref]
if shape_ref.type.broadcastable == out_type.broadcastable:
# In this case, we assume that some broadcasting is needed (otherwise
# the condition above would've been true), so we replace the `fill`
# with an `Alloc`.
o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
copy_stack_trace(node.outputs[0], o)
return [o]
return
# Register this after stabilize at 1.5 to make sure stabilize don't
......
......@@ -1292,12 +1292,10 @@ def test_local_fill_useless():
x = dvector()
y = dvector()
z = lvector()
m = dmatrix()
x_ = np.random.random((5,))
y_ = np.random.random((5,))
z_ = (np.random.random((5,)) * 5).astype("int64")
m_ = np.random.random((5, 5))
# basic case
f = function([x], at.fill(x, x) * 2, mode=mode_opt)
......@@ -1329,12 +1327,35 @@ def test_local_fill_useless():
assert [node.op for node in f.maker.fgraph.toposort()] == [mul]
f(x_, y_)
# Test with different number of dimensions
# The fill is not useless, so it should stay
f = function([m, x], at.fill(m, x) * 2, mode=mode_opt)
ops = [node.op.__class__ for node in f.maker.fgraph.toposort()]
assert Alloc in ops
f(m_, x_)
def test_local_fill_to_alloc():
x = dvector()
m = dmatrix()
x_ = np.random.random((5,))
m_ = np.random.random((5, 5))
y = at.fill(m, x)
mode = mode_opt.including("stabilize", "local_fill_to_alloc").excluding(
"useless", "local_useless_fill"
)
f = function([m, x], y, mode=mode)
assert Alloc in [node.op.__class__ for node in f.maker.fgraph.toposort()]
res = f(m_, x_)
exp_res = np.broadcast_to(x_, m_.shape)
assert np.array_equal(res, exp_res)
y = at.fill(x, m)
f = function([m, x], y, mode=mode)
assert Alloc not in [node.op.__class__ for node in f.maker.fgraph.toposort()]
res = f(m_, x_)
assert np.array_equal(res, m_)
class TestLocalCanonicalizeAlloc:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论