nd2py.core.transform package

Contents

nd2py.core.transform package#

Submodules#

nd2py.core.transform.bfgs_fit module#

nd2py.core.transform.bfgs_fit.collect_numbers(expression)[source]#
class nd2py.core.transform.bfgs_fit.BFGSFit(*args: Any, **kwargs: Any)[source]#

Bases: BaseEstimator, RegressorMixin

__init__(expression: Symbol, edge_list=None, num_nodes=None, use_eps=1e-08, method='BFGS', tol=1e-06, options=None, fold_constant=False)[source]#
fit(X, y=None)[source]#
predict(X)[source]#

用拟合好的 expression 去计算新的 X 上的输出。

nd2py.core.transform.fix_nettype module#

class nd2py.core.transform.fix_nettype.FixNetType[source]#

Bases: Visitor

generic_visit(node, *args, **kwargs) _Type[source]#

direction = ‘top-down’: 每个 node 的 nettype 由 kwargs[‘nettype’] 决定。 direction = ‘bottom-up’: 每个 node 的 nettype 由其 operands 决定。只保证每个 node 运算不会出错即可,不需要对 kwargs[‘nettype’] 负责

visit_Number(node: Number, *args, **kwargs) _Type[source]#
visit_Variable(node: Variable, *args, **kwargs) _Type[source]#
visit_BinaryOp(node, *args, **kwargs) _Type[source]#
visit_Add(node, *args, **kwargs) _Type#
visit_Sub(node, *args, **kwargs) _Type#
visit_Mul(node, *args, **kwargs) _Type#
visit_Div(node, *args, **kwargs) _Type#
visit_Pow(node, *args, **kwargs) _Type#
visit_Max(node, *args, **kwargs) _Type#
visit_Min(node, *args, **kwargs) _Type#
visit_Aggr(node, *args, **kwargs) _Type[source]#
visit_Rgga(node, *args, **kwargs) _Type#
visit_Sour(node, *args, **kwargs) _Type[source]#
visit_Targ(node, *args, **kwargs) _Type#
visit_Readout(node, *args, **kwargs) _Type[source]#
fix_nettype(node: Symbol, *args, **kwargs) Symbol[source]#
edge_to_node(node: Symbol, *args, **kwargs) Symbol[source]#
node_to_edge(node: Symbol, *args, **kwargs) Symbol[source]#
edge_to_scalar(node: Symbol, *args, **kwargs) Symbol[source]#
node_to_scalar(node: Symbol, *args, **kwargs) Symbol[source]#
scalar_to_node(node: Symbol, *args, **kwargs) Symbol[source]#
scalar_to_edge(node: Symbol, *args, **kwargs) Symbol[source]#

nd2py.core.transform.fold_constant module#

class nd2py.core.transform.fold_constant.FoldConstant(fold_fitable: bool = True, fold_constant: bool = True)[source]#

Bases: Visitor

访问器,用于将表达式中不含 Number 的子表达式折叠为 Constant。

__init__(fold_fitable: bool = True, fold_constant: bool = True)[source]#
generic_visit(node, *args, **kwargs)[source]#
visit_Empty(node: Symbol, *args, **kwargs)[source]#
visit_Number(node: Number, *args, **kwargs)[source]#
visit_Variable(node: Variable, *args, **kwargs)[source]#

nd2py.core.transform.reduce module#

class nd2py.core.transform.reduce.ReduceRule(source: Symbol, target: Symbol)[source]#

Bases: object

__init__(source: Symbol, target: Symbol)[source]#
class nd2py.core.transform.reduce.Reduce(n_variables=4, constants=[0, 1, -1, 3.141592653589793, 2.718281828459045], binary=None, unary=None, max_online_iterations: int = 10, load_cache: bool = True, num_anchors=100)[source]#

Bases: object

__init__(n_variables=4, constants=[0, 1, -1, 3.141592653589793, 2.718281828459045], binary=None, unary=None, max_online_iterations: int = 10, load_cache: bool = True, num_anchors=100)[source]#
prepare_rule_dict(l_max: int = 8, force_rebuild: bool = False, save_cache: bool = True, show_progress: bool = True)[source]#

离线阶段:基于 Kruskal 最小生成森林算法的变体发现化简规则并构建 reduce_rules 字典。

这是单进程版本,按顺序处理每个表达式。

Parameters:
  • l_max – 最大表达式长度

  • force_rebuild – 是否强制重新构建(忽略缓存)

  • save_cache – 是否将构建结果缓存到文件

  • show_progress – 是否显示进度条

