提交 01de70fe authored 作者: Frederic's avatar Frederic

Make FunctionGraph() clone the input by default to don't have constant from the cache.

Disable the clone in some tests and other place where it should not be done.
上级 f7cee63b
...@@ -129,8 +129,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -129,8 +129,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.variable for spec in output_specs] + updates orig_outputs = [spec.variable for spec in output_specs] + updates
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs)
fgraph = gof.fg.FunctionGraph(inputs, outputs)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
...@@ -143,7 +142,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -143,7 +142,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature( fgraph.attach_feature(
Supervisor(input Supervisor(input
for spec, input in zip(input_specs, inputs) for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input))))) fgraph.destroyers(input)))))
......
...@@ -74,7 +74,7 @@ class FunctionGraph(utils.object2): ...@@ -74,7 +74,7 @@ class FunctionGraph(utils.object2):
""" """
def __init__(self, inputs, outputs, features=None): def __init__(self, inputs, outputs, features=None, clone=True):
""" """
Create an FunctionGraph which operates on the subgraph bound by the inputs and Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets. outputs sets.
...@@ -85,7 +85,12 @@ class FunctionGraph(utils.object2): ...@@ -85,7 +85,12 @@ class FunctionGraph(utils.object2):
#TODO: document what variables are[not] set in the FunctionGraph when a feature #TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph? is added via the constructor. How constructed is the FunctionGraph?
:param clone: If true, we will clone the graph. This is
usefull to remove the constant cache problem.
""" """
if clone:
inputs, outputs = graph.clone(inputs, outputs)
self.execute_callbacks_time = 0 self.execute_callbacks_time = 0
......
...@@ -593,7 +593,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -593,7 +593,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
#logger.debug('Compiling node %i of graph' % node_idx) #logger.debug('Compiling node %i of graph' % node_idx)
if self._op_use_c_code: if self._op_use_c_code:
try: try:
e = FunctionGraph(*graph.clone(node.inputs, node.outputs)) e = FunctionGraph(node.inputs, node.outputs)
e_no_recycling = [new_o e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs) for (new_o, old_o) in zip(e.outputs, node.outputs)
......
...@@ -631,7 +631,9 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -631,7 +631,9 @@ def is_same_graph_with_merge(var1, var2, givens=None):
givens = copied[2] givens = copied[2]
# Create FunctionGraph. # Create FunctionGraph.
inputs = theano.gof.graph.inputs(vars) inputs = theano.gof.graph.inputs(vars)
fgraph = theano.gof.fg.FunctionGraph(inputs, vars) # The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution. # Perform Variable substitution.
for to_replace, replace_by in givens.iteritems(): for to_replace, replace_by in givens.iteritems():
fgraph.replace(to_replace, replace_by) fgraph.replace(to_replace, replace_by)
......
...@@ -93,7 +93,7 @@ def inputs(): ...@@ -93,7 +93,7 @@ def inputs():
def Env(inputs, outputs, validate=True): def Env(inputs, outputs, validate=True):
e = FunctionGraph(inputs, outputs) e = FunctionGraph(inputs, outputs, clone=False)
e.attach_feature(destroyhandler.DestroyHandler()) e.attach_feature(destroyhandler.DestroyHandler())
e.attach_feature(ReplaceValidate()) e.attach_feature(ReplaceValidate())
if validate: if validate:
......
import unittest
import theano
from theano.gof import CachedConstantError, FunctionGraph
class TFunctionGraph(unittest.TestCase):
def test_constant_cache_error(self):
v = theano.tensor.constant(1)
assert v.cached
self.assertRaises(CachedConstantError, FunctionGraph, [], [v + 1],
clone=False)
def test_clone(self):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
...@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply ...@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.fg import FunctionGraph as Env, InconsistencyError from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import * from theano.gof.toolbox import *
...@@ -61,14 +61,13 @@ def inputs(): ...@@ -61,14 +61,13 @@ def inputs():
return x, y, z return x, y, z
class TestNodeFinder: class TestNodeFinder:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e0 = dot(y, z) e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e], clone=False)
g.attach_feature(NodeFinder()) g.attach_feature(NodeFinder())
assert hasattr(g, 'get_nodes') assert hasattr(g, 'get_nodes')
......
...@@ -2928,7 +2928,7 @@ class Composite(ScalarOp): ...@@ -2928,7 +2928,7 @@ class Composite(ScalarOp):
self.name = rval self.name = rval
def init_fgraph(self): def init_fgraph(self):
fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs)) fgraph = FunctionGraph(self.inputs, self.outputs)
gof.MergeOptimizer().optimize(fgraph) gof.MergeOptimizer().optimize(fgraph)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论