Source code for nd2py.generator.eq.metaai_generator

# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
from __future__ import annotations
import random
import logging
import numpy as np
from numpy.random import RandomState, default_rng
from typing import Tuple, List, Generator, Dict, Set, TYPE_CHECKING
from collections import defaultdict
from ...utils import AttrDict
from ... import core as nd
if TYPE_CHECKING:
    from ...core.nettype import NetType
    from ...core.symbols import *


[docs] class MetaAIGenerator:
[docs] def __init__( self, variables: List[Variable], binary: List[str|Symbol] = [nd.Add, nd.Sub, nd.Mul, nd.Div], unary: List[str|Symbol] = [nd.Abs, nd.Inv, nd.Sqrt, nd.Log, nd.Exp, nd.Sin, nd.Arcsin, nd.Cos, nd.Arccos, nd.Tan, nd.Arctan, nd.Pow2, nd.Pow3], operators_to_downsample='Div:0,Arcsin:0,Arccos:0,Tan:0.2,Arctan:0.2,Sqrt:5,Pow2:3,Inv:3', rng: RandomState = None, edge_list: Tuple[List[int], List[int]] = None, num_nodes: int = None, scalar_number_only=True, ): self.binary = binary self.unary = unary self.symbols = self.binary + self.unary self.variables = variables self.scalar_number_only = scalar_number_only self._rng = rng or default_rng() prob_dict = defaultdict(lambda: 1.0) for item in operators_to_downsample.split(","): if item != "": op, prob = item.split(':') prob_dict[getattr(core, op)] = float(prob) self.binary_prob = [prob_dict[op] for op in self.binary] self.binary_prob = np.array(self.binary_prob) / sum(self.binary_prob) self.unary_prob = [prob_dict[op] for op in self.unary] self.unary_prob = np.array(self.unary_prob) / sum(self.unary_prob) if num_nodes is None and edge_list is not None: num_nodes = np.reshape(edge_list, (-1,)).max() + 1 self.num_nodes = num_nodes self.edge_list = edge_list
[docs] def sample(self, nettypes: Set[NetType], n_operators: int = None, n_var: int = None) -> nd.Symbol: if isinstance(nettypes, str): nettypes = {nettypes} sentinel = nd.Identity(); # 哨兵节点 # construct unary-binary tree empty_nodes = [*sentinel.operands] next_en = -1 n_empty = 1 while n_operators > 0: next_pos, arity = self.generate_next_pos(n_empty, n_operators) op = self.generate_ops(arity) next_en += next_pos + 1 n_empty -= next_pos + 1 empty_nodes[next_en] = empty_nodes[next_en].replace(op()) empty_nodes.extend(empty_nodes[next_en].operands) n_empty += op.n_operands n_operators -= 1 # fill variables n_used_var = 0 for n in empty_nodes: if isinstance(n, nd.Empty): sym, n_used_var = self.generate_leaf(n_var, n_used_var) n.replace(sym) return sentinel.operands[0]
[docs] def dist(self, n_op, n_emp): """ `max_ops`: maximum number of operators Enumerate the number of possible unary-binary trees that can be generated from empty nodes. D[n][e] represents the number of different binary trees with n nodes that can be generated from e empty nodes, using the following recursion: D(n, 0) = 0 D(0, e) = 1 D(n, e) = D(n, e - 1) + p_1 * D(n - 1, e) + D(n - 1, e + 1) p1 = 0 if binary trees, or 1 if unary-binary trees """ if not hasattr(self, 'dp_cache'): self.dp_cache = [[0]] p1 = 1 if self.unary else 0 if len(self.dp_cache) <= n_op + n_emp: for _ in range(len(self.dp_cache), n_op + n_emp + 1): self.dp_cache[0].append(1) for r, row in enumerate(self.dp_cache[1:], 1): row.append(row[-1] + p1 * self.dp_cache[r-1][-2] + self.dp_cache[r-1][-1]) self.dp_cache.append([0]) return self.dp_cache[n_op][n_emp]
[docs] def generate_leaf(self, n_var:int, n_used_var:int) -> Tuple[Symbol, int]: if n_used_var < n_var: return nd.Variable(f"x_{n_used_var+1}"), n_used_var+1 else: idx = np.random.randint(1, n_var + 1) return nd.Variable(f"x_{idx}"), n_used_var
[docs] def generate_ops(self, n_operands:int) -> Symbol: if n_operands == 1: return np.random.choice(self.unary, p=self.unary_prob) elif n_operands == 2: return np.random.choice(self.binary, p=self.binary_prob) else: raise ValueError(f"Unsupported number of operands: {n_operands}")
[docs] def generate_next_pos(self, n_empty, n_operators): """ Sample the position of the next node (binary case). Sample a position in {0, ..., `n_empty` - 1}. """ assert n_empty > 0 assert n_operators > 0 probs = [self.dist(n_operators - 1, n_empty - i + 1) for i in range(n_empty)] if self.unary: probs += [self.dist(n_operators - 1, n_empty - i) for i in range(n_empty)] probs = np.array(probs, dtype=np.float64) / self.dist(n_operators, n_empty) next_pos = np.random.choice(len(probs), p=probs) n_operands = 1 if next_pos >= n_empty else 2 next_pos %= n_empty return next_pos, n_operands