提交 6a2259f4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Remove constant caching

Closes #99
上级 b99dea2f
...@@ -5,25 +5,14 @@ import pytest ...@@ -5,25 +5,14 @@ import pytest
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano.gof import CachedConstantError, FunctionGraph from theano.gof.fg import FunctionGraph
from theano import tensor as tt from theano import tensor as tt
class TFunctionGraph: class TestFunctionGraph:
def test_constant_cache_error(self):
v = theano.tensor.constant(1)
assert v.cached
with pytest.raises(CachedConstantError):
FunctionGraph([], [v + 1], clone=False)
def test_clone(self):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
def test_pickle(self): def test_pickle(self):
v = tt.vector() v = tt.vector()
func = theano.gof.FunctionGraph([v], [v + 1]) func = FunctionGraph([v], [v + 1])
s = pickle.dumps(func) s = pickle.dumps(func)
pickle.loads(s) pickle.loads(s)
...@@ -31,6 +20,7 @@ class TFunctionGraph: ...@@ -31,6 +20,7 @@ class TFunctionGraph:
@pytest.mark.skipif( @pytest.mark.skipif(
not theano.config.cxx, reason="G++ not available, so we need to skip this test." not theano.config.cxx, reason="G++ not available, so we need to skip this test."
) )
@pytest.mark.slow
def test_node_outputs_not_used(self): def test_node_outputs_not_used(self):
# In the past, we where removing some not used variable from # In the past, we where removing some not used variable from
# fgraph.variables event if the apply had other output used in # fgraph.variables event if the apply had other output used in
......
...@@ -266,27 +266,14 @@ class TestAutoName: ...@@ -266,27 +266,14 @@ class TestAutoName:
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
def test_constant(self): def test_constant(self):
# Make sure the value we will use for the test aren't yet in the cache.
r1 = tensor.constant(1.5)
del tensor.constant_cache[r1.signature()]
r1 = tensor.constant(1.6)
del tensor.constant_cache[r1.signature()]
# Get counter value # Get counter value
autoname_id = next(Variable.__count__) autoname_id = next(Variable.__count__)
Variable.__count__ = count(autoname_id) Variable.__count__ = count(autoname_id)
r1 = tensor.constant(1.5) r1 = tensor.constant(1.5)
r2 = tensor.constant(1.5)
assert r1.auto_name == "auto_" + str(autoname_id), ( assert r1.auto_name == "auto_" + str(autoname_id), (
r1.auto_name, r1.auto_name,
"auto_" + str(autoname_id), "auto_" + str(autoname_id),
) )
# We reuse the same variable
assert r2.auto_name == "auto_" + str(autoname_id), (
r2.auto_name,
"auto_" + str(autoname_id),
)
assert r1 is r2
r3 = tensor.constant(1.6) r3 = tensor.constant(1.6)
assert r3.auto_name == "auto_" + str(autoname_id + 1) assert r3.auto_name == "auto_" + str(autoname_id + 1)
......
...@@ -1673,72 +1673,72 @@ class TestScan: ...@@ -1673,72 +1673,72 @@ class TestScan:
| |<RandomStateType> [id DD] | |<RandomStateType> [id DD]
| |Shape [id DE] '' | |Shape [id DE] ''
| | |Subtensor{int64::} [id DA] '' | | |Subtensor{int64::} [id DA] ''
| |TensorConstant{0.1} [id CW] | |TensorConstant{0.1} [id DF]
| |TensorConstant{0.9} [id CX] | |TensorConstant{0.9} [id DG]
|Sum{acc_dtype=float64} [id DF] '' |Sum{acc_dtype=float64} [id DH] ''
|Elemwise{mul,no_inplace} [id DG] '' |Elemwise{mul,no_inplace} [id DI] ''
|for{cpu,scan_fn}.2 [id H] '' |for{cpu,scan_fn}.2 [id H] ''
|RandomFunction{uniform}.1 [id DH] '' |RandomFunction{uniform}.1 [id DJ] ''
|<RandomStateType> [id DI] |<RandomStateType> [id DK]
|Shape [id DJ] '' |Shape [id DL] ''
| |for{cpu,scan_fn}.2 [id H] '' | |for{cpu,scan_fn}.2 [id H] ''
|TensorConstant{0.1} [id CW] |TensorConstant{0.1} [id DM]
|TensorConstant{0.9} [id CX] |TensorConstant{0.9} [id DN]
Inner graphs of the scan ops: Inner graphs of the scan ops:
for{cpu,scan_fn}.1 [id H] '' for{cpu,scan_fn}.1 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
> |y0[t-1] [id DL] -> [id BR] > |y0[t-1] [id DP] -> [id BR]
> |y0[t-3] [id DM] -> [id BR] > |y0[t-3] [id DQ] -> [id BR]
> |InplaceDimShuffle{} [id DN] '' > |InplaceDimShuffle{} [id DR] ''
> |CGemv{inplace} [id DO] '' > |CGemv{inplace} [id DS] ''
> |AllocEmpty{dtype='%(float)s'} [id DP] '' > |AllocEmpty{dtype='%(float)s'} [id DT] ''
> | |TensorConstant{1} [id DQ] > | |TensorConstant{1} [id DU]
> |TensorConstant{1.0} [id DR] > |TensorConstant{1.0} [id DV]
> |InplaceDimShuffle{x,0} [id DS] '' > |InplaceDimShuffle{x,0} [id DW] ''
> | |wout_copy [id DT] -> [id CQ] > | |wout_copy [id DX] -> [id CQ]
> |x0[t-1] [id DU] -> [id CB] > |x0[t-1] [id DY] -> [id CB]
> |TensorConstant{0.0} [id DV] > |TensorConstant{0.0} [id DZ]
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
> |CGemv{no_inplace} [id DX] '' > |CGemv{no_inplace} [id EB] ''
> | |AllocEmpty{dtype='%(float)s'} [id DY] '' > | |AllocEmpty{dtype='%(float)s'} [id EC] ''
> | | |Shape_i{1} [id DZ] '' > | | |Shape_i{1} [id ED] ''
> | | |win_copy [id EA] -> [id CR] > | | |win_copy [id EE] -> [id CR]
> | |TensorConstant{1.0} [id DR] > | |TensorConstant{1.0} [id DV]
> | |InplaceDimShuffle{1,0} [id EB] 'win_copy.T' > | |InplaceDimShuffle{1,0} [id EF] 'win_copy.T'
> | | |win_copy [id EA] -> [id CR] > | | |win_copy [id EE] -> [id CR]
> | |u1[t] [id EC] -> [id BJ] > | |u1[t] [id EG] -> [id BJ]
> | |TensorConstant{0.0} [id DV] > | |TensorConstant{0.0} [id DZ]
> |u2[t] [id ED] -> [id BN] > |u2[t] [id EH] -> [id BN]
> |u2[t-1] [id EE] -> [id BL] > |u2[t-1] [id EI] -> [id BL]
> |u2[t+1] [id EF] -> [id BP] > |u2[t+1] [id EJ] -> [id BP]
> |win2_copy [id EG] -> [id CO] > |win2_copy [id EK] -> [id CO]
> |CGemv{inplace} [id EH] '' > |CGemv{inplace} [id EL] ''
> |AllocEmpty{dtype='%(float)s'} [id EI] '' > |AllocEmpty{dtype='%(float)s'} [id EM] ''
> | |Shape_i{1} [id EJ] '' > | |Shape_i{1} [id EN] ''
> | |w_copy [id EK] -> [id CP] > | |w_copy [id EO] -> [id CP]
> |TensorConstant{1.0} [id DR] > |TensorConstant{1.0} [id DV]
> |InplaceDimShuffle{1,0} [id EL] 'w_copy.T' > |InplaceDimShuffle{1,0} [id EP] 'w_copy.T'
> | |w_copy [id EK] -> [id CP] > | |w_copy [id EO] -> [id CP]
> |x0[t-1] [id DU] -> [id CB] > |x0[t-1] [id DY] -> [id CB]
> |TensorConstant{0.0} [id DV] > |TensorConstant{0.0} [id DZ]
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
for{cpu,scan_fn}.0 [id H] '' for{cpu,scan_fn}.0 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
for{cpu,scan_fn}.2 [id H] '' for{cpu,scan_fn}.2 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
for{cpu,scan_fn}.2 [id H] '' for{cpu,scan_fn}.2 [id H] ''
>Elemwise{Composite{((i0 + i1) * i2)}} [id DK] '' >Elemwise{Composite{((i0 + i1) * i2)}} [id DO] ''
>Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id DW] '' >Elemwise{Composite{(i0 + ((i1 + (i2 * i3)) * i4) + i5)}} [id EA] ''
>CGemv{no_inplace} [id DX] '' >CGemv{no_inplace} [id EB] ''
""" % { """ % {
"float": theano.config.floatX "float": theano.config.floatX
} }
......
...@@ -2761,6 +2761,10 @@ class TestAsTensorVariable: ...@@ -2761,6 +2761,10 @@ class TestAsTensorVariable:
def setup_method(self): def setup_method(self):
self.x = tensor.scalar("x") self.x = tensor.scalar("x")
def test_tensor_from_scalar(self):
y = as_tensor_variable(scal.int8())
assert isinstance(y.owner.op, TensorFromScalar)
def test_one_output(self): def test_one_output(self):
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x) good_apply_var = ApplyDefaultTestOp(0).make_node(self.x)
as_tensor_variable(good_apply_var) as_tensor_variable(good_apply_var)
...@@ -5747,81 +5751,50 @@ class TestDot: ...@@ -5747,81 +5751,50 @@ class TestDot:
assert g.broadcastable == y.broadcastable assert g.broadcastable == y.broadcastable
class TestTensorfromscalar: def test_TensorFromScalar():
def test_basic(self): s = scal.constant(56)
s = scal.constant(56) t = tensor_from_scalar(s)
t = tensor_from_scalar(s) assert t.owner.op is tensor_from_scalar
assert t.owner.op is tensor_from_scalar assert t.type.broadcastable == (), t.type.broadcastable
assert t.type.broadcastable == (), t.type.broadcastable assert t.type.ndim == 0, t.type.ndim
assert t.type.ndim == 0, t.type.ndim assert t.type.dtype == s.type.dtype
assert t.type.dtype == s.type.dtype
v = eval_outputs([t])
assert v == 56, v
assert isinstance(v, np.ndarray)
assert v.shape == (), v.shape
def test_basic_1(self):
s = scal.constant(56)
t = as_tensor_variable(s)
assert t.owner.op is tensor_from_scalar
assert t.type.broadcastable == (), t.type.broadcastable
assert t.type.ndim == 0, t.type.ndim
assert t.type.dtype == s.type.dtype
v = eval_outputs([t])
assert v == 56, v
assert isinstance(v, np.ndarray)
assert v.shape == (), v.shape
g = grad(t, s) v = eval_outputs([t])
assert eval_outputs([g]) == 0.0
def test_basic_2(self): assert v == 56, v
s = scal.constant(56.0) assert isinstance(v, np.ndarray)
t = as_tensor_variable(s) assert v.shape == (), v.shape
assert t.owner.op is tensor_from_scalar
assert t.type.broadcastable == (), t.type.broadcastable
assert t.type.ndim == 0, t.type.ndim
assert t.type.dtype == s.type.dtype
v = eval_outputs([t])
assert v == 56.0, v g = grad(t, s)
assert isinstance(v, np.ndarray) assert eval_outputs([g]) == 0.0
assert v.shape == (), v.shape
g = grad(t, s)
assert eval_outputs([g]) == 1.0
def test_ScalarFromTensor():
tt = constant(56) # scal.constant(56)
ss = scalar_from_tensor(tt)
assert ss.owner.op is scalar_from_tensor
assert ss.type.dtype == tt.type.dtype
class TestScalarfromtensor: v = eval_outputs([ss])
def test_basic(self):
tt = constant(56) # scal.constant(56)
ss = scalar_from_tensor(tt)
assert ss.owner.op is scalar_from_tensor
assert ss.type.dtype == tt.type.dtype
v = eval_outputs([ss]) assert v == 56
assert v.shape == ()
assert v == 56 if config.cast_policy == "custom":
if config.cast_policy == "custom": assert isinstance(v, np.int8)
assert isinstance(v, np.int8) elif config.cast_policy in ("numpy", "numpy+floatX"):
elif config.cast_policy in ("numpy", "numpy+floatX"): assert isinstance(v, str(np.asarray(56).dtype))
assert isinstance(v, str(np.asarray(56).dtype)) else:
else: raise NotImplementedError(config.cast_policy)
raise NotImplementedError(config.cast_policy)
assert v.shape == () tt = lscalar()
tt = lscalar() ss = scalar_from_tensor(tt)
ss = scalar_from_tensor(tt) ss.owner.op.grad([tt], [ss])
ss.owner.op.grad([tt], [ss]) fff = function([tt], ss)
fff = function([tt], ss) v = fff(np.asarray(5))
v = fff(np.asarray(5)) assert v == 5
assert v == 5 assert isinstance(v, np.int64)
assert isinstance(v, np.int64) assert v.shape == ()
assert v.shape == ()
class TestGrad: class TestGrad:
......
...@@ -656,61 +656,61 @@ def test_scan_debugprint5(): ...@@ -656,61 +656,61 @@ def test_scan_debugprint5():
| | | | | | |for{cpu,scan_fn} [id F] '' | | | | | | |for{cpu,scan_fn} [id F] ''
| | | | | | |Constant{1} [id BT] | | | | | | |Constant{1} [id BT]
| | | | | |InplaceDimShuffle{x,x} [id BU] '' | | | | | |InplaceDimShuffle{x,x} [id BU] ''
| | | | | |TensorConstant{0.0} [id BP] | | | | | |TensorConstant{0.0} [id BV]
| | | | |Elemwise{second} [id BV] '' | | | | |Elemwise{second} [id BW] ''
| | | | | |Subtensor{int64} [id BW] '' | | | | | |Subtensor{int64} [id BX] ''
| | | | | | |Subtensor{int64::} [id BS] '' | | | | | | |Subtensor{int64::} [id BS] ''
| | | | | | |Constant{-1} [id BX] | | | | | | |Constant{-1} [id BY]
| | | | | |InplaceDimShuffle{x} [id BY] '' | | | | | |InplaceDimShuffle{x} [id BZ] ''
| | | | | |Elemwise{second,no_inplace} [id BZ] '' | | | | | |Elemwise{second,no_inplace} [id CA] ''
| | | | | |Sum{acc_dtype=float64} [id CA] '' | | | | | |Sum{acc_dtype=float64} [id CB] ''
| | | | | | |Subtensor{int64} [id BW] '' | | | | | | |Subtensor{int64} [id BX] ''
| | | | | |TensorConstant{1.0} [id R] | | | | | |TensorConstant{1.0} [id CC]
| | | | |Constant{-1} [id BX] | | | | |Constant{-1} [id BY]
| | | |Constant{1} [id BT] | | | |Constant{1} [id BT]
| | |Constant{-1} [id CB] | | |Constant{-1} [id CD]
| |Alloc [id CC] '' | |Alloc [id CE] ''
| | |TensorConstant{0.0} [id BP] | | |TensorConstant{0.0} [id CF]
| | |Elemwise{add,no_inplace} [id CD] '' | | |Elemwise{add,no_inplace} [id CG] ''
| | | |Elemwise{sub,no_inplace} [id C] '' | | | |Elemwise{sub,no_inplace} [id C] ''
| | | |TensorConstant{1} [id Y] | | | |TensorConstant{1} [id CH]
| | |Subtensor{int64} [id CE] '' | | |Subtensor{int64} [id CI] ''
| | |Shape [id CF] '' | | |Shape [id CJ] ''
| | | |A [id P] | | | |A [id P]
| | |Constant{0} [id CG] | | |Constant{0} [id CK]
| |A [id P] | |A [id P]
|Constant{-1} [id CH] |Constant{-1} [id CL]
Inner graphs of the scan ops: Inner graphs of the scan ops:
for{cpu,grad_of_scan_fn}.1 [id B] '' for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id CI] '' >Elemwise{add,no_inplace} [id CM] ''
> |Elemwise{mul} [id CJ] '' > |Elemwise{mul} [id CN] ''
> | |<TensorType(float64, vector)> [id CK] -> [id BL] > | |<TensorType(float64, vector)> [id CO] -> [id BL]
> | |A_copy [id CL] -> [id P] > | |A_copy [id CP] -> [id P]
> |<TensorType(float64, vector)> [id CM] -> [id BL] > |<TensorType(float64, vector)> [id CQ] -> [id BL]
>Elemwise{add,no_inplace} [id CN] '' >Elemwise{add,no_inplace} [id CR] ''
> |Elemwise{mul} [id CO] '' > |Elemwise{mul} [id CS] ''
> | |<TensorType(float64, vector)> [id CK] -> [id BL] > | |<TensorType(float64, vector)> [id CO] -> [id BL]
> | |<TensorType(float64, vector)> [id CP] -> [id Z] > | |<TensorType(float64, vector)> [id CT] -> [id Z]
> |<TensorType(float64, vector)> [id CQ] -> [id CC] > |<TensorType(float64, vector)> [id CU] -> [id CE]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
> |<TensorType(float64, vector)> [id CP] -> [id H] > |<TensorType(float64, vector)> [id CT] -> [id H]
> |A_copy [id CL] -> [id P] > |A_copy [id CP] -> [id P]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] '' >Elemwise{mul,no_inplace} [id CV] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CR] ''""" >Elemwise{mul,no_inplace} [id CV] ''"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
......
...@@ -40,7 +40,6 @@ e-mail thread "What is gof?". ...@@ -40,7 +40,6 @@ e-mail thread "What is gof?".
from theano.gof.cc import CLinker, OpWiseCLinker, DualLinker, HideC from theano.gof.cc import CLinker, OpWiseCLinker, DualLinker, HideC
from theano.gof.fg import ( from theano.gof.fg import (
CachedConstantError,
InconsistencyError, InconsistencyError,
MissingInputError, MissingInputError,
FunctionGraph, FunctionGraph,
......
...@@ -20,17 +20,6 @@ from theano.misc.ordered_set import OrderedSet ...@@ -20,17 +20,6 @@ from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
class CachedConstantError(Exception):
"""
An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
...@@ -186,15 +175,7 @@ class FunctionGraph(utils.object2): ...@@ -186,15 +175,7 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
# Setup a Variable #
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
if getattr(r, "cached", False):
raise CachedConstantError(
"You manually constructed a FunctionGraph, but you passed it a"
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph."
)
if hasattr(r, "fgraph") and r.fgraph is not None and r.fgraph is not self: if hasattr(r, "fgraph") and r.fgraph is not None and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self r.fgraph = self
......
...@@ -93,12 +93,7 @@ class Optimizer(object): ...@@ -93,12 +93,7 @@ class Optimizer(object):
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
try: ret = self.apply(fgraph, *args, **kwargs)
orig = theano.tensor.basic.constant.enable
theano.tensor.basic.constant.enable = False
ret = self.apply(fgraph, *args, **kwargs)
finally:
theano.tensor.basic.constant.enable = orig
return ret return ret
def __call__(self, fgraph): def __call__(self, fgraph):
......
...@@ -26,7 +26,6 @@ from theano.tensor.var import ( ...@@ -26,7 +26,6 @@ from theano.tensor.var import (
AsTensorError, AsTensorError,
TensorVariable, TensorVariable,
TensorConstant, TensorConstant,
TensorConstantSignature,
_tensor_py_operators, _tensor_py_operators,
) )
from theano.tensor.type import TensorType, values_eq_approx_always_true from theano.tensor.type import TensorType, values_eq_approx_always_true
...@@ -217,7 +216,7 @@ as_tensor = as_tensor_variable ...@@ -217,7 +216,7 @@ as_tensor = as_tensor_variable
def constant(x, name=None, ndim=None, dtype=None): def constant(x, name=None, ndim=None, dtype=None):
"""Return a symbolic `Constant` with value `x`. """Return a `TensorConstant` with value `x`.
Raises Raises
------ ------
...@@ -226,16 +225,6 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -226,16 +225,6 @@ def constant(x, name=None, ndim=None, dtype=None):
ValueError ValueError
`x` could not be expanded to have ndim dimensions. `x` could not be expanded to have ndim dimensions.
Notes
-----
We create a small cache of frequently used constant.
This speed up the Merge optimization for big graph.
We want to cache all scalar to don't merge as frequently constants.
But we don't want to cache too much stuff.
So we cache integer with dtype [u]int and float where the value is
between -10 and 10.
We cache all broadcast pattern for scalar.
""" """
x_ = scal.convert(x, dtype=dtype) x_ = scal.convert(x, dtype=dtype)
...@@ -252,40 +241,11 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -252,40 +241,11 @@ def constant(x, name=None, ndim=None, dtype=None):
try: try:
ttype = TensorType(dtype=x_.dtype, broadcastable=bcastable) ttype = TensorType(dtype=x_.dtype, broadcastable=bcastable)
if not constant.enable: return TensorConstant(ttype, x_, name=name)
return TensorConstant(ttype, x_, name=name)
sig = TensorConstantSignature((ttype, x_))
if sig in constant_cache:
return constant_cache[sig]
ret = TensorConstant(ttype, x_, name=name)
if (
x_.size == 1
and (-10) <= x_ <= 10
and (
x_.dtype in int_dtypes
or x_.dtype in uint_dtypes
or (
x_.dtype in float_dtypes
and
# Limit the size of the cache.
len(constant_cache) < 10000
)
)
):
constant_cache[sig] = ret
# This is needed to raise a good error to the user.
ret.cached = True
return ret
except Exception: except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x)) raise TypeError("Could not convert %s to TensorType" % x, type(x))
constant.enable = True
constant_cache = {}
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
try: try:
constant(x) constant(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论