提交 01de70fe authored 作者: Frederic's avatar Frederic

Make FunctionGraph() clone the input by default to don't have constant from the cache.

Disable the clone in some tests and other place where it should not be done.
上级 f7cee63b
......@@ -129,8 +129,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.variable for spec in output_specs] + updates
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
fgraph = gof.fg.FunctionGraph(inputs, outputs)
fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs)
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
......@@ -143,7 +142,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
# We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(input
for spec, input in zip(input_specs, inputs)
for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input)))))
......
......@@ -74,7 +74,7 @@ class FunctionGraph(utils.object2):
"""
def __init__(self, inputs, outputs, features=None):
def __init__(self, inputs, outputs, features=None, clone=True):
"""
Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets.
......@@ -85,7 +85,12 @@ class FunctionGraph(utils.object2):
#TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph?
:param clone: If true, we will clone the graph. This is
usefull to remove the constant cache problem.
"""
if clone:
inputs, outputs = graph.clone(inputs, outputs)
self.execute_callbacks_time = 0
......
......@@ -593,7 +593,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
#logger.debug('Compiling node %i of graph' % node_idx)
if self._op_use_c_code:
try:
e = FunctionGraph(*graph.clone(node.inputs, node.outputs))
e = FunctionGraph(node.inputs, node.outputs)
e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs)
......
......@@ -631,7 +631,9 @@ def is_same_graph_with_merge(var1, var2, givens=None):
givens = copied[2]
# Create FunctionGraph.
inputs = theano.gof.graph.inputs(vars)
fgraph = theano.gof.fg.FunctionGraph(inputs, vars)
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.iteritems():
fgraph.replace(to_replace, replace_by)
......
......@@ -93,7 +93,7 @@ def inputs():
def Env(inputs, outputs, validate=True):
e = FunctionGraph(inputs, outputs)
e = FunctionGraph(inputs, outputs, clone=False)
e.attach_feature(destroyhandler.DestroyHandler())
e.attach_feature(ReplaceValidate())
if validate:
......
import unittest
import theano
from theano.gof import CachedConstantError, FunctionGraph
class TFunctionGraph(unittest.TestCase):
def test_constant_cache_error(self):
v = theano.tensor.constant(1)
assert v.cached
self.assertRaises(CachedConstantError, FunctionGraph, [], [v + 1],
clone=False)
def test_clone(self):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
......@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply
from theano.gof.type import Type
from theano.gof.op import Op
from theano.gof.fg import FunctionGraph as Env, InconsistencyError
from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import *
......@@ -61,14 +61,13 @@ def inputs():
return x, y, z
class TestNodeFinder:
def test_straightforward(self):
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e], clone=False)
g.attach_feature(NodeFinder())
assert hasattr(g, 'get_nodes')
......
......@@ -2928,7 +2928,7 @@ class Composite(ScalarOp):
self.name = rval
def init_fgraph(self):
fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs))
fgraph = FunctionGraph(self.inputs, self.outputs)
gof.MergeOptimizer().optimize(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论