提交 959d7cd4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1521 from nouiz/fix_tutorial_int32

Fix the tutorial to make it work with python 32 and 64 bit.
...@@ -254,10 +254,11 @@ for the purpose of one particular function. ...@@ -254,10 +254,11 @@ for the purpose of one particular function.
.. theano/tests/test_tutorial.py:T_examples.test_examples_8 .. theano/tests/test_tutorial.py:T_examples.test_examples_8
>>> fn_of_state = state * 2 + inc >>> fn_of_state = state * 2 + inc
>>> foo = T.iscalar() # the type (lscalar) must match the shared variable we >>> # The type of foo must match the shared variable we are replacing
>>> # are replacing with the ``givens`` list >>> # with the ``givens``
>>> foo = T.scalar(dtype=state.dtype)
>>> skip_shared = function([inc, foo], fn_of_state, >>> skip_shared = function([inc, foo], fn_of_state,
givens=[(state, foo)]) givens=[(state, foo)])
>>> skip_shared(1, 3) # we're using 3 for the state, not state.value >>> skip_shared(1, 3) # we're using 3 for the state, not state.value
array(7) array(7)
>>> state.get_value() # old state still there, but we didn't use it >>> state.get_value() # old state still there, but we didn't use it
......
...@@ -19,9 +19,13 @@ The following output depicts the pre- and post- compilation graphs. ...@@ -19,9 +19,13 @@ The following output depicts the pre- and post- compilation graphs.
.. code-block:: python .. code-block:: python
import numpy
import theano import theano
import theano.tensor as T import theano.tensor as T
import numpy
import os
rng = numpy.random rng = numpy.random
N = 400 N = 400
...@@ -52,16 +56,16 @@ The following output depicts the pre- and post- compilation graphs. ...@@ -52,16 +56,16 @@ The following output depicts the pre- and post- compilation graphs.
train = theano.function( train = theano.function(
inputs=[x, y], inputs=[x, y],
outputs=[prediction, xent], outputs=[prediction, xent],
updates={w: w - 0.01 * gw, b: b - 0.01 * gb}, updates=[(w, w - 0.01 * gw), (b, b - 0.01 * gb)],
name="train") name="train")
predict = theano.function(inputs=[x], outputs=prediction, predict = theano.function(inputs=[x], outputs=prediction,
name="predict") name="predict")
if any( [x.op.__class__.__name__=='Gemv' for x in if any([x.op.__class__.__name__ in ['Gemv', 'CGemv'] for x in
train.maker.fgraph.toposort()]): train.maker.fgraph.toposort()]):
print 'Used the cpu' print 'Used the cpu'
elif any( [x.op.__class__.__name__=='GpuGemm' for x in elif any([x.op.__class__.__name__ == 'GpuGemm' for x in
train.maker.fgraph.toposort()]): train.maker.fgraph.toposort()]):
print 'Used the gpu' print 'Used the gpu'
else: else:
print 'ERROR, not able to tell if theano used the cpu or the gpu' print 'ERROR, not able to tell if theano used the cpu or the gpu'
...@@ -82,6 +86,8 @@ The following output depicts the pre- and post- compilation graphs. ...@@ -82,6 +86,8 @@ The following output depicts the pre- and post- compilation graphs.
# Print the picture graphs # Print the picture graphs
# after compilation # after compilation
if not os.path.exists('pics'):
os.mkdir('pics')
theano.printing.pydotprint(predict, theano.printing.pydotprint(predict,
outfile="pics/logreg_pydotprint_predic.png", outfile="pics/logreg_pydotprint_predic.png",
var_with_name_simple=True) var_with_name_simple=True)
......
...@@ -957,8 +957,21 @@ def pydotprint_variables(vars, ...@@ -957,8 +957,21 @@ def pydotprint_variables(vars,
for nd in vars: for nd in vars:
if nd.owner: if nd.owner:
plot_apply(nd.owner, depth) plot_apply(nd.owner, depth)
try:
g.write_png(outfile, prog='dot') g.write_png(outfile, prog='dot')
except pd.InvocationException, e:
# Some version of pydot are bugged/don't work correctly with
# empty label. Provide a better user error message.
if pd.__version__ == "1.0.28" and "label=]" in e.message:
raise Exception("pydot 1.0.28 is know to be bugged. Use another "
"working version of pydot")
elif "label=]" in e.message:
raise Exception("Your version of pydot " + pd.__version__ +
" returned an error. Version 1.0.28 is known"
" to be bugged and 1.0.25 to be working with"
" Theano. Using another version of pydot could"
" fix this problem. The pydot error is: " +
e.message)
print 'The output file is available at', outfile print 'The output file is available at', outfile
......
...@@ -585,7 +585,7 @@ class T_examples(unittest.TestCase): ...@@ -585,7 +585,7 @@ class T_examples(unittest.TestCase):
from theano import shared from theano import shared
# Force the dtype to int64 to work correctly on 32 bit computer. # Force the dtype to int64 to work correctly on 32 bit computer.
# Otherwise, it create by default a int32 on 32 bit computer. # Otherwise, it create by default a int32 on 32 bit computer.
state = shared(numpy.int64(0)) state = shared(0)
inc = T.iscalar('inc') inc = T.iscalar('inc')
accumulator = function([inc], state, updates=[(state, state+inc)]) accumulator = function([inc], state, updates=[(state, state+inc)])
...@@ -604,10 +604,11 @@ class T_examples(unittest.TestCase): ...@@ -604,10 +604,11 @@ class T_examples(unittest.TestCase):
assert state.get_value() == array(0) assert state.get_value() == array(0)
fn_of_state = state * 2 + inc fn_of_state = state * 2 + inc
foo = T.lscalar() # the type (lscalar) must match the shared variable we # The type of foo must match the shared variable we are replacing
# are replacing with the ``givens`` list # with the ``givens``
foo = T.scalar(dtype=state.dtype)
skip_shared = function([inc, foo], fn_of_state, skip_shared = function([inc, foo], fn_of_state,
givens=[(state, foo)]) givens=[(state, foo)])
assert skip_shared(1, 3) == array(7) assert skip_shared(1, 3) == array(7)
assert state.get_value() == array(0) assert state.get_value() == array(0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论