提交 a032cfbe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

RootOp: Respect `use_vectorized_jac`

上级 e2cb1e2d
...@@ -831,6 +831,7 @@ class RootOp(ScipyVectorWrapperOp): ...@@ -831,6 +831,7 @@ class RootOp(ScipyVectorWrapperOp):
) )
self.fgraph = FunctionGraph([variables, *args], [equations]) self.fgraph = FunctionGraph([variables, *args], [equations])
self.use_vectorized_jac = use_vectorized_jac
if jac: if jac:
jac_wrt_x = jacobian( jac_wrt_x = jacobian(
...@@ -914,12 +915,15 @@ class RootOp(ScipyVectorWrapperOp): ...@@ -914,12 +915,15 @@ class RootOp(ScipyVectorWrapperOp):
inner_fx = self.fgraph.outputs[0] inner_fx = self.fgraph.outputs[0]
df_dx = ( df_dx = (
jacobian(inner_fx, inner_x, vectorize=True) jacobian(inner_fx, inner_x, vectorize=self.use_vectorized_jac)
if not self.jac if not self.jac
else self.fgraph.outputs[1] else self.fgraph.outputs[1]
) )
df_dtheta_columns = jacobian( df_dtheta_columns = jacobian(
inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True inner_fx,
inner_args,
disconnected_inputs="ignore",
vectorize=self.use_vectorized_jac,
) )
grad_wrt_args = implict_optimization_grads( grad_wrt_args = implict_optimization_grads(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论