提交 3a2556a8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make equal_computations support mixed NumPy/primitive inputs

上级 3635eacd
......@@ -326,9 +326,30 @@ class TestAutoName:
def test_equal_computations():
# This was a bug report by a Theano user.
a, b = tensor.iscalars(2)
with pytest.raises(ValueError):
equal_computations([a], [a, b])
assert equal_computations([a], [a])
assert equal_computations([tensor.as_tensor(1)], [tensor.as_tensor(1)])
assert not equal_computations([b], [a])
assert not equal_computations([tensor.as_tensor(1)], [tensor.as_tensor(2)])
assert equal_computations([2], [2])
assert equal_computations([np.r_[2, 1]], [np.r_[2, 1]])
assert equal_computations([np.r_[2, 1]], [tensor.as_tensor(np.r_[2, 1])])
assert equal_computations([tensor.as_tensor(np.r_[2, 1])], [np.r_[2, 1]])
assert not equal_computations([2], [a])
assert not equal_computations([np.r_[2, 1]], [a])
assert not equal_computations([a], [2])
assert not equal_computations([a], [np.r_[2, 1]])
c = tensor.type_other.NoneConst
assert equal_computations([c], [c])
m = tensor.matrix()
max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m)
......
......@@ -7,6 +7,8 @@ from collections import deque
from copy import copy
from itertools import count
import numpy as np
import theano
from theano import config
from theano.gof.utils import (
......@@ -1420,14 +1422,37 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
Parameters
----------
xs : list of Variable
ys : list of Variable
Returns
-------
bool
"""
assert len(xs) == len(ys)
if len(xs) != len(ys):
raise ValueError("The number of graphs/Variables in each argument must match.")
if in_xs is None:
in_xs = []
if in_ys is None:
in_ys = []
for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return np.array_equal(y.data, x)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return np.array_equal(x.data, y)
return False
if not isinstance(y, Variable):
return False
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
......@@ -1437,8 +1462,10 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False
if x not in in_xs and x.type != y.type:
return False
if len(in_xs) != len(in_ys):
return False
for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type:
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论