提交 6b6ed43a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor aesara.tensor.basic_opt.local_fill_to_alloc

上级 3bd247e4
...@@ -1652,50 +1652,42 @@ def local_fill_sink(fgraph, node): ...@@ -1652,50 +1652,42 @@ def local_fill_sink(fgraph, node):
@register_specialize @register_specialize
@register_stabilize @register_stabilize
# @register_canonicalize # We make full pass after the canonizer phase.
@local_optimizer([fill]) @local_optimizer([fill])
def local_fill_to_alloc(fgraph, node): def local_fill_to_alloc(fgraph, node):
"""fill(s,v) -> alloc(v, shape(s)) r"""Remove `fill`\s or replace them with `Alloc`\s.
This is an important optimization because with the shape_to_shape_i `Alloc`\s are preferable because they replace explicit tensor dependencies
optimization, the dependency on 's' is often removed. with their dependencies on those tensors' shapes, and sometimes those
shapes can be computed without needing to compute the tensors themselves.
"""
if node.op == fill: XXX: This rewrite can produce inconsistent results, so do *not* consider
r, v = node.inputs making it a canonicalization until those inconsistencies are
if v.type == node.outputs[0].type: resolved/justified.
# this is a useless fill, erase it. """
rval = [v] shape_ref, values_ref = node.inputs
elif v.type.broadcastable == node.outputs[0].type.broadcastable: out_type = node.outputs[0].type
# this is a cast
rval = [cast(v, node.outputs[0].type.dtype)] if values_ref.type.broadcastable == out_type.broadcastable:
elif r.type.broadcastable == node.outputs[0].type.broadcastable: # The assumption here is that `values_ref` already has the same shape
# we are broadcasting v somehow, but not r # as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
o = broadcast_like(v, r, fgraph, dtype=v.dtype)
# XXX FIXME TODO: The only way this can be determined is if one
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
# equal.
# This is an old rewrite, and it's only a
# "specialization/stabilization", so we're going to leave it be for
# now.
return [values_ref]
if shape_ref.type.broadcastable == out_type.broadcastable:
# In this case, we assume that some broadcasting is needed (otherwise
# the condition above would've been true), so we replace the `fill`
# with an `Alloc`.
o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
copy_stack_trace(node.outputs[0], o) copy_stack_trace(node.outputs[0], o)
rval = [o] return [o]
else:
# we are broadcasting both v and r,
# the output shape must be computed
#
# TODO: implement this case (including a test!)
#
# I think the strategy should be to extend the shorter
# shape vector with 1s (how?) and then take the
# elementwise max of the two. - how to flag an error of
# shape mismatch where broadcasting should be illegal?
return return
# TODO: cut out un-necessary dimshuffles of v
assert rval[0].type == node.outputs[0].type, (
"rval",
rval[0].type,
"orig",
node.outputs[0].type,
"node",
node,
) # aesara.printing.debugprint(node.outputs[0], file='str'))
return rval
# Register this after stabilize at 1.5 to make sure stabilize don't # Register this after stabilize at 1.5 to make sure stabilize don't
......
...@@ -1292,12 +1292,10 @@ def test_local_fill_useless(): ...@@ -1292,12 +1292,10 @@ def test_local_fill_useless():
x = dvector() x = dvector()
y = dvector() y = dvector()
z = lvector() z = lvector()
m = dmatrix()
x_ = np.random.random((5,)) x_ = np.random.random((5,))
y_ = np.random.random((5,)) y_ = np.random.random((5,))
z_ = (np.random.random((5,)) * 5).astype("int64") z_ = (np.random.random((5,)) * 5).astype("int64")
m_ = np.random.random((5, 5))
# basic case # basic case
f = function([x], at.fill(x, x) * 2, mode=mode_opt) f = function([x], at.fill(x, x) * 2, mode=mode_opt)
...@@ -1329,12 +1327,35 @@ def test_local_fill_useless(): ...@@ -1329,12 +1327,35 @@ def test_local_fill_useless():
assert [node.op for node in f.maker.fgraph.toposort()] == [mul] assert [node.op for node in f.maker.fgraph.toposort()] == [mul]
f(x_, y_) f(x_, y_)
# Test with different number of dimensions
# The fill is not useless, so it should stay def test_local_fill_to_alloc():
f = function([m, x], at.fill(m, x) * 2, mode=mode_opt) x = dvector()
ops = [node.op.__class__ for node in f.maker.fgraph.toposort()] m = dmatrix()
assert Alloc in ops
f(m_, x_) x_ = np.random.random((5,))
m_ = np.random.random((5, 5))
y = at.fill(m, x)
mode = mode_opt.including("stabilize", "local_fill_to_alloc").excluding(
"useless", "local_useless_fill"
)
f = function([m, x], y, mode=mode)
assert Alloc in [node.op.__class__ for node in f.maker.fgraph.toposort()]
res = f(m_, x_)
exp_res = np.broadcast_to(x_, m_.shape)
assert np.array_equal(res, exp_res)
y = at.fill(x, m)
f = function([m, x], y, mode=mode)
assert Alloc not in [node.op.__class__ for node in f.maker.fgraph.toposort()]
res = f(m_, x_)
assert np.array_equal(res, m_)
class TestLocalCanonicalizeAlloc: class TestLocalCanonicalizeAlloc:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论