提交 3a2556a8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make equal_computations support mixed NumPy/primitive inputs

上级 3635eacd
...@@ -326,9 +326,30 @@ class TestAutoName: ...@@ -326,9 +326,30 @@ class TestAutoName:
def test_equal_computations(): def test_equal_computations():
# This was a bug report by a Theano user.
a, b = tensor.iscalars(2)
with pytest.raises(ValueError):
equal_computations([a], [a, b])
assert equal_computations([a], [a])
assert equal_computations([tensor.as_tensor(1)], [tensor.as_tensor(1)])
assert not equal_computations([b], [a])
assert not equal_computations([tensor.as_tensor(1)], [tensor.as_tensor(2)])
assert equal_computations([2], [2])
assert equal_computations([np.r_[2, 1]], [np.r_[2, 1]])
assert equal_computations([np.r_[2, 1]], [tensor.as_tensor(np.r_[2, 1])])
assert equal_computations([tensor.as_tensor(np.r_[2, 1])], [np.r_[2, 1]])
assert not equal_computations([2], [a])
assert not equal_computations([np.r_[2, 1]], [a])
assert not equal_computations([a], [2])
assert not equal_computations([a], [np.r_[2, 1]])
c = tensor.type_other.NoneConst c = tensor.type_other.NoneConst
assert equal_computations([c], [c]) assert equal_computations([c], [c])
m = tensor.matrix() m = tensor.matrix()
max_argmax1 = tensor.max_and_argmax(m) max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m) max_argmax2 = tensor.max_and_argmax(m)
......
...@@ -7,6 +7,8 @@ from collections import deque ...@@ -7,6 +7,8 @@ from collections import deque
from copy import copy from copy import copy
from itertools import count from itertools import count
import numpy as np
import theano import theano
from theano import config from theano import config
from theano.gof.utils import ( from theano.gof.utils import (
...@@ -1420,14 +1422,37 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -1420,14 +1422,37 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
`ys`, but also represent subgraphs of a computational graph in `xs` `ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`. or `ys`.
Parameters
----------
xs : list of Variable
ys : list of Variable
Returns
-------
bool
""" """
assert len(xs) == len(ys) if len(xs) != len(ys):
raise ValueError("The number of graphs/Variables in each argument must match.")
if in_xs is None: if in_xs is None:
in_xs = [] in_xs = []
if in_ys is None: if in_ys is None:
in_ys = [] in_ys = []
for x, y in zip(xs, ys): for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return np.array_equal(y.data, x)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return np.array_equal(x.data, y)
return False
if not isinstance(y, Variable):
return False
if x.owner and not y.owner: if x.owner and not y.owner:
return False return False
if y.owner and not x.owner: if y.owner and not x.owner:
...@@ -1437,8 +1462,10 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -1437,8 +1462,10 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
if x not in in_xs and x.type != y.type: if x not in in_xs and x.type != y.type:
return False return False
if len(in_xs) != len(in_ys): if len(in_xs) != len(in_ys):
return False return False
for _x, _y in zip(in_xs, in_ys): for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type: if _x.type != _y.type:
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论