提交 e6d204ec authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Replace use of Rebroadcast by SpecifyShape in convert_variable

Adds condition in convert_variable_test which would fail before this change
上级 be719a61
......@@ -328,10 +328,8 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
# Note that, in this case, `var.type != self`, because that's
# covered by the branch above.
# Use the more specific broadcast/shape information of the two
return aesara.tensor.basic.Rebroadcast(
*[(i, b) for i, b in enumerate(self.broadcastable)]
)(var)
# Use the more specific static shape information of the two
return aesara.tensor.specify_shape(var, self.shape)
def value_zeros(self, shape):
"""Create an numpy ndarray full of 0 values.
......
......@@ -141,15 +141,17 @@ more specific/informative than ``v1``'s--and both are compatible.
>>> v3 = v2.type.filter_variable(v1)
>>> v3
Rebroadcast{(0, False),(1, True)}.0
SpecifyShape.0
>>> import aesara
>>> aesara.dprint(v3, print_type=True)
Rebroadcast{(0, False),(1, True)} [id A] <TensorType(float64, (None, 1))> ''
SpecifyShape [id A] <TensorType(float64, (2, 1))>
|<TensorType(float64, (2, None))> [id B] <TensorType(float64, (2, None))>
|TensorConstant{2} [id C] <TensorType(int8, ())>
|TensorConstant{1} [id D] <TensorType(int8, ())>
Performing this in the opposite direction returned the output of a
:class:`Rebroadcast`\ :class:`Op`. This :class:`Rebroadcast` uses ``v1`` as an
:class:`SpecifyShape`\ :class:`Op`. This :class:`SpecifyShape` uses ``v1`` static shape as an
input and serves to produce a new :class:`Variable` that has a :class:`Type` compatible with
both ``v1`` and ``v2``.
......
......@@ -37,17 +37,19 @@ Aesara propagates information about shapes within a graph using specialized
Specifying Exact Shape
======================
Currently, specifying a shape is not as easy and flexible as we wish and we plan some
upgrade. Here is the current state of what can be done:
You can create variables with static shape information as follows:
.. code-block:: python
aesara.tensor.tensor("float64", shape=(4, 3, 2))
- You can pass the shape info directly to the ``ConvOp`` created
when calling ``conv2d``. You simply set the parameters ``image_shape``
and ``filter_shape`` inside the call. They must be tuples of 4
elements. For example:
You can also pass shape infomation directly to some :class:`Op`\s, like ``RandomVariables``
.. code-block:: python
aesara.tensor.nnet.conv2d(..., image_shape=(7, 3, 5, 5), filter_shape=(2, 3, 4, 4))
aesara.tensor.random.normal(size=(7, 3, 5, 5))
- You can use the :class:`SpecifyShape`\ :class:`Op` to add shape information anywhere in the
graph. This allows to perform some optimizations. In the following example,
......
......@@ -3214,9 +3214,6 @@ def test_local_Unique_scalar(return_index, return_counts, return_inverse):
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, DimShuffle)
assert y_opt_start.owner.inputs[0] == x
......@@ -3266,11 +3263,6 @@ def test_local_Unique_Alloc_lift(
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Alloc) for node in y_opt_fg.apply_nodes)
......@@ -3329,11 +3321,6 @@ def test_local_Unique_BroadcastTo(
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, BroadcastTo) for node in y_opt_fg.apply_nodes)
......@@ -3395,11 +3382,6 @@ def test_local_Unique_Repeat(
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Repeat) for node in y_opt_fg.apply_nodes)
......@@ -3456,11 +3438,6 @@ def test_local_Unique_second(
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if y_opt.owner and isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
y_opt_start = y_opt_start.owner.inputs[0]
......
......@@ -6,7 +6,7 @@ import pytest
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.tensor.basic import Rebroadcast
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.type import TensorType
......@@ -93,6 +93,10 @@ def test_filter_variable():
res = test_type.filter_variable(test_var2, allow_convert=True)
assert res.type == test_type
test_type3 = TensorType(config.floatX, shape=(1, 20))
res = test_type3.filter_variable(test_var, allow_convert=True)
assert res.type == test_type3
def test_filter_strict():
test_type = TensorType(config.floatX, [])
......@@ -277,7 +281,7 @@ def test_fixed_shape_convert_variable():
t3 = TensorType("float64", (False, True))
t3_var = t3()
res = t2.convert_variable(t3_var)
assert isinstance(res.owner.op, Rebroadcast)
assert isinstance(res.owner.op, SpecifyShape)
t3 = TensorType("float64", (False, False))
t4 = TensorType("float64", (3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论