提交 49e25b0f authored 作者: Frederic's avatar Frederic

Raise error if the user try to use shared variable.

上级 2d4eca23
......@@ -47,6 +47,10 @@ class OpFromGraph(gof.Op):
if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs')
shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)]
if shared_inputs:
raise NotImplementedError("OpFromGraph do not support SharedVariable in the inner graph")
# TODO: the graph may have implicit inputs like
# SharedVariable instances.
# what impact to they have on the validity of this Op?
......
import numpy
import unittest
from theano import config
from theano import config, shared
from theano.compile import function
......@@ -56,6 +56,23 @@ class T_OpFromGraph(unittest.TestCase):
zv = numpy.ones((2, 2), dtype=config.floatX)*5
assert numpy.all(11.0 == fn(xv, yv, zv))
def test_shared(self):
x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
self.assertRaises(NotImplementedError, OpFromGraph, [x, y, z], [e], mode='FAST_RUN')
return
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
f = op(x, y, z) - op(y, z, x) # (1+3*5=array of 16) - (3+1*5=array of 8)
fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX)*3
zv = numpy.ones((2, 2), dtype=config.floatX)*5
#print function, function.__module__
#print fn.maker.fgraph.toposort()
assert numpy.allclose(8.0, fn(xv, yv, zv))
assert numpy.allclose(8.0, fn(xv, yv, zv))
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论