Unverified 提交 e96b285d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #39 from brandonwillard/fix-subtensor-shape-inference

- Fix `*Subtensor*` `Op` shape inference - Make `theano.tensor.basic.as_tensor_variable` consistent in how it handles constant conversions - Remove `TensorConstant` caching
.PHONY: help venv conda docker docstyle format style black test lint check coverage pypi
.DEFAULT_GOAL = help
PROJECT_NAME = theano
PROJECT_DIR = theano/
PYTHON = python
PIP = pip
CONDA = conda
SHELL = bash
help:
@printf "Usage:\n"
@grep -E '^[a-zA-Z_-]+:.*?# .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?# "}; {printf "\033[1;34mmake %-10s\033[0m%s\n", $$1, $$2}'
conda: # Set up a conda environment for development.
@printf "Creating conda environment...\n"
${CONDA} create --yes --name ${PROJECT_NAME}-env python=3.6
( \
${CONDA} activate ${PROJECT_NAME}-env; \
${PIP} install -U pip; \
${PIP} install -r requirements.txt; \
${CONDA} deactivate; \
)
@printf "\n\nConda environment created! \033[1;34mRun \`conda activate ${PROJECT_NAME}-env\` to activate it.\033[0m\n\n\n"
venv: # Set up a Python virtual environment for development.
@printf "Creating Python virtual environment...\n"
rm -rf ${PROJECT_NAME}-venv
${PYTHON} -m venv ${PROJECT_NAME}-venv
( \
source ${PROJECT_NAME}-venv/bin/activate; \
${PIP} install -U pip; \
${PIP} install -r requirements.txt; \
deactivate; \
)
@printf "\n\nVirtual environment created! \033[1;34mRun \`source ${PROJECT_NAME}-venv/bin/activate\` to activate it.\033[0m\n\n\n"
docstyle:
@printf "Checking documentation with pydocstyle...\n"
pydocstyle ${PROJECT_DIR}
@printf "\033[1;34mPydocstyle passes!\033[0m\n\n"
format:
@printf "Checking code format with black...\n"
black -t py36 --check ${PROJECT_DIR} tests/ setup.py conftest.py
@printf "\033[1;34mBlack passes!\033[0m\n\n"
style:
@printf "Checking code style with pylint...\n"
flake8
@printf "\033[1;34mPylint passes!\033[0m\n\n"
black: # Format code in-place using black.
black ${PROJECT_DIR} tests/ setup.py conftest.py
test: # Test code using pytest.
pytest -v tests/ ${PROJECT_DIR} --cov=${PROJECT_DIR} --cov-report=xml --html=testing-report.html --self-contained-html
coverage: test
diff-cover coverage.xml --compare-branch=master --fail-under=100
pypi:
${PYTHON} setup.py clean --all; \
${PYTHON} setup.py rotate --match=.tar.gz,.whl,.egg,.zip --keep=0; \
${PYTHON} setup.py sdist bdist_wheel; \
twine upload --skip-existing dist/*;
lint: docstyle format style # Lint code using pydocstyle, black and pylint.
check: lint test coverage # Both lint and test code. Runs `make lint` followed by `make test`.
import os
import pytest import pytest
def pytest_sessionstart(session):
os.environ["THEANO_FLAGS"] = ",".join(
[
os.environ.setdefault("THEANO_FLAGS", ""),
"warn.ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise",
]
)
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--runslow", action="store_true", default=False, help="run slow tests" "--runslow", action="store_true", default=False, help="run slow tests"
......
...@@ -12,3 +12,4 @@ sympy ...@@ -12,3 +12,4 @@ sympy
versioneer versioneer
jax; python_version > '3.6' jax; python_version > '3.6'
jaxlib; python_version > '3.6' jaxlib; python_version > '3.6'
diff-cover
...@@ -5,25 +5,14 @@ import pytest ...@@ -5,25 +5,14 @@ import pytest
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano.gof import CachedConstantError, FunctionGraph from theano.gof.fg import FunctionGraph
from theano import tensor as tt from theano import tensor as tt
class TFunctionGraph: class TestFunctionGraph:
def test_constant_cache_error(self):
v = theano.tensor.constant(1)
assert v.cached
with pytest.raises(CachedConstantError):
FunctionGraph([], [v + 1], clone=False)
def test_clone(self):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
def test_pickle(self): def test_pickle(self):
v = tt.vector() v = tt.vector()
func = theano.gof.FunctionGraph([v], [v + 1]) func = FunctionGraph([v], [v + 1])
s = pickle.dumps(func) s = pickle.dumps(func)
pickle.loads(s) pickle.loads(s)
...@@ -31,6 +20,7 @@ class TFunctionGraph: ...@@ -31,6 +20,7 @@ class TFunctionGraph:
@pytest.mark.skipif( @pytest.mark.skipif(
not theano.config.cxx, reason="G++ not available, so we need to skip this test." not theano.config.cxx, reason="G++ not available, so we need to skip this test."
) )
@pytest.mark.slow
def test_node_outputs_not_used(self): def test_node_outputs_not_used(self):
# In the past, we where removing some not used variable from # In the past, we where removing some not used variable from
# fgraph.variables event if the apply had other output used in # fgraph.variables event if the apply had other output used in
......
from itertools import count
import pickle import pickle
import pytest import pytest
import numpy as np import numpy as np
from itertools import count
from theano import sparse, shared, tensor from theano import sparse, shared, tensor
from theano.gof.graph import ( from theano.gof.graph import (
Apply, Apply,
...@@ -12,8 +14,8 @@ from theano.gof.graph import ( ...@@ -12,8 +14,8 @@ from theano.gof.graph import (
general_toposort, general_toposort,
inputs, inputs,
io_toposort, io_toposort,
is_same_graph,
Variable, Variable,
equal_computations,
) )
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
...@@ -58,10 +60,6 @@ class MyOp(Op): ...@@ -58,10 +60,6 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
##########
# inputs #
##########
class TestInputs: class TestInputs:
def test_inputs(self): def test_inputs(self):
...@@ -77,11 +75,6 @@ class TestInputs: ...@@ -77,11 +75,6 @@ class TestInputs:
assert i == [r1, r2, r5], i assert i == [r1, r2, r5], i
#############
# as_string #
#############
class X: class X:
def leaf_formatter(self, leaf): def leaf_formatter(self, leaf):
return str(leaf.type) return str(leaf.type)
...@@ -126,11 +119,6 @@ class TestStr(X): ...@@ -126,11 +119,6 @@ class TestStr(X):
assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"]
#########
# clone #
#########
class TestClone(X): class TestClone(X):
def test_accurate(self): def test_accurate(self):
r1, r2 = MyVariable(1), MyVariable(2) r1, r2 = MyVariable(1), MyVariable(2)
...@@ -186,11 +174,6 @@ class TestClone(X): ...@@ -186,11 +174,6 @@ class TestClone(X):
assert i[0] is c1 and o[0] is c1 assert i[0] is c1 and o[0] is c1
############
# toposort #
############
def prenode(obj): def prenode(obj):
if isinstance(obj, Variable): if isinstance(obj, Variable):
if obj.owner: if obj.owner:
...@@ -258,154 +241,6 @@ class TestToposort: ...@@ -258,154 +241,6 @@ class TestToposort:
assert all == [o0] assert all == [o0]
#################
# is_same_graph #
#################
class TestIsSameGraph:
def check(self, expected, debug=True):
"""
Core function to perform comparison.
:param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
with:
- `v1` and `v2` two Variables (the graphs to be compared)
- `gj` a `givens` dictionary to give as input to `is_same_graph`
- `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`
:param debug: If True, then we make sure we are testing both
implementations of `is_same_graph`.
This function also tries to call `is_same_graph` by inverting `v1` and
`v2`, and ensures the output remains the same.
"""
for v1, v2, go in expected:
for gj, oj in go:
r1 = is_same_graph(v1, v2, givens=gj, debug=debug)
assert r1 == oj
r2 = is_same_graph(v2, v1, givens=gj, debug=debug)
assert r2 == oj
def test_single_var(self):
# Test `is_same_graph` with some trivial graphs (one Variable).
x, y, z = tensor.vectors("x", "y", "z")
self.check(
[
(x, x, (({}, True),)),
(
x,
y,
(
({}, False),
({y: x}, True),
),
),
(x, tensor.neg(x), (({}, False),)),
(x, tensor.neg(y), (({}, False),)),
]
)
def test_full_graph(self):
# Test `is_same_graph` with more complex graphs.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x * 2, x * 2, (({}, True),)),
(
x * 2,
y * 2,
(
({}, False),
({y: x}, True),
),
),
(
x * 2,
y * 2,
(
({}, False),
({x: y}, True),
),
),
(
x * 2,
y * 3,
(
({}, False),
({y: x}, False),
),
),
(
t * 2,
z * 2,
(
({}, False),
({t: z}, True),
),
),
(
t * 2,
z * 2,
(
({}, False),
({z: t}, True),
),
),
(x * (y * z), (x * y) * z, (({}, False),)),
]
)
def test_merge_only(self):
# Test `is_same_graph` when `equal_computations` cannot be used.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x, t, (({}, False), ({t: x}, True))),
(
t * 2,
x * 2,
(
({}, False),
({t: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x + z,
x * y + t,
(({}, False), ({y: x}, False), ({y: x, t: z}, True)),
),
],
debug=False,
)
################
# eval #
################
class TestEval: class TestEval:
def setup_method(self): def setup_method(self):
self.x, self.y = tensor.scalars("x", "y") self.x, self.y = tensor.scalars("x", "y")
...@@ -421,9 +256,6 @@ class TestEval: ...@@ -421,9 +256,6 @@ class TestEval:
), "temporary functions must not be serialized" ), "temporary functions must not be serialized"
################
# autoname #
################
class TestAutoName: class TestAutoName:
def test_auto_name(self): def test_auto_name(self):
# Get counter value # Get counter value
...@@ -434,27 +266,14 @@ class TestAutoName: ...@@ -434,27 +266,14 @@ class TestAutoName:
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
def test_constant(self): def test_constant(self):
# Make sure the value we will use for the test aren't yet in the cache.
r1 = tensor.constant(1.5)
del tensor.constant_cache[r1.signature()]
r1 = tensor.constant(1.6)
del tensor.constant_cache[r1.signature()]
# Get counter value # Get counter value
autoname_id = next(Variable.__count__) autoname_id = next(Variable.__count__)
Variable.__count__ = count(autoname_id) Variable.__count__ = count(autoname_id)
r1 = tensor.constant(1.5) r1 = tensor.constant(1.5)
r2 = tensor.constant(1.5)
assert r1.auto_name == "auto_" + str(autoname_id), ( assert r1.auto_name == "auto_" + str(autoname_id), (
r1.auto_name, r1.auto_name,
"auto_" + str(autoname_id), "auto_" + str(autoname_id),
) )
# We reuse the same variable
assert r2.auto_name == "auto_" + str(autoname_id), (
r2.auto_name,
"auto_" + str(autoname_id),
)
assert r1 is r2
r3 = tensor.constant(1.6) r3 = tensor.constant(1.6)
assert r3.auto_name == "auto_" + str(autoname_id + 1) assert r3.auto_name == "auto_" + str(autoname_id + 1)
...@@ -506,3 +325,13 @@ class TestAutoName: ...@@ -506,3 +325,13 @@ class TestAutoName:
r2 = r1.clone() r2 = r1.clone()
assert r1.auto_name == "auto_" + str(autoname_id) assert r1.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
def test_equal_computations():
# This was a bug report by a Theano user.
c = tensor.type_other.NoneConst
assert equal_computations([c], [c])
m = tensor.matrix()
max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
from theano import tensor
from theano.gof.graph import Variable, Apply from theano.gof.graph import Variable, Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.toolbox import NodeFinder from theano.gof.toolbox import NodeFinder, is_same_graph
def as_variable(x): class TestNodeFinder:
assert isinstance(x, Variable) def test_straightforward(self):
return x class MyType(Type):
def __init__(self, name):
self.name = name
class MyType(Type):
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return isinstance(other, MyType)
def MyVariable(name): def __str__(self):
return Variable(MyType(name), None, None) return self.name
def __repr__(self):
return self.name
class MyOp(Op): def __eq__(self, other):
return isinstance(other, MyType)
__props__ = ("nin", "name") class MyOp(Op):
def __init__(self, nin, name): __props__ = ("nin", "name")
self.nin = nin
self.name = name
def make_node(self, *inputs): def __init__(self, nin, name):
assert len(inputs) == self.nin self.nin = nin
inputs = list(map(as_variable, inputs)) self.name = name
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
def __str__(self): def make_node(self, *inputs):
return self.name def as_variable(x):
assert isinstance(x, Variable)
return x
assert len(inputs) == self.nin
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
sigmoid = MyOp(1, "Sigmoid") def __str__(self):
add = MyOp(2, "Add") return self.name
dot = MyOp(2, "Dot")
sigmoid = MyOp(1, "Sigmoid")
add = MyOp(2, "Add")
dot = MyOp(2, "Dot")
def inputs(): def MyVariable(name):
x = MyVariable("x") return Variable(MyType(name), None, None)
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
def inputs():
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
class TestNodeFinder:
def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e0 = dot(y, z) e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
...@@ -83,3 +78,137 @@ class TestNodeFinder: ...@@ -83,3 +78,137 @@ class TestNodeFinder:
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([t for t in g.get_nodes(type)]) == num: if not len([t for t in g.get_nodes(type)]) == num:
raise Exception("Expected: %i times %s" % (num, type)) raise Exception("Expected: %i times %s" % (num, type))
class TestIsSameGraph:
def check(self, expected):
"""
Core function to perform comparison.
:param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
with:
- `v1` and `v2` two Variables (the graphs to be compared)
- `gj` a `givens` dictionary to give as input to `is_same_graph`
- `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`
This function also tries to call `is_same_graph` by inverting `v1` and
`v2`, and ensures the output remains the same.
"""
for v1, v2, go in expected:
for gj, oj in go:
r1 = is_same_graph(v1, v2, givens=gj)
assert r1 == oj
r2 = is_same_graph(v2, v1, givens=gj)
assert r2 == oj
def test_single_var(self):
# Test `is_same_graph` with some trivial graphs (one Variable).
x, y, z = tensor.vectors("x", "y", "z")
self.check(
[
(x, x, (({}, True),)),
(
x,
y,
(
({}, False),
({y: x}, True),
),
),
(x, tensor.neg(x), (({}, False),)),
(x, tensor.neg(y), (({}, False),)),
]
)
def test_full_graph(self):
# Test `is_same_graph` with more complex graphs.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x * 2, x * 2, (({}, True),)),
(
x * 2,
y * 2,
(
({}, False),
({y: x}, True),
),
),
(
x * 2,
y * 2,
(
({}, False),
({x: y}, True),
),
),
(
x * 2,
y * 3,
(
({}, False),
({y: x}, False),
),
),
(
t * 2,
z * 2,
(
({}, False),
({t: z}, True),
),
),
(
t * 2,
z * 2,
(
({}, False),
({z: t}, True),
),
),
(x * (y * z), (x * y) * z, (({}, False),)),
]
)
def test_merge_only(self):
# Test `is_same_graph` when `equal_computations` cannot be used.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x, t, (({}, False), ({t: x}, True))),
(
t * 2,
x * 2,
(
({}, False),
({t: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x + z,
x * y + t,
(({}, False), ({y: x}, False), ({y: x, t: z}, True)),
),
],
)
...@@ -1673,72 +1673,72 @@ class TestScan: ...@@ -1673,72 +1673,72 @@ class TestScan:
| |<RandomStateType> [id DD] | |<RandomStateType> [id DD]
| |Shape [id DE] '' | |Shape [id DE] ''
| | |Subtensor{int64::} [id DA] '' | | |Subtensor{int64::} [id DA] ''
| |TensorConstant{0.1} [id CW] | |TensorConstant{0.1} [id DF]
| |TensorConstant{0.9} [id CX] | |TensorConstant{0.9} [id DG]
|Sum{acc_dtype=float64} [id DF] '' |Sum{acc_dtype=float64} [id DH] ''
|Elemwise{mul,no_inplace} [id DG] '' |Elemwise{mul,no_inplace} [id DI] ''
|for{cpu,scan_fn}.2 [id H] '' |for{cpu,scan_fn}.2 [id H] ''
|RandomFunction{uniform}.1 [id DH] '' |RandomFunction{uniform}.1 [id DJ] ''
|<RandomStateType> [id DI] |<RandomStateType> [id DK]
|Shape [id DJ] '' |Shape [id DL] ''
| |for{cpu,scan_fn}.2 [id H] '' | |for{cpu,scan_fn}.2 [id H] ''
|TensorConstant{0.1} [id CW] |TensorConstant{0.1} [id DM]
|TensorConstant{0.9} [id CX] |TensorConstant{0.9} [id DN]
Inner graphs of the scan ops: Inner graphs of the scan ops:
for{cpu,scan_fn}.1 [id H] '' for{cpu,scan_fn}.1 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
> |y0[t-1] [id DL] -> [id BR] > |y0[t-1] [id DP] -> [id BR]
> |y0[t-3] [id DM] -> [id BR] > |y0[t-3] [id DQ] -> [id BR]
> |InplaceDimShuffle{} [id DN] '' > |InplaceDimShuffle{} [id DR] ''
> |CGemv{inplace} [id DO] '' > |CGemv{inplace} [id DS] ''
> |AllocEmpty{dtype='%(float)s'} [id DP] '' > |AllocEmpty{dtype='%(float)s'} [id DT] ''
> | |TensorConstant{1} [id DQ] > | |TensorConstant{1} [id DU]
> |TensorConstant{1.0} [id DR] > |TensorConstant{1.0} [id DV]
> |InplaceDimShuffle{x,0} [id DS] '' > |InplaceDimShuffle{x,0} [id DW] ''
> | |wout_copy [id DT] -> [id CQ] > | |wout_copy [id DX] -> [id CQ]
> |x0[t-1] [id DU] -> [id CB] > |x0[t-1] [id DY] -> [id CB]
> |TensorConstant{0.0} [id DV] > |TensorConstant{0.0} [id DZ]
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
> |CGemv{no_inplace} [id DX] '' > |CGemv{no_inplace} [id EB] ''
> | |AllocEmpty{dtype='%(float)s'} [id DY] '' > | |AllocEmpty{dtype='%(float)s'} [id EC] ''
> | | |Shape_i{1} [id DZ] '' > | | |Shape_i{1} [id ED] ''
> | | |win_copy [id EA] -> [id CR] > | | |win_copy [id EE] -> [id CR]
> | |TensorConstant{1.0} [id DR] > | |TensorConstant{1.0} [id DV]
> | |InplaceDimShuffle{1,0} [id EB] 'win_copy.T' > | |InplaceDimShuffle{1,0} [id EF] 'win_copy.T'
> | | |win_copy [id EA] -> [id CR] > | | |win_copy [id EE] -> [id CR]
> | |u1[t] [id EC] -> [id BJ] > | |u1[t] [id EG] -> [id BJ]
> | |TensorConstant{0.0} [id DV] > | |TensorConstant{0.0} [id DZ]
> |u2[t] [id ED] -> [id BN] > |u2[t] [id EH] -> [id BN]
> |u2[t-1] [id EE] -> [id BL] > |u2[t-1] [id EI] -> [id BL]
> |u2[t+1] [id EF] -> [id BP] > |u2[t+1] [id EJ] -> [id BP]
> |win2_copy [id EG] -> [id CO] > |win2_copy [id EK] -> [id CO]
> |CGemv{inplace} [id EH] '' > |CGemv{inplace} [id EL] ''
> |AllocEmpty{dtype='%(float)s'} [id EI] '' > |AllocEmpty{dtype='%(float)s'} [id EM] ''
> | |Shape_i{1} [id EJ] '' > | |Shape_i{1} [id EN] ''
> | |w_copy [id EK] -> [id CP] > | |w_copy [id EO] -> [id CP]
> |TensorConstant{1.0} [id DR] > |TensorConstant{1.0} [id DV]
> |InplaceDimShuffle{1,0} [id EL] 'w_copy.T' > |InplaceDimShuffle{1,0} [id EP] 'w_copy.T'
> | |w_copy [id EK] -> [id CP] > | |w_copy [id EO] -> [id CP]
> |x0[t-1] [id DU] -> [id CB] > |x0[t-1] [id DY] -> [id CB]
> |TensorConstant{0.0} [id DV] > |TensorConstant{0.0} [id DZ]
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
for{cpu,scan_fn}.0 [id H] '' for{cpu,scan_fn}.0 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
for{cpu,scan_fn}.2 [id H] '' for{cpu,scan_fn}.2 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
for{cpu,scan_fn}.2 [id H] '' for{cpu,scan_fn}.2 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
""" % { """ % {
"float": theano.config.floatX "float": theano.config.floatX
} }
......
import itertools import itertools
import numpy as np
import pytest
import theano
from theano import tensor
from theano.scan_module.scan_utils import equal_computations, map_variables
from theano.tensor.type_other import NoneConst
import pytest
def test_equal_compuations(): import numpy as np
# This was a bug report by a Theano user.
c = NoneConst
assert equal_computations([c], [c])
m = theano.tensor.matrix()
max_argmax1 = theano.tensor.max_and_argmax(m)
max_argmax2 = theano.tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
import theano
################# from theano import tensor
# map_variables # from theano.scan_module.scan_utils import map_variables
#################
class TestMapVariables: class TestMapVariables:
......
...@@ -4,6 +4,7 @@ import theano.tensor.inplace ...@@ -4,6 +4,7 @@ import theano.tensor.inplace
from theano import tensor as T, config from theano import tensor as T, config
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.gof.toolbox import is_same_graph
from theano.tensor.nnet import ( from theano.tensor.nnet import (
sigmoid, sigmoid,
sigmoid_inplace, sigmoid_inplace,
...@@ -351,9 +352,7 @@ class TestSigmoidOpts: ...@@ -351,9 +352,7 @@ class TestSigmoidOpts:
trees = [parse_mul_tree(e) for e in (expr1, expr2)] trees = [parse_mul_tree(e) for e in (expr1, expr2)]
perform_sigm_times_exp(trees[0]) perform_sigm_times_exp(trees[0])
trees[0] = simplify_mul(trees[0]) trees[0] = simplify_mul(trees[0])
good = theano.gof.graph.is_same_graph( good = is_same_graph(compute_mul(trees[0]), compute_mul(trees[1]))
compute_mul(trees[0]), compute_mul(trees[1])
)
if not good: if not good:
print(trees[0]) print(trees[0])
print(trees[1]) print(trees[1])
...@@ -541,7 +540,7 @@ class TestSigmoidUtils: ...@@ -541,7 +540,7 @@ class TestSigmoidUtils:
tree = (x * y) * -z tree = (x * y) * -z
mul_tree = parse_mul_tree(tree) mul_tree = parse_mul_tree(tree)
assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree
assert theano.gof.graph.is_same_graph(compute_mul(parse_mul_tree(tree)), tree) assert is_same_graph(compute_mul(parse_mul_tree(tree)), tree)
def test_parse_mul_tree(self): def test_parse_mul_tree(self):
x, y, z = tensor.vectors("x", "y", "z") x, y, z = tensor.vectors("x", "y", "z")
...@@ -566,7 +565,7 @@ class TestSigmoidUtils: ...@@ -566,7 +565,7 @@ class TestSigmoidUtils:
lambda x: is_1pexp(x, only_process_constants=False), lambda x: is_1pexp(x, only_process_constants=False),
[(1 + exp(-x)), (exp(-x) + 1)], [(1 + exp(-x)), (exp(-x) + 1)],
): ):
assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x) assert not neg and is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp(x), False) is None assert is_1pexp(1 - exp(x), False) is None
assert is_1pexp(2 + exp(x), False) is None assert is_1pexp(2 + exp(x), False) is None
assert is_1pexp(exp(x) + 2, False) is None assert is_1pexp(exp(x) + 2, False) is None
......
差异被折叠。
...@@ -2781,6 +2781,22 @@ class TestLocalSubtensorMakeVector: ...@@ -2781,6 +2781,22 @@ class TestLocalSubtensorMakeVector:
r = f(0, 1, 2) r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2 assert r[0] == 0 and r[1] == 2
@pytest.mark.xfail(
reason="local_subtensor_make_vector doesn't handle all index cases"
)
def test_MakeVector_idx(self):
x, y, z, q = tensor.lscalars("xyzq")
v = make_vector(x, y, z)
q = make_vector(0, 2)
f = function([x, y, z], v[q], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, MakeVector)
assert len(prog[0].inputs) == 2
r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2
def test_stack_trace(self): def test_stack_trace(self):
x, y, z = tensor.lscalars("xyz") x, y, z = tensor.lscalars("xyz")
v = make_vector(x, y, z) v = make_vector(x, y, z)
......
...@@ -7,14 +7,15 @@ import theano.tensor as tt ...@@ -7,14 +7,15 @@ import theano.tensor as tt
from numpy.testing import assert_equal, assert_string_equal from numpy.testing import assert_equal, assert_string_equal
from theano.tensor import ( from theano.tensor.var import TensorConstant
from theano.tensor.subtensor import (
Subtensor, Subtensor,
AdvancedSubtensor, AdvancedSubtensor,
AdvancedBooleanSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
) )
from theano.tensor.elemwise import DimShuffle
from theano.tensor.type_other import MakeSlice
import tests.unittest_tools as utt import tests.unittest_tools as utt
...@@ -79,37 +80,103 @@ def test_copy(): ...@@ -79,37 +80,103 @@ def test_copy():
assert_string_equal(y.name, "y") assert_string_equal(y.name, "y")
def test_None_dimShuffle_replace(): def test__getitem__Subtensor():
# tests replacing None usage in subtensor with dimshuffle # Make sure we get `Subtensor`s for basic indexing operations
# x = tt.matrix("x")
# tests whenever None is used in subtensor to reshape a variable, it is i = tt.iscalar("i")
# replaced by dimshuffle. If the replacement is done properly, Subtensor op
# (or any of its variants) should not be used anymore.
x = tt.dmatrix("x") z = x[i]
y = x[:, None, :] op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
f = theano.function([x], y) assert op_types[-1] == Subtensor
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [ # This should ultimately do nothing (i.e. just return `x`)
Subtensor, z = x[()]
AdvancedSubtensor, assert len(z.owner.op.idx_list) == 0
AdvancedSubtensor1, # assert z is x
IncSubtensor,
AdvancedIncSubtensor, # This is a poorly placed optimization that produces a `DimShuffle`
AdvancedIncSubtensor1, # It lands in the `full_slices` condition in
] # `_tensor_py_operators.__getitem__`
z = x[..., None]
x = tt.tensor3("x") op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
y1 = x[:, :, None, :] assert all(op_type == DimShuffle for op_type in op_types)
y2 = x[None, :, :, None, :]
y3 = x[:, :, None, :, None, None] z = x[None, :, None, :]
f = theano.function([x], [y1, y2, y3]) op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
for elem in f.maker.fgraph.toposort(): assert all(op_type == DimShuffle for op_type in op_types)
assert type(elem.op) not in [
Subtensor, # This one lands in the non-`full_slices` condition in
AdvancedSubtensor, # `_tensor_py_operators.__getitem__`
AdvancedSubtensor1, z = x[:i, :, None]
IncSubtensor, op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
AdvancedIncSubtensor, assert op_types[1:] == [DimShuffle, Subtensor]
AdvancedIncSubtensor1,
] z = x[:]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
z = x[..., :]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
z = x[..., i, :]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
def test__getitem__AdvancedBooleanSubtensor():
# Make sure we get `AdvancedBooleanSubtensor`s for basic indexing operations
x = tt.matrix("x")
i = tt.type.TensorType("bool", (False, False))("i")
z = x[i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
i = tt.type.TensorType("bool", (False,))("i")
z = x[:, i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
i = tt.type.TensorType("bool", (False,))("i")
z = x[..., i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
with pytest.raises(TypeError):
z = x[[True, False], i]
z = x[tt.ivector("b"), i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
def test__getitem__AdvancedSubtensor():
# Make sure we get `AdvancedSubtensor`s for basic indexing operations
x = tt.matrix("x")
i = tt.ivector("i")
# This is a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor1
# This should index nothing (i.e. return an empty copy of `x`)
# We check that the index is empty
z = x[[]]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types == [AdvancedSubtensor1]
assert isinstance(z.owner.inputs[1], TensorConstant)
# This is also a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[:, i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types == [DimShuffle, AdvancedSubtensor1, DimShuffle]
z = x[..., i, None]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[i, None]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor
...@@ -656,61 +656,61 @@ def test_scan_debugprint5(): ...@@ -656,61 +656,61 @@ def test_scan_debugprint5():
| | | | | | |for{cpu,scan_fn} [id F] '' | | | | | | |for{cpu,scan_fn} [id F] ''
| | | | | | |Constant{1} [id BT] | | | | | | |Constant{1} [id BT]
| | | | | |InplaceDimShuffle{x,x} [id BU] '' | | | | | |InplaceDimShuffle{x,x} [id BU] ''
| | | | | |TensorConstant{0.0} [id BP] | | | | | |TensorConstant{0.0} [id BV]
| | | | |Elemwise{second} [id BV] '' | | | | |Elemwise{second} [id BW] ''
| | | | | |Subtensor{int64} [id BW] '' | | | | | |Subtensor{int64} [id BX] ''
| | | | | | |Subtensor{int64::} [id BS] '' | | | | | | |Subtensor{int64::} [id BS] ''
| | | | | | |Constant{-1} [id BX] | | | | | | |Constant{-1} [id BY]
| | | | | |InplaceDimShuffle{x} [id BY] '' | | | | | |InplaceDimShuffle{x} [id BZ] ''
| | | | | |Elemwise{second,no_inplace} [id BZ] '' | | | | | |Elemwise{second,no_inplace} [id CA] ''
| | | | | |Sum{acc_dtype=float64} [id CA] '' | | | | | |Sum{acc_dtype=float64} [id CB] ''
| | | | | | |Subtensor{int64} [id BW] '' | | | | | | |Subtensor{int64} [id BX] ''
| | | | | |TensorConstant{1.0} [id R] | | | | | |TensorConstant{1.0} [id CC]
| | | | |Constant{-1} [id BX] | | | | |Constant{-1} [id BY]
| | | |Constant{1} [id BT] | | | |Constant{1} [id BT]
| | |Constant{-1} [id CB] | | |Constant{-1} [id CD]
| |Alloc [id CC] '' | |Alloc [id CE] ''
| | |TensorConstant{0.0} [id BP] | | |TensorConstant{0.0} [id CF]
| | |Elemwise{add,no_inplace} [id CD] '' | | |Elemwise{add,no_inplace} [id CG] ''
| | | |Elemwise{sub,no_inplace} [id C] '' | | | |Elemwise{sub,no_inplace} [id C] ''
| | | |TensorConstant{1} [id Y] | | | |TensorConstant{1} [id CH]
| | |Subtensor{int64} [id CE] '' | | |Subtensor{int64} [id CI] ''
| | |Shape [id CF] '' | | |Shape [id CJ] ''
| | | |A [id P] | | | |A [id P]
| | |Constant{0} [id CG] | | |Constant{0} [id CK]
| |A [id P] | |A [id P]
|Constant{-1} [id CH] |Constant{-1} [id CL]
Inner graphs of the scan ops: Inner graphs of the scan ops:
for{cpu,grad_of_scan_fn}.1 [id B] '' for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id CI] '' >Elemwise{add,no_inplace} [id CM] ''
> |Elemwise{mul} [id CJ] '' > |Elemwise{mul} [id CN] ''
> | |<TensorType(float64, vector)> [id CK] -> [id BL] > | |<TensorType(float64, vector)> [id CO] -> [id BL]
> | |A_copy [id CL] -> [id P] > | |A_copy [id CP] -> [id P]
> |<TensorType(float64, vector)> [id CM] -> [id BL] > |<TensorType(float64, vector)> [id CQ] -> [id BL]
>Elemwise{add,no_inplace} [id CN] '' >Elemwise{add,no_inplace} [id CR] ''
> |Elemwise{mul} [id CO] '' > |Elemwise{mul} [id CS] ''
> | |<TensorType(float64, vector)> [id CK] -> [id BL] > | |<TensorType(float64, vector)> [id CO] -> [id BL]
> | |<TensorType(float64, vector)> [id CP] -> [id Z] > | |<TensorType(float64, vector)> [id CT] -> [id Z]
> |<TensorType(float64, vector)> [id CQ] -> [id CC] > |<TensorType(float64, vector)> [id CU] -> [id CE]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
> |<TensorType(float64, vector)> [id CP] -> [id H] > |<TensorType(float64, vector)> [id CT] -> [id H]
> |A_copy [id CL] -> [id P] > |A_copy [id CP] -> [id P]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] ''""" >Elemwise{mul,no_inplace} [id CV] ''"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
......
...@@ -33,21 +33,6 @@ import sys ...@@ -33,21 +33,6 @@ import sys
import warnings import warnings
def has_handlers(logger):
# copied from Logger.hasHandlers() (introduced in Python 3.2)
_logger = logger
_has_handler = False
while _logger:
if _logger.handlers:
_has_handler = True
break
if not _logger.propagate:
break
else:
_logger = _logger.parent
return _has_handler
theano_logger = logging.getLogger("theano") theano_logger = logging.getLogger("theano")
logging_default_handler = logging.StreamHandler() logging_default_handler = logging.StreamHandler()
logging_default_formatter = logging.Formatter( logging_default_formatter = logging.Formatter(
...@@ -56,40 +41,27 @@ logging_default_formatter = logging.Formatter( ...@@ -56,40 +41,27 @@ logging_default_formatter = logging.Formatter(
logging_default_handler.setFormatter(logging_default_formatter) logging_default_handler.setFormatter(logging_default_formatter)
theano_logger.setLevel(logging.WARNING) theano_logger.setLevel(logging.WARNING)
if has_handlers(theano_logger) is False: if not theano_logger.hasHandlers():
theano_logger.addHandler(logging_default_handler) theano_logger.addHandler(logging_default_handler)
# Disable default log handler added to theano_logger when the module # Disable default log handler added to theano_logger when the module
# is imported. # is imported.
def disable_log_handler(logger=theano_logger, handler=logging_default_handler): def disable_log_handler(logger=theano_logger, handler=logging_default_handler):
if has_handlers(logger): if logger.hasHandlers():
logger.removeHandler(handler) logger.removeHandler(handler)
# Version information. # Version information.
from theano.version import version as __version__ from theano.version import version as __version__
# Raise a meaning full warning/error if the theano directory is in the # Raise a meaningful warning/error if the theano directory is in the Python
# Python path. # path.
from six import PY3
rpath = os.path.realpath(__path__[0]) rpath = os.path.realpath(__path__[0])
for p in sys.path: for p in sys.path:
if os.path.realpath(p) != rpath: if os.path.realpath(p) != rpath:
continue continue
if PY3: raise RuntimeError("You have the theano directory in your Python path.")
raise RuntimeError(
"You have the theano directory in your Python path."
" This do not work in Python 3."
)
else:
warnings.warn(
"You have the theano directory in your Python path."
" This is will not work in Python 3."
)
break
from theano.configdefaults import config from theano.configdefaults import config
from theano.configparser import change_flags from theano.configparser import change_flags
...@@ -225,7 +197,7 @@ def dot(l, r): ...@@ -225,7 +197,7 @@ def dot(l, r):
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
this function digs through them. this function digs through them.
...@@ -236,7 +208,8 @@ def get_scalar_constant_value(v): ...@@ -236,7 +208,8 @@ def get_scalar_constant_value(v):
tensor.basic.NotScalarConstantError. tensor.basic.NotScalarConstantError.
""" """
# Is it necessary to test for presence of theano.sparse at runtime? # Is it necessary to test for presence of theano.sparse at runtime?
if "sparse" in globals() and isinstance(v.type, sparse.SparseType): sparse = globals().get("sparse")
if sparse and isinstance(v.type, sparse.SparseType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM): if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0] data = v.owner.inputs[0]
return tensor.get_scalar_constant_value(data) return tensor.get_scalar_constant_value(data)
......
...@@ -96,9 +96,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= ...@@ -96,9 +96,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
if verbose: if verbose:
print("unable to find command, tried %s" % (commands,)) print("unable to find command, tried %s" % (commands,))
return None, None return None, None
stdout = p.communicate()[0].strip() stdout = p.communicate()[0].strip().decode()
if sys.version_info[0] >= 3:
stdout = stdout.decode()
if p.returncode != 0: if p.returncode != 0:
if verbose: if verbose:
print("unable to run %s (error)" % dispcmd) print("unable to run %s (error)" % dispcmd)
......
...@@ -122,7 +122,7 @@ class OpFromGraph(gof.Op): ...@@ -122,7 +122,7 @@ class OpFromGraph(gof.Op):
.. TODO: .. TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.local_outputs, op2, is_same_graph_with_merge(op1.local_outputs, op2,
local_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
......
...@@ -24,7 +24,7 @@ from theano import config, gof ...@@ -24,7 +24,7 @@ from theano import config, gof
from theano.gof import graph from theano.gof import graph
from theano.compile.io import In, SymbolicInput, SymbolicOutput from theano.compile.io import In, SymbolicInput, SymbolicOutput
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
from theano.gof.graph import is_same_graph from theano.gof.toolbox import is_same_graph
from theano.gof.op import ops_with_inner_function from theano.gof.op import ops_with_inner_function
_logger = logging.getLogger("theano.compile.function_module") _logger = logging.getLogger("theano.compile.function_module")
......
...@@ -18,7 +18,7 @@ from collections import OrderedDict ...@@ -18,7 +18,7 @@ from collections import OrderedDict
from six import integer_types from six import integer_types
from theano import gof from theano.gof import Op, Apply, ParamsType, Variable
def register_view_op_c_code(type, code, version=()): def register_view_op_c_code(type, code, version=()):
...@@ -39,7 +39,7 @@ def register_view_op_c_code(type, code, version=()): ...@@ -39,7 +39,7 @@ def register_view_op_c_code(type, code, version=()):
ViewOp.c_code_and_version[type] = (code, version) ViewOp.c_code_and_version[type] = (code, version)
class ViewOp(gof.Op): class ViewOp(Op):
""" """
Returns an inplace view of the input. Used internally by Theano. Returns an inplace view of the input. Used internally by Theano.
...@@ -54,7 +54,7 @@ class ViewOp(gof.Op): ...@@ -54,7 +54,7 @@ class ViewOp(gof.Op):
_f16_ok = True _f16_ok = True
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out):
(x,) = inp (x,) = inp
...@@ -151,7 +151,7 @@ def register_deep_copy_op_c_code(typ, code, version=()): ...@@ -151,7 +151,7 @@ def register_deep_copy_op_c_code(typ, code, version=()):
DeepCopyOp.c_code_and_version[typ] = (code, version) DeepCopyOp.c_code_and_version[typ] = (code, version)
class DeepCopyOp(gof.Op): class DeepCopyOp(Op):
# Mapping from Type to C code (and version) to use. # Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
...@@ -165,7 +165,7 @@ class DeepCopyOp(gof.Op): ...@@ -165,7 +165,7 @@ class DeepCopyOp(gof.Op):
pass pass
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, args, outs): def perform(self, node, args, outs):
if hasattr(args[0], "copy"): if hasattr(args[0], "copy"):
...@@ -235,7 +235,7 @@ def register_shape_c_code(type, code, version=()): ...@@ -235,7 +235,7 @@ def register_shape_c_code(type, code, version=()):
Shape.c_code_and_version[type] = (code, version) Shape.c_code_and_version[type] = (code, version)
class Shape(gof.Op): class Shape(Op):
""" """
L{Op} to return the shape of a matrix. L{Op} to return the shape of a matrix.
...@@ -260,7 +260,7 @@ class Shape(gof.Op): ...@@ -260,7 +260,7 @@ class Shape(gof.Op):
# This will fail at execution time. # This will fail at execution time.
if not isinstance(x, theano.Variable): if not isinstance(x, theano.Variable):
x = theano.tensor.as_tensor_variable(x) x = theano.tensor.as_tensor_variable(x)
return gof.Apply(self, [x], [theano.tensor.lvector()]) return Apply(self, [x], [theano.tensor.lvector()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
(x,) = inp (x,) = inp
...@@ -329,7 +329,7 @@ shape = Shape() ...@@ -329,7 +329,7 @@ shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
class Shape_i(gof.Op): class Shape_i(Op):
""" """
L{Op} to return the shape of a matrix. L{Op} to return the shape of a matrix.
...@@ -369,7 +369,7 @@ class Shape_i(gof.Op): ...@@ -369,7 +369,7 @@ class Shape_i(gof.Op):
# using params. # using params.
@property @property
def params_type(self): def params_type(self):
return gof.ParamsType(i=theano.scalar.basic.int64) return ParamsType(i=theano.scalar.basic.int64)
def __str__(self): def __str__(self):
return "%s{%i}" % (self.__class__.__name__, self.i) return "%s{%i}" % (self.__class__.__name__, self.i)
...@@ -540,7 +540,7 @@ def load_back(mod, name): ...@@ -540,7 +540,7 @@ def load_back(mod, name):
return obj return obj
class FromFunctionOp(gof.Op): class FromFunctionOp(Op):
""" """
Build a basic Theano Op around a function. Build a basic Theano Op around a function.
...@@ -666,7 +666,7 @@ def register_rebroadcast_c_code(typ, code, version=()): ...@@ -666,7 +666,7 @@ def register_rebroadcast_c_code(typ, code, version=()):
Rebroadcast.c_code_and_version[typ] = (code, version) Rebroadcast.c_code_and_version[typ] = (code, version)
class Rebroadcast(gof.Op): class Rebroadcast(Op):
""" """
Change the input's broadcastable fields in some predetermined way. Change the input's broadcastable fields in some predetermined way.
...@@ -737,7 +737,7 @@ class Rebroadcast(gof.Op): ...@@ -737,7 +737,7 @@ class Rebroadcast(gof.Op):
self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable) self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)
] ]
) )
return gof.Apply(self, [x], [t()]) return Apply(self, [x], [t()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
(x,) = inp (x,) = inp
...@@ -848,7 +848,7 @@ def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=No ...@@ -848,7 +848,7 @@ def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=No
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply) SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply)
class SpecifyShape(gof.Op): class SpecifyShape(Op):
""" """
L{Op} that puts into the graph the user-provided shape. L{Op} that puts into the graph the user-provided shape.
...@@ -876,14 +876,14 @@ class SpecifyShape(gof.Op): ...@@ -876,14 +876,14 @@ class SpecifyShape(gof.Op):
_f16_ok = True _f16_ok = True
def make_node(self, x, shape): def make_node(self, x, shape):
if not isinstance(x, gof.Variable): if not isinstance(x, Variable):
x = theano.tensor.as_tensor_variable(x) x = theano.tensor.as_tensor_variable(x)
shape = theano.tensor.as_tensor_variable(shape) shape = theano.tensor.as_tensor_variable(shape)
assert shape.ndim == 1 assert shape.ndim == 1
assert shape.dtype in theano.tensor.integer_dtypes assert shape.dtype in theano.tensor.integer_dtypes
if isinstance(shape, theano.tensor.TensorConstant): if isinstance(shape, theano.tensor.TensorConstant):
assert shape.data.size == x.ndim assert shape.data.size == x.ndim
return gof.Apply(self, [x, shape], [x.type()]) return Apply(self, [x, shape], [x.type()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, shape = inp x, shape = inp
......
...@@ -40,7 +40,6 @@ e-mail thread "What is gof?". ...@@ -40,7 +40,6 @@ e-mail thread "What is gof?".
from theano.gof.cc import CLinker, OpWiseCLinker, DualLinker, HideC from theano.gof.cc import CLinker, OpWiseCLinker, DualLinker, HideC
from theano.gof.fg import ( from theano.gof.fg import (
CachedConstantError,
InconsistencyError, InconsistencyError,
MissingInputError, MissingInputError,
FunctionGraph, FunctionGraph,
......
...@@ -1767,10 +1767,7 @@ def std_lib_dirs_and_libs(): ...@@ -1767,10 +1767,7 @@ def std_lib_dirs_and_libs():
else: else:
if platform.python_implementation() == "PyPy": if platform.python_implementation() == "PyPy":
# Assume Linux (note: Ubuntu doesn't ship this .so) # Assume Linux (note: Ubuntu doesn't ship this .so)
if sys.version_info < (3,): libname = "pypy3-c"
libname = "pypy-c"
else:
libname = "pypy3-c"
# Unfortunately the only convention of this .so is that it appears # Unfortunately the only convention of this .so is that it appears
# next to the location of the interpreter binary. # next to the location of the interpreter binary.
libdir = os.path.dirname(os.path.realpath(sys.executable)) libdir = os.path.dirname(os.path.realpath(sys.executable))
...@@ -2353,7 +2350,7 @@ class GCC_compiler(Compiler): ...@@ -2353,7 +2350,7 @@ class GCC_compiler(Compiler):
# redefinition for recent CPython versions (>=2.7.16 and >=3.7.3). # redefinition for recent CPython versions (>=2.7.16 and >=3.7.3).
# The following nullifies that redefinition, if it is found. # The following nullifies that redefinition, if it is found.
python_version = sys.version_info[:3] python_version = sys.version_info[:3]
if python_version < (2, 7, 16) or (3,) <= python_version < (3, 7, 3): if (3,) <= python_version < (3, 7, 3):
config_h_filename = distutils.sysconfig.get_config_h_filename() config_h_filename = distutils.sysconfig.get_config_h_filename()
try: try:
with open(config_h_filename) as config_h: with open(config_h_filename) as config_h:
......
...@@ -20,17 +20,6 @@ from theano.misc.ordered_set import OrderedSet ...@@ -20,17 +20,6 @@ from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
class CachedConstantError(Exception):
"""
An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
...@@ -186,15 +175,7 @@ class FunctionGraph(utils.object2): ...@@ -186,15 +175,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
# Setup a Variable #
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
if getattr(r, "cached", False):
raise CachedConstantError(
"You manually constructed a FunctionGraph, but you passed it a"
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph."
)
if hasattr(r, "fgraph") and r.fgraph is not None and r.fgraph is not self: if hasattr(r, "fgraph") and r.fgraph is not None and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self r.fgraph = self
......
差异被折叠。
...@@ -93,12 +93,7 @@ class Optimizer(object): ...@@ -93,12 +93,7 @@ class Optimizer(object):
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
try: ret = self.apply(fgraph, *args, **kwargs)
orig = theano.tensor.basic.constant.enable
theano.tensor.basic.constant.enable = False
ret = self.apply(fgraph, *args, **kwargs)
finally:
theano.tensor.basic.constant.enable = orig
return ret return ret
def __call__(self, fgraph): def __call__(self, fgraph):
...@@ -1060,45 +1055,6 @@ class MergeOptimizer(Optimizer): ...@@ -1060,45 +1055,6 @@ class MergeOptimizer(Optimizer):
) )
def is_same_graph_with_merge(var1, var2, givens=None):
"""
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
if givens is None:
givens = {}
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
inputs = theano.gof.graph.inputs(vars)
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in vars]
o1, o2 = [v.owner for v in vars_replaced]
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
def pre_constant_merge(vars): def pre_constant_merge(vars):
""" """
Merge constants in the subgraph used to compute nodes in `vars`. Merge constants in the subgraph used to compute nodes in `vars`.
......
from functools import partial
from collections import OrderedDict
import sys import sys
import copy
import time import time
import inspect import inspect
import numpy as np import numpy as np
import theano
from functools import partial
from collections import OrderedDict
from six.moves import StringIO from six.moves import StringIO
import theano
from theano import config from theano import config
from theano.gof import graph from theano.gof.graph import (
inputs,
io_toposort,
equal_computations,
variables,
)
class AlreadyThere(Exception): class AlreadyThere(Exception):
...@@ -312,7 +319,7 @@ class Bookkeeper(Feature): ...@@ -312,7 +319,7 @@ class Bookkeeper(Feature):
FunctionGraph is initially populated, this is where you should FunctionGraph is initially populated, this is where you should
run checks on the initial contents of the FunctionGraph. run checks on the initial contents of the FunctionGraph.
""" """
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node, "on_attach") self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
...@@ -320,7 +327,7 @@ class Bookkeeper(Feature): ...@@ -320,7 +327,7 @@ class Bookkeeper(Feature):
Should remove any dynamically added functionality Should remove any dynamically added functionality
that it installed into the function_graph that it installed into the function_graph
""" """
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node, "Bookkeeper.detach") self.on_prune(fgraph, node, "Bookkeeper.detach")
...@@ -801,3 +808,154 @@ class NoOutputFromInplace(Feature): ...@@ -801,3 +808,154 @@ class NoOutputFromInplace(Feature):
"being computed by modifying another variable ", "being computed by modifying another variable ",
"inplace.", "inplace.",
) )
def is_same_graph_with_merge(var1, var2, givens=None):
"""
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
from theano.gof.opt import MergeOptimizer
if givens is None:
givens = {}
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
graph_inputs = inputs(vars)
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in vars]
o1, o2 = [v.owner for v in vars_replaced]
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
def is_same_graph(var1, var2, givens=None):
"""
Return True iff Variables `var1` and `var2` perform the same computation.
By 'performing the same computation', we mean that they must share the same
graph, so that for instance this function will return False when comparing
(x * (y * z)) with ((x * y) * z).
The current implementation is not efficient since, when possible, it
verifies equality by calling two different functions that are expected to
return the same output. The goal is to verify this assumption, to
eventually get rid of one of them in the future.
Parameters
----------
var1
The first Variable to compare.
var2
The second Variable to compare.
givens
Similar to the `givens` argument of `theano.function`, it can be used
to perform substitutions in the computational graph of `var1` and
`var2`. This argument is associated to neither `var1` nor `var2`:
substitutions may affect both graphs if the substituted variable
is present in both.
Examples
--------
====== ====== ====== ======
var1 var2 givens output
====== ====== ====== ======
x + 1 x + 1 {} True
x + 1 y + 1 {} False
x + 1 y + 1 {x: y} True
====== ====== ====== ======
"""
use_equal_computations = True
if givens is None:
givens = {}
if not isinstance(givens, dict):
givens = dict(givens)
# Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
if givens:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
# The typical case we want to handle is when `to_replace` belongs to
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
ok = True
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(inputs, ([var1], [var2])))
all_vars = [
set(variables(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
]
def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1]
for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it
# belongs to.
inside = dict(
(v, [in_var(v, k) for k in (1, 2)]) for v in (to_replace, replace_by)
)
if (
inside[to_replace][0]
and not inside[to_replace][1]
and inside[replace_by][1]
and not inside[replace_by][0]
):
# Substitute variable in `var1` by one from `var2`.
in_xs.append(to_replace)
in_ys.append(replace_by)
elif (
inside[to_replace][1]
and not inside[to_replace][0]
and inside[replace_by][0]
and not inside[replace_by][1]
):
# Substitute variable in `var2` by one from `var1`.
in_xs.append(replace_by)
in_ys.append(to_replace)
else:
ok = False
break
if not ok:
# We cannot directly use `equal_computations`.
use_equal_computations = False
else:
in_xs = None
in_ys = None
if use_equal_computations:
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1
return rval1
...@@ -65,7 +65,7 @@ from theano.compile import function, In, Out ...@@ -65,7 +65,7 @@ from theano.compile import function, In, Out
from theano.compile.mode import AddFeatureOptimizer from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern, equal_computations
from theano.gof.toolbox import NoOutputFromInplace from theano.gof.toolbox import NoOutputFromInplace
from theano.tensor import as_tensor_variable, TensorType from theano.tensor import as_tensor_variable, TensorType
...@@ -770,7 +770,7 @@ class Scan(PureOp): ...@@ -770,7 +770,7 @@ class Scan(PureOp):
if self_in.type != other_in.type: if self_in.type != other_in.type:
return False return False
return scan_utils.equal_computations( return equal_computations(
self.outputs, other.outputs, self.inputs, other.inputs self.outputs, other.outputs, self.inputs, other.inputs
) )
......
...@@ -62,19 +62,17 @@ from collections import OrderedDict ...@@ -62,19 +62,17 @@ from collections import OrderedDict
from six import integer_types from six import integer_types
from theano import tensor, scalar from theano import gof, tensor, scalar
from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty from theano.tensor import opt, get_scalar_constant_value, Alloc, AllocEmpty
from theano import gof
from theano.compile import optdb from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer from theano.gof.graph import equal_computations
from theano.scan_module import scan_op from theano.scan_module import scan_op, scan_utils
from theano.scan_module import scan_utils from theano.scan_module.scan_utils import scan_args
from theano.scan_module.scan_utils import equal_computations, scan_args
__docformat__ = "restructedtext en" __docformat__ = "restructedtext en"
__authors__ = ( __authors__ = (
...@@ -172,7 +170,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -172,7 +170,7 @@ def remove_constants_and_unused_inputs_scan(node):
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
# Check for identical other sequence # Check for identical other sequence
identical_seqs = [ identical_seqs = [
x for x in nw_outer if scan_utils.equal_computations([x], [node_inp]) x for x in nw_outer if equal_computations([x], [node_inp])
] ]
if identical_seqs: if identical_seqs:
index = node.inputs.index(identical_seqs[0]) - 1 index = node.inputs.index(identical_seqs[0]) - 1
...@@ -198,7 +196,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -198,7 +196,7 @@ def remove_constants_and_unused_inputs_scan(node):
identical_nonseq_idx = [ identical_nonseq_idx = [
i i
for (i, x) in enumerate(nw_outer_nonseq) for (i, x) in enumerate(nw_outer_nonseq)
if scan_utils.equal_computations([x], [nw_out]) if equal_computations([x], [nw_out])
] ]
if identical_nonseq_idx: if identical_nonseq_idx:
givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]] givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]]
...@@ -1907,9 +1905,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1907,9 +1905,7 @@ class ScanMerge(gof.Optimizer):
return True return True
cond = node.op.outputs[-1] cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1] rep_cond = rep.op.outputs[-1]
return scan_utils.equal_computations( return equal_computations([cond], [rep_cond], node.op.inputs, rep.op.inputs)
[cond], [rep_cond], node.op.inputs, rep.op.inputs
)
def apply(self, fgraph): def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
......
...@@ -635,143 +635,6 @@ def expand_empty(tensor_var, size): ...@@ -635,143 +635,6 @@ def expand_empty(tensor_var, size):
return ret return ret
def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
`x` and `y` represent the same computations on the same variables
(unless equivalences are provided using `in_xs`, `in_ys`).
If `in_xs` and `in_ys` are provided, then when comparing a node `x` with
a node `y` they are automatically considered as equal if there is some
index `i` such that `x == in_xs[i]` and `y == in_ys[i]`(and they both
have the same type). Note that `x` and `y` can be in the list `xs` and
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
"""
assert len(xs) == len(ys)
if in_xs is None:
in_xs = []
if in_ys is None:
in_ys = []
for x, y in zip(xs, ys):
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
return False
if x.owner: # Check above tell that y.owner eval to True too.
if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False
if x not in in_xs and x.type != y.type:
return False
if len(in_xs) != len(in_ys):
return False
for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type:
return False
common = set(zip(in_xs, in_ys))
different = set()
for dx, dy in zip(xs, ys):
# We checked above that both dx and dy have an owner or not
if not dx.owner:
if isinstance(dx, tensor.Constant) and isinstance(dy, tensor.Constant):
if not dx.equals(dy):
return False
else:
pass
elif (dx, dy) not in common and dx != dy:
return False
# Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality.
def compare_nodes(nd_x, nd_y, common, different):
"""
Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
computation.
NOTE : This function relies on the variable common to cache
results to be more efficient.
"""
if nd_x.op != nd_y.op:
return False
elif len(nd_x.inputs) != len(nd_y.inputs):
return False
elif len(nd_x.outputs) != len(nd_y.outputs):
return False
else:
all_in_common = True
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
if (dx, dy) in different:
return False
if (dx, dy) not in common:
all_in_common = False
if all_in_common:
return True
# Compare the individual inputs for equality
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
# Equality between the variables is unknown, compare
# their respective owners, if they have some
if (
dx.owner
and dy.owner
and dx.owner.outputs.index(dx) == dy.owner.outputs.index(dy)
):
nodes_equal = compare_nodes(
dx.owner, dy.owner, common, different
)
if not nodes_equal:
different.add((dx, dy))
return False
# If both variables don't have an owner, then they are
# inputs and can be directly compared
elif dx.owner is None and dy.owner is None:
if dx != dy:
if isinstance(dx, tensor.Constant) and isinstance(
dy, tensor.Constant
):
if not dx.equals(dy):
return False
else:
return False
else:
return False
# If the code reaches this statement then the inputs are pair-wise
# equivalent so the outputs of the current nodes are also
# pair-wise equivalents
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy))
return True
# Validate that each xs[i], ys[i] pair represents the same computation
for i in range(len(xs)):
if xs[i].owner:
# The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been addressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner, common, different)
if not is_equal:
return False
return True
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
""" """
Compute the shape of the outputs given the shape of the inputs of a theano Compute the shape of the outputs given the shape of the inputs of a theano
...@@ -1413,7 +1276,7 @@ def forced_replace(out, x, y): ...@@ -1413,7 +1276,7 @@ def forced_replace(out, x, y):
if graph in visited: if graph in visited:
continue continue
visited.add(graph) visited.add(graph)
if equal_computations([graph], [x]): if gof.graph.equal_computations([graph], [x]):
to_replace.append((graph, y)) to_replace.append((graph, y))
elif graph.owner: elif graph.owner:
q.extendleft(graph.owner.inputs) q.extendleft(graph.owner.inputs)
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论