prepare_rule_dict_parallel(l_max: int = 8, n_jobs: int = -1, batch_size: int = 10000, force_rebuild: bool = False, save_cache: bool = True, show_progress: bool = True)[source]#

离线阶段:基于 Kruskal 最小生成森林算法的变体发现化简规则并构建 reduce_rules 字典。

这是多进程版本,使用 joblib.Parallel 加速同一长度内的表达式处理。

Parameters:
  • l_max – 最大表达式长度

  • n_jobs – 并行进程数,-1 表示使用所有 CPU 核心

  • force_rebuild – 是否强制重新构建(忽略缓存)

  • save_cache – 是否将构建结果缓存到文件

  • show_progress – 是否显示进度条

static has_variable(tau: Symbol) bool[source]#

检查表达式是否包含变量

static get_start_length(hash_dict: Dict, reduce_rules: List[ReduceRule]) int[source]#

获取应该从哪个长度开始处理(跳过已缓存的)

static get_array_hash(val, n=8)[source]#
static process_tau(tau: Symbol, anchors: Dict[str, np.ndarray], hash_dict: Dict[str, List[Tuple['Symbol', np.ndarray]]])[source]#
static process_tau_batch(tau_list: List['Symbol'], anchors: Dict[str, np.ndarray], hash_dict: Dict[str, List[Tuple['Symbol', np.ndarray]]])[source]#

nd2py.core.transform.simplify module#

class nd2py.core.transform.simplify.Simplify[source]#

Bases: Visitor

generic_visit(node: Symbol, *args, **kwargs) _Type[source]#
remove_nested_unary(node: Symbol, *args, **kwargs) _Type[source]#
visit_Sin(node: Symbol, *args, **kwargs) _Type#
visit_Cos(node: Symbol, *args, **kwargs) _Type#
visit_Tanh(node: Symbol, *args, **kwargs) _Type#
visit_Sigmoid(node: Symbol, *args, **kwargs) _Type#
visit_Sqrt(node: Symbol, *args, **kwargs) _Type#
visit_SqrtAbs(node: Symbol, *args, **kwargs) _Type#
visit_Exp(node: Symbol, *args, **kwargs) _Type#
visit_Log(node: Symbol, *args, **kwargs) _Type#
visit_LogAbs(node: Symbol, *args, **kwargs) _Type#
visit_Readout(node: Readout, *args, **kwargs) _Type[source]#
visit_Number(node: Number, *args, **kwargs) _Type[source]#
visit_Variable(node: Variable, *args, **kwargs) _Type[source]#
visit_Add(node: Add, *args, **kwargs) _Type[source]#
visit_Sub(node: Add, *args, **kwargs) _Type#
visit_Mul(node: Mul, *args, **kwargs) _Type[source]#
visit_Div(node: Mul, *args, **kwargs) _Type#
visit_Neg(node: Neg, *args, **kwargs) _Type[source]#
visit_Inv(node: Inv, *args, **kwargs) _Type[source]#
visit_Aggr(node: Aggr, *args, **kwargs) _Type[source]#

nd2py.core.transform.split_by_add module#

class nd2py.core.transform.split_by_add.SplitByAdd[source]#

Bases: Visitor

generic_visit(node: Symbol, *args, **kwargs) _Type[source]#
visit_Add(node: Add, *args, **kwargs) _Type[source]#
visit_Sub(node: Sub, *args, **kwargs) _Type[source]#
visit_Mul(node: Mul, *args, **kwargs) _Type[source]#
visit_Div(node: Div, *args, **kwargs) _Type[source]#
visit_Neg(node: Neg, *args, **kwargs) _Type[source]#
visit_Sour(node: Sour, *args, **kwargs) _Type[source]#
visit_Targ(node: Targ, *args, **kwargs) _Type[source]#
visit_Aggr(node: Aggr, *args, **kwargs) _Type[source]#
visit_Rgga(node: Rgga, *args, **kwargs) _Type[source]#
visit_Readout(node: Readout, *args, **kwargs) _Type[source]#
merge_bias(items: List[Symbol], *args, **kwargs) List[Symbol][source]#

Merge bias terms in the node.

nd2py.core.transform.split_by_mul module#

class nd2py.core.transform.split_by_mul.SplitByMul[source]#

Bases: Visitor

generic_visit(node: Symbol, *args, **kwargs) _Type[source]#
visit_Mul(node: Mul, *args, **kwargs) _Type[source]#
visit_Div(node: Div, *args, **kwargs) _Type[source]#
merge_coefficients(items: List[Symbol], *args, **kwargs) List[Symbol][source]#

Merge coefficients from the symbols.