提交 0380b760 authored 作者: Frederic Bastien's avatar Frederic Bastien

allow to put int and float into input and output pattern of PatternSub

上级 efa90388
......@@ -14,6 +14,7 @@ import op
from copy import copy
from theano.gof.python25 import any, all
from theano.configparser import AddConfigVar, BoolParam, config
import theano
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
......@@ -517,10 +518,14 @@ class PatternSub(LocalOptimizer):
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda env, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is
......@@ -635,6 +640,11 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, graph.Constant):
if all(theano.tensor.constant(pattern).value==expr.value):
return u
else:
return retry_with_equiv()
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
return u
else:
......@@ -647,6 +657,8 @@ class PatternSub(LocalOptimizer):
return pattern[0](*args)
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
elif isinstance(pattern, (int,float)):
return pattern
else:
return pattern.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论