# Python Symbolic Information Theoretic Inequality Prover # Copyright (C) 2020 Cheuk Ting Li # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """ Python Symbolic Information Theoretic Inequality Prover Version 1.1.6 Copyright (C) 2020 Cheuk Ting Li Based on the general method of using linear programming for proving information theoretic inequalities described in the following work: R. W. Yeung, "A new outlook on Shannon's information measures," IEEE Trans. Inform. Theory, vol. 37, pp. 466-474, May 1991. R. W. Yeung, "A framework for linear information inequalities," IEEE Trans. Inform. Theory, vol. 43, pp. 1924-1934, Nov 1997. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. This linear programming approach was used in the ITIP software developed by Raymond W. Yeung and Ying-On Yan ( http://user-www.ie.cuhk.edu.hk/~ITIP/ ). Usage: This library requires Python 3. Python 2 is not supported. This library requires either PuLP, Pyomo or scipy for sparse linear programming. Open an interactive Python console in the directory containing psitip.py. from psitip import * X, Y, Z = rv("X", "Y", "Z") # Check for always true statements # Use H(X + Y | Z + W) for the entropy H(X,Y|Z,W) # Use I(X & Y + Z | W) for the mutual information I(X;Y,Z|W) # As in the case in ITIP, a False return value does not mean the inequality is # false. It only means that it cannot be deduced by Shannon-type inequalities. bool(I(X & Y) - H(X + Z) <= 0) # return True bool(I(X & Y) == H(X) - H(X | Y)) # return True # Each constraint (e.g. I(X & Y) - H(X + Z) <= 0 above) generates a Region object # that can be intersected with each other (using the "&" operator). Casting a # Region to bool returns whether the constraints in the Region are always satisfied. # Use A >> B or A <= B to check if the conditions in A implies those in B # The "<=" operator checks whether the region on the left is a subset of the # region on the right. # Note that "<=" does NOT denote implication (which is the opposite direction). bool(((I(X & Y) == 0) & (I(X+Y & Z) == 0)) >> (I(X & Z) == 0)) # return True See test.py for more usage examples. WARNING: Nested implication may produce incorrect results for some cases. Use at your own risk. It is advisable to check whether the output auxiliary random variables are indeed valid. """ import itertools import collections import array import fractions import warnings import types import functools import math import time import logging import random import contextlib import io import heapq try: import os import os.path except ImportError: os = None try: import numpy import numpy.linalg except ImportError: numpy = None try: import scipy import scipy.sparse import scipy.optimize import scipy.spatial import scipy.special import scipy.linalg except ImportError: scipy = None try: # Suppress stdout of pulp during import # with contextlib.redirect_stdout(open(os.devnull, "w")): with contextlib.redirect_stdout(io.StringIO()): import pulp except ImportError: pulp = None except Exception as err: warnings.warn(str(err), RuntimeWarning) pulp = None try: import pyomo.environ as pyo from pyomo.opt import SolverFactory logging.getLogger("pyomo.core").setLevel(logging.ERROR) except ImportError: pyo = None try: import ortools import ortools.linear_solver import ortools.linear_solver.pywraplp except ImportError: ortools = None try: import cdd except ImportError: cdd = None try: import z3 except ImportError: z3 = None try: import graphviz except ImportError: graphviz = None try: import matplotlib.pyplot as plt import matplotlib.patches import matplotlib.lines import matplotlib.patheffects import matplotlib.colors from matplotlib.collections import PatchCollection except ImportError: matplotlib = None plt = None try: import torch import torch.optim except ImportError: torch = None try: import torch.linalg except ImportError: pass try: import IPython import IPython.display except ImportError: IPython = None try: import lark from lark import v_args as lark_v_args from lark import Transformer as lark_Transformer except ImportError: lark = None lark_v_args = lambda inline: (lambda x: x) lark_Transformer = object # try: # import sympy # import sympy.matrices.normalforms # except ImportError: # sympy = None # try: # import smithnormalform # import smithnormalform.matrix # import smithnormalform.snfproblem # import smithnormalform.z # except ImportError: # smithnormalform = None class LinearProgType: NIL = 0 H = 1 # H(X_C) HC1BN = 2 # H(X | Y_C) HMIN = 3 class PsiOpts: """ Options Attributes: solver : The linear programming solver used "scipy" : scipy.optimize.linprog "pulp.cbc" : PuLP with CBC "pulp.glpk" : PuLP with GLPK "pyomo.glpk" : Pyomo with GLPK str_style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) STR_STYLE_LATEX : I(X,Y;Z|W) lptype : Linear programming mode LinearProgType.H : Classical LinearProgType.HC1BN : Bayesian Network optimization verbose_auxsearch : Print auxiliary RV search progress verbose_lp : Print linear programming parameters rename_char : Char appended when renaming auxiliary RV eps : Epsilon for float comparison """ STR_STYLE_STANDARD = 1 STR_STYLE_PSITIP = 2 STR_STYLE_LATEX = 4 STR_STYLE_LATEX_ARRAY = 1 << 5 STR_STYLE_LATEX_FRAC = 1 << 6 STR_STYLE_LATEX_QUANTAFTER = 1 << 7 STR_STYLE_MARKOV = 1 << 8 STR_STYLE_REM_SYMBOLS = 1 << 9 SFRL_LEVEL_SINGLE = 1 SFRL_LEVEL_MULTIPLE = 2 global_index = None settings = { "ent_coeff": 1.0 / math.log(2.0), "random": None, "quantum": False, "hge0": True, "hcge0": True, "ige0": True, "icge0": True, "figsize": None, "eps": 1e-10, "eps_lp": 1e-5, "eps_check": 1e-6, "eps_violate_cutoff": 1e-1, "max_denom": 1000000, "max_denom_mul": 10000, "max_denom_lp": 10000, "max_denom_try": 12, "condition_included": True, "avoid_enabled": True, "str_style": 2 + (1 << 8), "str_style_std": 1 + (1 << 8), "str_style_latex": 4 + (1 << 8), "str_style_repr": 2 + (1 << 8), "str_tosort": False, "str_lhsreal": True, "str_eqn_prefer_ge": False, "str_eps_hide": True, "str_proof_note": False, "rename_char": "_", "fcn_suffix": "@@32768@@#fcn", "meta_subs_criteria": False, "truth": None, "proof": None, "use_global_index": True, "indreg_enabled": True, "maxent_lex_enabled": True, "example_opt_num_points_mul": 1, "timer_start": None, "timer_end": None, "stop_file": None, "stop_file_last_check": None, "stop_file_exists_cache": False, "solver": None, "lptype": LinearProgType.HMIN, "lptype_H_if_proof": False, "lp_zero_group": True, "lp_no_zero_group_if_proof": True, "lp_bounded": False, "lp_ubound": 1e4, "lp_eps": 1e-3, "lp_eps_obj": 1e-4, "lp_zero_cutoff": -1e-5, "lp_cond_maximize": True, "lp_dual_form": False, "lp_dual_form_if_proof": True, "lp_bnet_reverse": False, "lp_bnet_hc": True, "fcn_mode": 1, "solver_scipy_maxsize": -1, "pulp_options": None, "pyomo_options": {}, "verbose_solver": False, "simplify_enabled": True, "simplify_quick": False, "simplify_reduce_coeff": True, "simplify_remove_missing_aux": True, "simplify_aux": True, "simplify_aux_combine": True, "simplify_aux_commonpart": True, "simplify_aux_xor_len": 1, "simplify_aux_empty": False, "simplify_aux_recombine": False, "simplify_pair": True, "simplify_redundant": True, "simplify_redundant_full": True, "simplify_bayesnet": True, "simplify_expr_exhaust": True, "simplify_redundant_op": True, "simplify_union": False, "simplify_aux_relax": True, "simplify_sort": True, "simplify_regterm": True, "simplify_aux_hull": False, "simplify_aux_hull_lower_complexity": True, "simplify_aux_eq": False, "simplify_aux_strengthen": False, "simplify_num_iter": 1, "term_allow": None, "istorch": (torch is not None), "opt_optimizer": "SLSQP", #"sgd", "opt_eps_denom": 1e-8, "opt_eps_tol": 1e-7, "opt_learnrate": 0.04, "opt_learnrate2": 0.02, "opt_momentum": 0.0, "opt_num_iter": 800, "opt_num_iter2": 0, "opt_num_points": 1, "opt_eps_converge": 1e-9, "opt_num_hop": 1, "opt_hop_temp": 0.05, "opt_hop_prob": 0.2, "opt_alm_rho": 1.0, "opt_alm_rho_pow": 1.0, "opt_alm_step": 5, "opt_alm_penalty": 20.0, "opt_aux_card": 5, "opt_example_card": (2, 4), "imp_noncircular": True, "imp_noncircular_allaux": False, "imp_simplify": False, "prefer_expand": True, "tensorize_simplify": False, "eliminate_rays": False, "ignore_must": False, "forall_multiuse": True, "forall_multiuse_numsave": 128, "auxsearch_local": True, "auxsearch_leaveone": False, "auxsearch_leaveone_add_ineq": True, "auxsearch_strengthen": False, "auxsearch_aux_strengthen": True, "init_leaveone": True, "auxsearch_max_iter": 0, "auxsearch_op_casesteplimit": 16, "auxsearch_op_caselimit": 512, "auxsearch_sandwich": True, "auxsearch_sandwich_inc": False, "auxsearch_max_numupdate_one_if_proof": True, "flatten_minmax_elim": False, "flatten_distribute": True, "flatten_distribute_multi": False, "presolve_aux_hull": False, "presolve_aux_hull_quick": True, "presolve_aux_eq": False, "presolve_aux_eq_quick": True, "presolve_simplify": False, "alg_group_nit": 10, "alg_normalize": True, "bayesnet_semigraphoid_iter": 1000, "proof_enabled": False, "proof_nowrite": False, "proof_step_dualsum": True, "proof_step_dualsum_short": True, "proof_step_dualsum_exhaust": True, "proof_step_dualsum_prog": False, "proof_step_chain": True, "proof_step_optimize": True, "proof_step_simplify": False, "proof_step_term_simplify": False, "proof_step_compress": True, "proof_step_term_nshuffle": 1, "proof_step_bayesnet": True, "proof_deficit_separate": True, "proof_repeat_implicant": False, "proof_yield_one": False, "proof_note": True, "proof_note_color": None, "proof_note_newline": 80, "proof_note_skip_trivial": True, "proof_step_expand_def": False, "proof_noskip": False, "repr_simplify": True, "repr_check": False, "repr_latex": False, "venn_latex": False, "discover_hull_frac_enabled": True, "discover_hull_frac_denom": 1000, "discover_max_facet": None, # 1000000, "discover_num_simplex": 0, "str_ineq_sep": ",", "str_brace_l": "{", "str_brace_r": "}", "str_eq": "==", "str_float": False, "str_float_dp": 5, "solve_display_reg": None, "latex_H": "H", "latex_I": "I", "latex_rv_delim": ",", "latex_cond": "|", "latex_sup": "\\sup", "latex_inf": "\\inf", "latex_max": "\\max", "latex_min": "\\min", "latex_exists": "\\exists", "latex_forall": "\\forall", "latex_quantifier_sep": ":\\,", "latex_indep": "{\\perp\\!\\!\\perp}", "latex_markov": "\\leftrightarrow", "latex_mi_delim": ";", "latex_list_bracket_l": "[", "latex_list_bracket_r": "]", "latex_matimplies": "\\Rightarrow", "latex_equiv": "\\Leftrightarrow", "latex_implies": "\\Rightarrow", "latex_times": "\\cdot", "latex_prob": "\\mathbf{P}", "latex_rv_empty": "\\emptyset", "latex_region_universe": "\\top", "latex_region_empty": "\\emptyset", "latex_tautology": "\\top", "latex_contradiction": "\\bot", "latex_unknown": "?", "latex_or": "\\vee", "latex_and": "\\wedge", "latex_infty": "\\infty", "latex_eps": "\\epsilon", "latex_because": "\\because", "latex_therefore": "\\therefore", "latex_is_typical": "\\in\\mathcal{T}", "latex_subs": ":=", "latex_subs_bracket_l": "\\{", "latex_subs_bracket_r": ".", "latex_group_mul": "", # "\\cdot", "latex_group_id": "\\mathrm{id}", "latex_group_add": "+", "latex_group_minus": "-", "latex_group_additive": False, "latex_color": None, "latex_line_len": None, "verbose_lp": False, "verbose_lp_cons": False, "verbose_auxsearch": False, "verbose_auxsearch_step": False, "verbose_auxsearch_result": False, "verbose_auxsearch_cache": False, "verbose_auxsearch_step_cached": False, "verbose_auxsearch_op": False, "verbose_auxsearch_op_step": False, "verbose_auxsearch_op_detail": False, "verbose_auxsearch_op_detail2": False, "verbose_subset": False, "verbose_sfrl": False, "verbose_flatten": False, "verbose_eliminate": False, "verbose_eliminate_toreal": False, "verbose_semigraphoid": False, "verbose_proof": False, "verbose_discover": False, "verbose_discover_detail": False, "verbose_discover_outer": False, "verbose_discover_terms": False, "verbose_discover_terms_inner": False, "verbose_discover_terms_outer": False, "verbose_opt": False, "verbose_opt_step": False, "verbose_opt_step_var": False, "verbose_float_dp": 8, "verbose_commmodel": False, "verbose_codingmodel": False, "verbose_proof_step": False, "verbose_aux_reduced": False, "sfrl_level": 0, "sfrl_maxsize": 1, "sfrl_gap": "" } @staticmethod def set_setting_dict(d, key, value): if key.endswith("_all"): keyprefix = key[:-len("_all")] for key2 in d: if key2.startswith(keyprefix): PsiOpts.set_setting_dict(d, key2, value) return if key == "sfrl": if value == "no": d["sfrl_level"] = 0 else: d["sfrl_level"] = max(d["sfrl_level"], PsiOpts.SFRL_LEVEL_SINGLE) if value == "frl": d["sfrl_gap"] = "" elif value.startswith("sfrl_gap."): d["sfrl_gap"] = value[value.index(".") + 1 :] elif value == "sfrl_nogap": d["sfrl_gap"] = "zero" elif key == "str_style": d["str_style"] = iutil.convert_str_style(value) if isinstance(value, str) and value.lower() == "aitip": d["str_ineq_sep"] = "" d["str_brace_l"] = " " d["str_brace_r"] = "" d["str_eq"] = "=" d["str_float"] = True elif key == "timer" or key == "timelimit": if value is None: d["timer_start"] = None d["timer_end"] = None else: if isinstance(value, str): value = iutil.get_duration_ms(value) curtime = time.time() * 1000 d["timer_start"] = curtime d["timer_end"] = curtime + float(value) elif key == "stop_file": if value != d["stop_file"]: d["stop_file"] = value d["stop_file_last_check"] = None elif key == "simplify_level": d["simplify_enabled"] = value >= 1 d["simplify_quick"] = value <= 2 # d["simplify_remove_missing_aux"] = True d["simplify_aux"] = value >= 4 d["simplify_aux_combine"] = value >= 4 d["simplify_aux_commonpart"] = value >= 5 if value >= 9: d["simplify_aux_xor_len"] = 4 elif value >= 7: d["simplify_aux_xor_len"] = 3 else: d["simplify_aux_xor_len"] = 1 d["simplify_aux_empty"] = value >= 6 d["simplify_aux_recombine"] = value >= 10 d["simplify_pair"] = value >= 2 d["simplify_redundant"] = value >= 3 d["simplify_redundant_full"] = value >= 4 d["simplify_bayesnet"] = value >= 4 d["simplify_redundant_op"] = value >= 4 d["simplify_union"] = value >= 8 d["simplify_aux_hull_lower_complexity"] = value >= 5 d["simplify_num_iter"] = 2 if value >= 9 else 1 elif key == "simplify_strengthen": d["simplify_aux_strengthen"] = value d["simplify_aux_eq"] = value elif key == "simplify_relax": d["simplify_aux_relax"] = value d["simplify_aux_hull"] = value elif key == "auxsearch_level": d["auxsearch_leaveone"] = value >= 6 d["auxsearch_strengthen"] = value >= 9 d["presolve_aux_hull"] = value >= 7 d["presolve_aux_hull_quick"] = value >= 4 d["presolve_aux_eq"] = value >= 8 d["presolve_aux_eq_quick"] = value >= 5 d["presolve_simplify"] = value >= 10 elif key == "level": PsiOpts.set_setting_dict(d, "simplify_level", value) PsiOpts.set_setting_dict(d, "auxsearch_level", value) elif key == "ent_base": d["ent_coeff"] = 1.0 / math.log(value) elif key == "term_allow": if isinstance(value, str): value = value.lower() if value == "h": d["term_allow"] = TermAllowType.H elif value == "i": d["term_allow"] = TermAllowType.H | TermAllowType.I elif value == "hc" or value == "ch": d["term_allow"] = TermAllowType.H | TermAllowType.HC elif value == "ic" or value == "ci": d["term_allow"] = TermAllowType.H | TermAllowType.HC | TermAllowType.I | TermAllowType.IC elif value == "all": d["term_allow"] = None else: d["term_allow"] = value elif key == "term_allow_hc": if value: d["term_allow"] |= TermAllowType.HC else: if d["term_allow"] is None: d["term_allow"] = TermAllowType.DEFAULT d["term_allow"] &= ~(TermAllowType.HC) elif key == "term_allow_i": if value: d["term_allow"] |= TermAllowType.I else: if d["term_allow"] is None: d["term_allow"] = TermAllowType.DEFAULT d["term_allow"] &= ~(TermAllowType.I) elif key == "term_allow_ic": if value: d["term_allow"] |= TermAllowType.IC else: if d["term_allow"] is None: d["term_allow"] = TermAllowType.DEFAULT d["term_allow"] &= ~(TermAllowType.IC) elif key == "term_allow_i3": if value: d["term_allow"] |= TermAllowType.I3 else: if d["term_allow"] is None: d["term_allow"] = TermAllowType.DEFAULT d["term_allow"] &= ~(TermAllowType.I3) elif key == "verbose_auxsearch_all": d["verbose_auxsearch"] = value d["verbose_auxsearch_step"] = value d["verbose_auxsearch_result"] = value elif key == "verbose_proof": d["verbose_proof"] = value if value: d["proof_enabled"] = value if d["proof"] is None: d["proof"] = ProofObj.empty() elif key == "proof_enabled": d["proof_enabled"] = value if value: if d["proof"] is None: d["proof"] = ProofObj.empty() elif key == "pulp_solver": d["solver"] = "pulp.other" iutil.pulp_solver = value elif key == "lptype": if value == "H": d[key] = LinearProgType.H elif value == "HMIN": d[key] = LinearProgType.HMIN elif isinstance(value, str): d[key] = LinearProgType.HC1BN else: d[key] = value elif key == "truth": if value is None: d["truth"] = None else: d["truth"] = value.copy() elif key == "truth_add": if value is not None: if d["truth"] is None: d["truth"] = value.copy() else: d["truth"] = d["truth"] & value elif key == "cases": d["auxsearch_leaveone"] = value elif key == "proof_add": if d["proof"] is None: d["proof"] = ProofObj.empty() if PsiOpts.settings.get("verbose_proof", False): print(value.tostring(prev = d["proof"])) print("") d["proof"] += value elif key == "proof_clear": if value: if d["proof"] is not None: d["proof"].clear() elif key == "proof_new": if value: d["proof_enabled"] = value d["proof"] = ProofObj.empty() if isinstance(value, str): PsiOpts.set_setting_dict(d, "proof_option", value) elif key == "proof_shorten": d["lp_dual_form_if_proof"] = value d["lp_no_zero_group_if_proof"] = value elif key == "proof_detail": d["lptype_H_if_proof"] = value d["proof_noskip"] = value d["proof_step_compress"] = not value elif key == "proof_option": for c in value.split(","): c = c.strip().lower() PsiOpts.set_setting_dict(d, "proof_" + c, True) elif key == "proof_branch": if value: d["proof_enabled"] = value if d["proof"] is None: d["proof"] = ProofObj.empty() else: d["proof"] = d["proof"].copy() if isinstance(value, str): PsiOpts.set_setting_dict(d, "proof_option", value) elif key == "proof_step_in": if d["proof"] is None: d["proof"] = ProofObj.empty() d["proof"] = d["proof"].step_in(value) elif key == "proof_step_out": if d["proof"] is None: d["proof"] = ProofObj.empty() d["proof"] = d["proof"].step_out() elif key == "note": d["str_proof_note"] = value elif key == "random_seed": rnd = numpy.random.default_rng(value) d["random"] = rnd elif key == "opt_singlepass": d["opt_learnrate"] = 0.04 d["opt_num_iter"] = 800 d["opt_num_iter2"] = 0 d["opt_num_points"] = 1 d["opt_num_hop"] = 1 elif key == "opt_basinhopping": if value: d["opt_learnrate"] = 0.12 d["opt_learnrate2"] = 0.02 #0.001 d["opt_num_points"] = 5 d["opt_num_iter"] = 15 d["opt_num_iter2"] = 500 d["opt_num_hop"] = 20 else: PsiOpts.set_setting_dict(d, "opt_singlepass", True) elif key == "opt_learnrate_mul": d["opt_learnrate"] *= value d["opt_learnrate2"] *= value elif key == "opt_num_iter_mul": d["opt_num_iter"] = int(d["opt_num_iter"] * value) d["opt_num_iter2"] = int(d["opt_num_iter2"] * value) elif key == "opt_num_points_mul": d["opt_num_points"] = int(d["opt_num_points"] * value) else: if key not in d: raise KeyError("Option '" + str(key) + "' not found.") d[key] = value @staticmethod def apply_dict(d): IBaseObj.set_repr_latex(d["repr_latex"]) @staticmethod def set_setting(**kwargs): for key, value in kwargs.items(): PsiOpts.set_setting_dict(PsiOpts.settings, key, value) PsiOpts.apply_dict(PsiOpts.settings) @staticmethod def setting(**kwargs): PsiOpts.set_setting(**kwargs) @staticmethod def get_setting(key, defaultval = None): if key in PsiOpts.settings: return PsiOpts.settings[key] return defaultval @staticmethod def get_proof(): return PsiOpts.settings["proof"] @staticmethod def set_proof(value): if PsiOpts.settings["proof"] is None: PsiOpts.settings["proof"] = value else: PsiOpts.settings["proof"].copy_(value) @staticmethod def get_random(): rnd = PsiOpts.settings["random"] if rnd is None: # rnd = random.Random() rnd = numpy.random.default_rng() PsiOpts.settings["random"] = rnd return rnd @staticmethod def get_truth(): return PsiOpts.settings["truth"] @staticmethod def timer_left(): if PsiOpts.settings["timer_end"] is None: return None curtime = time.time() * 1000 return PsiOpts.settings["timer_end"] - curtime @staticmethod def timer_left_sec(): r = PsiOpts.timer_left() if r is None: return None return int(round(r / 1000.0)) @staticmethod def is_timer_ended_time(): if PsiOpts.settings["timer_end"] is None: return False curtime = time.time() * 1000 return curtime > PsiOpts.settings["timer_end"] @staticmethod def is_timer_ended_file(): if PsiOpts.settings["stop_file"] is None: return False curtime = time.time() * 1000 if PsiOpts.settings["stop_file_last_check"] is not None and curtime <= PsiOpts.settings["stop_file_last_check"] + 5 * 1000: return PsiOpts.settings["stop_file_exists_cache"] PsiOpts.settings["stop_file_last_check"] = curtime PsiOpts.settings["stop_file_exists_cache"] = os.path.exists(PsiOpts.settings["stop_file"]) return PsiOpts.settings["stop_file_exists_cache"] @staticmethod def is_timer_ended(): if PsiOpts.is_timer_ended_time(): return True if PsiOpts.is_timer_ended_file(): return True return False @staticmethod def has_timer(): return PsiOpts.settings["timer_end"] is not None or PsiOpts.settings["stop_file"] is not None @staticmethod def get_pyomo_options(): r = dict(PsiOpts.settings["pyomo_options"]) # if "tee" not in r: # r["tee"] = PsiOpts.settings["verbose_solver"] # if "verbose" not in r: # r["verbose"] = False timelimit = PsiOpts.timer_left_sec() if timelimit is not None: csolver = iutil.get_solver() if csolver == "pyomo.glpk": r["tmlim"] = timelimit elif csolver == "pyomo.cplex": r["timelimit"] = timelimit elif csolver == "pyomo.gurobi": r["TimeLimit"] = timelimit elif csolver == "pyomo.cbc": r["seconds"] = timelimit return r @staticmethod def setting_strengthen_sign(s): if s.startswith("simplify_"): s = s[len("simplify_"):] if s == "aux_eq" or s == "aux_strengthen" or s == "strengthen": return 1 elif s == "aux_relax" or s == "aux_hull" or s == "relax": return -1 return 0 @staticmethod def setting_strengthen_split(d): r0 = dict() r1 = dict() for a, b in d.items(): if b is True: t = PsiOpts.setting_strengthen_sign(a) if t >= 0: r0[a] = b if t <= 0: r1[a] = b else: r0[a] = b r1[a] = b return (r0, r1) def __init__(self, **kwargs): """ Options. Parameters ---------- **kwargs : TYPE DESCRIPTION. Returns ------- None. """ self.cur_settings = PsiOpts.settings.copy() for key, value in kwargs.items(): PsiOpts.set_setting_dict(self.cur_settings, key, value) def __enter__(self): PsiOpts.settings, self.cur_settings = self.cur_settings, PsiOpts.settings PsiOpts.apply_dict(PsiOpts.settings) return PsiOpts.settings def __exit__(self, exc_type, exc_value, exc_traceback): PsiOpts.settings, self.cur_settings = self.cur_settings, PsiOpts.settings PsiOpts.apply_dict(PsiOpts.settings) class iutil: """Common utilities """ solver_list = ["ortools.GLOP", "pulp.glpk", "pyomo.glpk", "pulp.cbc", "scipy", "z3"] pulp_solver = None pulp_solvers = {} cur_count = 0 cur_count_name = {} @staticmethod def display_latex(s, ismath = True, metadata = None): color = PsiOpts.settings["latex_color"] if color is not None: s = "\\color{" + color + "}{" + s + "}" if ismath: r = IPython.display.Math(s, metadata = metadata) else: r = IPython.display.Latex(s, metadata = metadata) IPython.display.display(r) # return r @staticmethod def float_tostr(x, style = 0, bracket = True, force_float = False): if x == numpy.inf or x == -numpy.inf or numpy.isnan(x): if style & PsiOpts.STR_STYLE_LATEX: if x == numpy.inf: return "\infty" elif x == -numpy.inf: return "-\infty" else: return "?" else: return str(x) ceps = 1e-10 if abs(x) <= ceps: return "0" elif abs(x - round(x)) <= ceps: return str(int(round(x))) else: to_force_float = False to_try_float = False if force_float is True or (isinstance(force_float, int) and force_float >= 10) or PsiOpts.settings["str_float"]: to_force_float = True elif isinstance(force_float, int) and force_float >= 5: to_try_float = True if to_force_float: # return str(x) return ("{:." + str(PsiOpts.settings["str_float_dp"]) + "f}").format(x) denom = 1 if to_try_float: denom = PsiOpts.settings["max_denom_try"] else: denom = PsiOpts.settings["max_denom"] frac = fractions.Fraction(abs(x)).limit_denominator(denom) if to_try_float and abs(x - float(frac)) > ceps: return ("{:." + str(PsiOpts.settings["str_float_dp"]) + "f}").format(x) if x > 0: if style & PsiOpts.STR_STYLE_LATEX_FRAC: return "\\frac{" + str(frac.numerator) + "}{" + str(frac.denominator) + "}" else: if bracket: return "(" + str(frac) + ")" else: return str(frac) else: if style & PsiOpts.STR_STYLE_LATEX_FRAC: return "-\\frac{" + str(frac.numerator) + "}{" + str(frac.denominator) + "}" else: if bracket: return "-(" + str(frac) + ")" else: return "-" + str(frac) @staticmethod def float_snap(x, denom = None, force = False, eps = None): if denom is None: denom = PsiOpts.settings["max_denom"] if eps is None: eps = PsiOpts.settings["eps"] t = float(fractions.Fraction(x).limit_denominator(denom)) if force or abs(x - t) <= eps: return t return x @staticmethod def isconstzero(x, eps = None): if eps is None: eps = PsiOpts.settings["eps"] if isinstance(x, Expr): x = x.get_const() if isinstance(x, bool): return not x if isinstance(x, (int, fractions.Fraction)): return x == 0 if isinstance(x, float): return abs(x) <= eps return False @staticmethod def tostr_verbose(x): if isinstance(x, float): dp = PsiOpts.settings["verbose_float_dp"] if dp is None: return str(x) return ("{:." + str(dp) + "f}").format(x) elif isinstance(x, list): return "[" + ", ".join(iutil.tostr_verbose(a) for a in x) + "]" elif isinstance(x, tuple): return "(" + ", ".join(iutil.tostr_verbose(a) for a in x) + ")" else: return str(x) @staticmethod def float_toz3(x): if z3 is None: return None if abs(x - round(x)) <= 1e-10: return z3.RealVal(int(round(x))) else: frac = fractions.Fraction(abs(x)).limit_denominator( PsiOpts.settings["max_denom"]) return z3.Q(frac.numerator, frac.denominator) @staticmethod def num_open_brackets(s): return s.count("(") + s.count("[") + s.count("{") - ( s.count("}") + s.count("]") + s.count(")")) @staticmethod def split_comma_old(s, delim = ", "): if not isinstance(s, str): return [s] t = s.split(delim) c = "" r = [] for a in t: if c != "": c += delim c += a if iutil.num_open_brackets(c) == 0: r.append(c) c = "" if c != "": r.append(c) return r @staticmethod def split_comma(s): if not isinstance(s, str): return [s] r = [] nopen = 0 for a in s: if nopen == 0 and (a == "," or a == " "): if len(r) == 0 or r[-1] != "": r.append("") else: if a == "(" or a == "[" or a == "{": nopen += 1 elif a == ")" or a == "]" or a == "}": nopen -= 1 if len(r) == 0: r.append("") r[-1] += a return r @staticmethod def get_count(counter_name = None, add = True): if counter_name is None: if add: iutil.cur_count += 1 return iutil.cur_count if counter_name not in iutil.cur_count_name: iutil.cur_count_name[counter_name] = 0 if add: iutil.cur_count_name[counter_name] += 1 return iutil.cur_count_name[counter_name] @staticmethod def gcd(a, b): while b > 0: a, b = b, a % b return a @staticmethod def lcm(a, b): return (a // iutil.gcd(a, b)) * b @staticmethod def hasinstance(a, t): if isinstance(a, t): return True if isinstance(a, (tuple, list)): return any(iutil.hasinstance(x, t) for x in a) @staticmethod def convert_algtype(algtype): if algtype is None: return 0 if isinstance(algtype, str): if algtype == "": return 0 if algtype in ("semigroup", "monoid"): return AlgType.SEMIGROUP if algtype == "group": return AlgType.GROUP if algtype == "abelian": return AlgType.ABELIAN if algtype in ("torsionfree", "vector", "real"): return AlgType.REAL raise ValueError("Algebraic structure \"" + algtype + "\" not supported. Options are " + ", ".join("\"" + x + "\"" for x in ("semigroup", "group", "abelian", "torsionfree", "vector", "real"))) return 0 return algtype @staticmethod def convert_str_style(style): if isinstance(style, str): style = style.lower() if style == "standard" or style == "std": return PsiOpts.settings["str_style_std"] elif style == "aitip": return PsiOpts.settings["str_style_std"] & ~(PsiOpts.STR_STYLE_MARKOV) | PsiOpts.STR_STYLE_REM_SYMBOLS elif style == "psitip" or style == "code": return PsiOpts.settings["str_style_repr"] elif style == "latex": return PsiOpts.settings["str_style_latex"] | PsiOpts.STR_STYLE_LATEX_ARRAY | PsiOpts.STR_STYLE_LATEX_FRAC elif style == "latex_noarray": return PsiOpts.settings["str_style_latex"] | PsiOpts.STR_STYLE_LATEX_FRAC else: return style @staticmethod def reverse_eqnstr(eqnstr): if eqnstr == "<=": return ">=" elif eqnstr == ">=": return "<=" elif eqnstr == "<": return ">" elif eqnstr == ">": return "<" else: return eqnstr @staticmethod def eqnstr_style(eqnstr, style): if eqnstr == "": return "" if style & PsiOpts.STR_STYLE_STANDARD: if eqnstr == "==": return PsiOpts.settings["str_eq"] else: return eqnstr if style & PsiOpts.STR_STYLE_LATEX: if eqnstr == "<=": return "\\le" elif eqnstr == ">=": return "\\ge" elif eqnstr == "==": return "=" elif eqnstr == "!=": return "\\neq" else: return eqnstr return eqnstr @staticmethod def op_str(eqnstr, a, b, isle = None): if eqnstr == "<=": return a <= b elif eqnstr == ">=": return a >= b elif eqnstr == "<": return a < b elif eqnstr == ">": return a > b elif eqnstr == "==" or eqnstr == "=": if isle is not None: if isle: return a == b else: return b == a return a == b elif eqnstr == "!=": return a != b elif eqnstr == ">>": return a >> b elif eqnstr == "<<": return a << b elif eqnstr == "+": return a + b elif eqnstr == "-": return a - b elif eqnstr == "*": return a * b elif eqnstr == "/": return a / b elif eqnstr == "**": return a ** b elif eqnstr == "//": return a // b elif eqnstr == "^": return a ^ b elif eqnstr == "&": return a & b elif eqnstr == "|": return a | b elif eqnstr == "%": return a % b return None @staticmethod def latex_len(s): s = s.replace("\\left", "") s = s.replace("\\right", "") s = s.replace("\\hat", "") s = s.replace("\\bar", "") s = s.replace("\\tilde", "") s = s.replace("\\displaystyle", "") s = s.replace("\\!", "") r = 0.0 i = 0 while i < len(s): c = s[i] if c == "\\": r += 1 i += 1 while i < len(s) and s[i].isalpha(): i += 1 i -= 1 elif c in {"_", "^", "{", "}"}: pass elif c in {" ", ",", "(", ")"}: r += 0.5 else: r += 1 i += 1 return r @staticmethod def latex_split_line(strs, line_len, slstr = ""): if line_len is None: line_len = 100000000000 if isinstance(strs, str): strs = [strs] rlines = [] for s in strs: r = [] c = 0 for i in range(len(s)): sc = s[i:] if (sc.startswith("+") or sc.startswith("-") or sc.startswith("<") or sc.startswith(">") or sc.startswith("\\le") or sc.startswith("\\ge") or sc.startswith("=") or sc.startswith("\\neq")): r.append(s[c: i]) c = i r.append(s[c:]) lines = [] for x in r: if lines and iutil.latex_len(lines[-1]) + iutil.latex_len(x) <= line_len: lines[-1] += x else: lines.append(x) rlines += lines lines = rlines if len(lines) == 0: return "" elif len(lines) == 1: return lines[0] else: return "\\begin{array}{l}\n" + ("\\\\\n" + slstr).join(lines) + "\n\\end{array}" @staticmethod def strip_match(s, a, b): if len(s) >= len(a) + len(b) and s.startswith(a) and s.endswith(b): return s[len(a):-len(b)] return None @staticmethod def latex_concat(style, strs): nlstr = "\n" if style & PsiOpts.STR_STYLE_LATEX: nlstr = "\\\\\n" if not style & PsiOpts.STR_STYLE_LATEX: return nlstr.join(strs) ma0 = "\\begin{align*}\n" ma1 = "\\end{align*}\n" strs = list(strs) cmai = [False] * len(strs) for i in range(len(strs)): s1 = iutil.strip_match(strs[i], ma0, ma1) if s1 is not None: cmai[i] = True strs[i] = s1 if any(cmai): for i in range(len(strs)): if not cmai[i]: strs[i] = "& " + strs[i] return ma0 + nlstr.join(strs) + ma1 return nlstr.join(strs) @staticmethod def str_list_concat(slists): r = [] for slist in slists: if isinstance(slist, str): slist = [slist] if r: r += [",", " "] r += slist return r @staticmethod def meta_concat(metas): if all(meta is None for meta in metas): return None metas = list(metas) metas.sort(key = lambda meta: 0 if meta is None else meta.get("pf_note_priority", 0)) r = dict() for meta in metas: if meta is None: continue for key, value in meta.items(): if key in r: if key == "pf_note": r[key] = iutil.str_list_concat([r[key], value]) elif key == "pf_note_priority": r[key] = r[key] + value else: r[key] = value if "pf_note_priority" in r: r["pf_note_priority"] *= 1.0 / len(metas) return iutil.copy(r) @staticmethod def pf_note_str(s, style, note_color = None, add_space = 0, add_bracket = True): if note_color is None: note_color = PsiOpts.settings["proof_note_color"] if isinstance(s, str): s = [s] if add_bracket: s = [" " * add_space, "(", "since", " "] + s + [")"] else: s = [" " * add_space] + s r = "" if style & PsiOpts.STR_STYLE_LATEX: if note_color is not None: r += "{\\color{" + str(note_color) + "}{" r += iutil.tostring_join(s, style) if style & PsiOpts.STR_STYLE_LATEX: if note_color is not None: r += "}}" return r @staticmethod def str_python_multiline(s): s = str(s) return "(\"" + "\\n\"\n\"".join(s.split("\n")) + "\")" @staticmethod def istensor(a): return not isinstance(a, Comp) and hasattr(a, "shape") # return isinstance(a, (numpy.array, IBaseArray)) or (torch is not None and isinstance(a, torch.Tensor)) @staticmethod def hash_short(s): s = str(s) return hash(s) % 99991 @staticmethod def z3_vardict(rvs, auxs, auxis, reals): if z3 is None: return None n = len(rvs) + len(auxis) r = {} r["#n"] = n s = z3.Solver() if n: x = z3.BitVec("TMPVARX", n) y = z3.BitVec("TMPVARY", n) H = z3.Function("H", x.sort(), z3.RealSort()) s.add(H(z3.BitVecVal(0, n)) == 0) s.add(z3.ForAll([x, y], H(x) <= H(x|y))) s.add(z3.ForAll([x, y], H(x) + H(y) >= H(x|y) + H(x&y))) r["#H"] = H else: r["#H"] = None r["#solver"] = s for i, t in enumerate(rvs): r[t.get_name()] = z3.BitVecVal(1 << i, n) for t in auxs + auxis: r[t.get_name()] = z3.BitVec(t.get_name(), n) for t in reals: r[t.get_name()] = z3.Real(t.get_name()) return r @staticmethod def get_solver(psolver = None): csolver_list = [(x, False) for x in iutil.solver_list] setting_solver = PsiOpts.settings["solver"] if setting_solver is None: setting_solver = "" if isinstance(setting_solver, str): setting_solver = [setting_solver] csolver_list = [(x, True) for x in setting_solver if x != ""] + csolver_list if psolver is not None: csolver_list = [(psolver, True)] + csolver_list warn_list = [] for s, iswarn in csolver_list: sel = False if s == "scipy" and (scipy is not None): sel = True elif s.startswith("pulp.") and (pulp is not None): sel = True elif s.startswith("pyomo.") and (pyo is not None): sel = True elif s.startswith("ortools.") and (ortools is not None): sel = True elif s == "z3" and (z3 is not None): sel = True if sel: if warn_list: warnings.warn("Solver " + ", ".join(warn_list) + " not found. Falling back to " + s + ". Use PsiOpts.setting(solver=\"" + s + "\") to stop this warning.", RuntimeWarning) return s else: if iswarn: warn_list.append(s) if warn_list: warnings.warn("Solver " + ", ".join(warn_list) + " not found.", RuntimeWarning) return "" @staticmethod def pulp_get_solver(solver): coptions = PsiOpts.settings["pulp_options"] msg = PsiOpts.settings["verbose_solver"] copt = solver[solver.index(".") + 1 :].upper() if copt == "OTHER": return iutil.pulp_solver if copt in iutil.pulp_solvers: return iutil.pulp_solvers[copt] r = None if copt == "GLPK": #r = pulp.solvers.GLPK(msg = 0, options = coptions) r = pulp.GLPK(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions) elif copt == "CBC" or copt == "PULP_CBC_CMD": #r = pulp.solvers.PULP_CBC_CMD(options = coptions) r = pulp.PULP_CBC_CMD(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions) elif copt == "GUROBI": r = pulp.GUROBI(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions) elif copt == "CPLEX": r = pulp.CPLEX(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions) elif copt == "MOSEK": r = pulp.MOSEK(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions) elif copt == "CHOCO_CMD": r = pulp.CHOCO_CMD(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions) iutil.pulp_solvers[copt] = r return r @staticmethod def istorch(x): return torch is not None and isinstance(x, torch.Tensor) @staticmethod def ensure_torch(x): if hasattr(x, "get_x"): x = x.get_x() if iutil.istorch(x): return x return torch.tensor(x, dtype=torch.float64) @staticmethod def ensure_comp(x, strict = True): if x is None: return None if isinstance(x, (BayesNet, FcnRelation)): x = x.get_region() if isinstance(x, Comp): return x if isinstance(x, str): return Comp.rv(x) if isinstance(x, int): return Comp.empty() if isinstance(x, IBaseObj): return x.allcomprv_noaux() if x is False: return Comp.empty() r = Comp.empty() for y in x: t = iutil.ensure_comp(y, strict = strict) if t is not None: r += t return r @staticmethod def ensure_expr(x, strict = True): if x is None: return None if isinstance(x, Expr): return x if isinstance(x, Term): return Expr.fromterm(x) if isinstance(x, BayesNet): if strict: raise ValueError("Cannot convert BayesNet to Expr.") x = x.get_region() if isinstance(x, FcnRelation): if strict: raise ValueError("Cannot convert FcnRelation to Expr.") x = x.get_region() if isinstance(x, Region): if strict: raise ValueError("Cannot convert Region to Expr.") return x.expr() if isinstance(x, Comp): if strict: raise ValueError("Cannot convert Comp to Expr.") return Expr.H(x) if isinstance(x, ConcReal): x = float(x) if isinstance(x, (bool, int, float)): return Expr.const(float(x)) if isinstance(x, fractions.Fraction): return Expr.const(float(x)) if iutil.istensor(x): return Expr.const(float(x)) if isinstance(x, str): return Expr.parse(x) if strict: raise ValueError("Cannot convert " + str(type(x)) + " to Expr.") r = Expr.zero() for y in x: t = iutil.ensure_expr(y, strict = strict) if t is not None: r += t return r @staticmethod def ensure_region(x, strict = True): if x is None: return None if isinstance(x, Region): return x if isinstance(x, Comp): raise ValueError("Cannot convert Comp to Region.") # return Region.universe() if isinstance(x, Expr): raise ValueError("Cannot convert Expr to Region.") # return x == 0 if isinstance(x, (bool, int, float)): if x: return Region.universe() else: return Region.empty() if isinstance(x, (BayesNet, FcnRelation)): return x.get_region() if isinstance(x, ConcModel): return x.get_region(vals = True) if isinstance(x, ProofObj): return iutil.ensure_region(x.step_regions()) if isinstance(x, str): return Region.parse(x) if isinstance(x, tuple): if all(isinstance(y, Comp) for y in x): # return BayesNet([x]).get_region() return markov(*x) if isinstance(x, (list, set)): if all(isinstance(y, (tuple, Comp)) for y in x): return BayesNet(list(x)).get_region() r = Region.universe() for y in x: t = iutil.ensure_region(y, strict = strict) if t is not None: r &= t return r @staticmethod def type_coerce(a): if any(isinstance(x, Region) for x in a): return [iutil.ensure_region(x) for x in a] if any(isinstance(x, Expr) for x in a): return [iutil.ensure_expr(x) for x in a] if any(isinstance(x, Comp) for x in a): return [iutil.ensure_comp(x) for x in a] return a @staticmethod def log(x): if iutil.istorch(x): return torch.log(x) else: return numpy.log(x) @staticmethod def xlogxoy(x, y): if iutil.istorch(x) or iutil.istorch(y): ceps_d = PsiOpts.settings["opt_eps_denom"] return x * iutil.log((x + ceps_d) / (y + ceps_d)) else: ceps = PsiOpts.settings["eps"] if x <= ceps: return 0.0 else: return x * numpy.log(x / y) @staticmethod def xlogxoy2(x, y): if iutil.istorch(x) or iutil.istorch(y): ceps_d = PsiOpts.settings["opt_eps_denom"] return x * (iutil.log((x + ceps_d) / (y + ceps_d)) ** 2) else: ceps = PsiOpts.settings["eps"] if x <= ceps: return 0.0 else: return x * (numpy.log(x / y) ** 2) @staticmethod def sqrt(x): if iutil.istorch(x): return torch.sqrt(x) else: return numpy.sqrt(x) @staticmethod def product(x): r = 1 for a in x: r *= a return r @staticmethod def bitcount(x): r = 0 while x != 0: x &= x - 1 r += 1 return r @staticmethod def strpad(*args): r = "" tgtlen = 0 for i in range(0, len(args), 2): r += str(args[i]) if i + 1 < len(args): tgtlen += int(args[i + 1]) while len(r) < tgtlen: r += " " return r @staticmethod def list_tostr(x, tuple_delim = ", ", list_delim = ", ", inden = 0): r = " " * inden if isinstance(x, list): if len([a for a in x if isinstance(a, list) or isinstance(a, tuple)]) > 0: r += "[" for i in range(len(x)): if i == 0: r += iutil.list_tostr(x[i], tuple_delim, list_delim, inden + 2)[inden + 1:] else: r += list_delim + "\n" + iutil.list_tostr(x[i], tuple_delim, list_delim, inden + 2) r += " ]" return r else: r += "[" + list_delim.join([iutil.list_tostr(a, tuple_delim, list_delim, 0) for a in x]) + "]" return r elif isinstance(x, tuple): r += "(" + tuple_delim.join([iutil.list_tostr(a, tuple_delim, list_delim, 0) for a in x]) + ")" return r r += str(x) #r += x.tostring() return r @staticmethod def list_tostr_std(x): return iutil.list_tostr(x, tuple_delim = ": ", list_delim = "; ") @staticmethod def list_iscomplex(x): if not isinstance(x, list): return True for a in x: if not isinstance(a, tuple): return True if len(a) != 2: return True if isinstance(a[1], list): return True return False @staticmethod def str_inden(s, ninden, spacestr = " ", slstr = ""): return slstr + spacestr * ninden + s.replace("\n", "\n" + slstr + spacestr * ninden) @staticmethod def enum_partition(n): def enum_partition_recur(mask): if mask == 0: return [[]] r = [] for i in range(n): if mask & (1 << i) != 0: mask2 = mask - (1 << i) while True: mask3 = (1 << i) | mask2 r += [[mask3] + a for a in enum_partition_recur(mask - mask3)] if mask2 == 0: break mask2 = (mask2 - 1) & (mask - (1 << i)) break return r return enum_partition_recur((1 << n) - 1) @staticmethod def tsort(x): """Topological sort.""" n = len(x) ninc = [0] * n for i in range(n): for j in range(n): if x[i][j]: ninc[j] += 1 cstack = [i for i in range(n) if ninc[i] == 0] r = [] while len(cstack) > 0: i = cstack.pop() r.append(i) for j in range(n): if x[i][j]: ninc[j] -= 1 if ninc[j] == 0: cstack.append(j) return r @staticmethod def iscyclic(x): return len(iutil.tsort(x)) < len(x) @staticmethod def signal_type(x): if isinstance(x, tuple) and len(x) > 0 and isinstance(x[0], str): return x[0] return "" @staticmethod def mhash(x): if isinstance(x, list) or isinstance(x, tuple): return hash(tuple(iutil.mhash(y) for y in x)) return hash(x) @staticmethod def list_unique(x): r = [] s = set() for a in x: h = iutil.mhash(a) if h not in s: s.add(h) r.append(a) return r @staticmethod def list_sorted_unique(x): x = sorted(x, key = lambda a: iutil.mhash(a)) return [x[i] for i in range(len(x)) if i == 0 or not x[i] == x[i - 1]] @staticmethod def list_interleave(x): return [a for t in itertools.zip_longest(*x) for a in t if a is not None] @staticmethod def sumlist(x): if isinstance(x, list) or isinstance(x, tuple): return sum(iutil.sumlist(a) for a in x) return x @staticmethod def split_number_text(s): r = [""] for c in s: if c == " ": if r[-1] != "": r.append("") else: if c.isdigit(): if r[-1] != "" and not r[-1][-1].isdigit(): r.append("") else: if r[-1] != "" and r[-1][-1].isdigit(): r.append("") r[-1] += c return r @staticmethod def get_duration_ms(s): t = iutil.split_number_text(s) r = 0 strs_s = {"s", "sec", "secs", "second", "seconds"} strs_m = {"m", "min", "mins", "minute", "minutes"} strs_h = {"h", "hr", "hrs", "hour", "hours"} strs_d = {"d", "day", "days"} strs_w = {"w", "week", "weeks"} strs_mon = {"mon", "month", "months"} strs_y = {"y", "yr", "yrs", "year", "years"} for i in range(len(t)): if t[i].isdigit(): if i + 1 >= len(t): r += int(t[i]) elif t[i + 1].lower() in strs_s: r += int(t[i]) * 1000 elif t[i + 1].lower() in strs_m: r += int(t[i]) * 1000 * 60 elif t[i + 1].lower() in strs_h: r += int(t[i]) * 1000 * 60 * 60 elif t[i + 1].lower() in strs_d: r += int(t[i]) * 1000 * 60 * 60 * 24 elif t[i + 1].lower() in strs_w: r += int(t[i]) * 1000 * 60 * 60 * 24 * 7 elif t[i + 1].lower() in strs_mon: r += int(t[i]) * 1000 * 60 * 60 * 24 * 7 * 30 elif t[i + 1].lower() in strs_y: r += int(t[i]) * 1000 * 60 * 60 * 24 * 7 * 365 else: r += int(t[i]) return r @staticmethod def remove_format(s): return s.replace("{", "").replace("}", "").replace("_", "") @staticmethod def find_similarity(s, x): if s == x: return 10000 * 3 t = s.split("@@") x2 = x.split("@@") r = 0 for i in range(0, len(t), 2): for j in range(0, len(x2), 2): if x2[j] != "": if x2[j] in t[i]: r = max(r, 10000 + len(x2[j]) - len(t[i])) else: x2r = iutil.remove_format(x2[j]) tr = iutil.remove_format(t[i]) if x2r in tr: r = max(r, 5000 + len(x2r) - len(tr)) return r @staticmethod def set_suffix_num(s, k, schar, replace_mode = "set", style = None, ensure_latex = True): t = s.split("@@") if ensure_latex and len(t) >= 1 and ("@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@") not in s: return iutil.set_suffix_num(s + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + t[0], k, schar, replace_mode, style = style, ensure_latex = False) if len(t) >= 2: for i in range(0, len(t), 2): t[i] = iutil.set_suffix_num(t[i], k, schar, replace_mode, style = None if i == 0 else int(t[i - 1]), ensure_latex = False) return "@@".join(t) if replace_mode != "suffix": s, v0 = iutil.break_subscript_latex(s) if replace_mode == "append": v0 += str(k) elif replace_mode == "set": v0 = str(k) elif replace_mode == "add": if v0.isdigit(): v0 = str(int(v0) + k) else: v0 += str(k) if len(v0) > 1 and style == PsiOpts.STR_STYLE_LATEX: return s + "_{" + v0 + "}" else: return s + "_" + v0 # i = s.rfind(schar) # if i >= 0: # if replace_mode == "append": # return s + str(k) # elif replace_mode == "set": # return s[:i] + schar + str(k) # else: # if s[i + 1 :].isdigit(): # if replace_mode == "add": # return s[:i] + schar + str(int(s[i + 1 :]) + k) # else: # return s[:i] + schar + str(k) return s + schar + str(k) @staticmethod def get_name_fromcat(s, style): t = s.split("@@") r = t[0] for i in range(1, len(t) - 1, 2): if int(t[i]) & style: r = t[i + 1] break if style & PsiOpts.STR_STYLE_REM_SYMBOLS: r = r.replace("_", "").replace("{", "").replace("}", "").replace("\\", "").replace("^", "") return r @staticmethod def name_prefix_suffix(s, s_pre, s_suf): t = s.split("@@") for i in range(0, len(t), 2): t[i] = s_pre + t[i] + s_suf return "@@".join(t) @staticmethod def break_subscript_latex(s): t = s.split("_") if len(t) == 0: return "", "" if len(t) == 1: return t[0], "" v = "_".join(t[1:]) if v.startswith("{") and v.endswith("}"): return t[0], v[1:-1] return t[0], v @staticmethod def is_single_term(x): if isinstance(x, (int, float, tuple, list, IVar)): return True if isinstance(x, (Comp, Expr)): return len(x) == 1 return False @staticmethod def fcn_name_maker(name, v, pname = None, lname = None, cropi = False, infix = False, latex_group = False, fcn_suffix = True, coeff_mul = False): if not isinstance(v, list) and not isinstance(v, tuple): v = [v] r = "" for style in [PsiOpts.STR_STYLE_STANDARD, PsiOpts.STR_STYLE_PSITIP, PsiOpts.STR_STYLE_LATEX]: ccoeff_mul = False if isinstance(coeff_mul, bool): ccoeff_mul = coeff_mul elif isinstance(coeff_mul, int): ccoeff_mul = (style == coeff_mul) else: ccoeff_mul = (style in coeff_mul) if style != PsiOpts.STR_STYLE_STANDARD: r += "@@" + str(style) + "@@" for i, a in enumerate(v): apow = None if isinstance(a, tuple): apow = a[1] a = a[0] a_single_term = iutil.is_single_term(a) def sel(*args): for ca in args: if ca is None: continue if isinstance(ca, tuple): if apow is None or apow >= 0: return ca[0] else: return ca[1] else: return ca return None if (i == 0) ^ infix: if style & PsiOpts.STR_STYLE_STANDARD: r += sel(name) elif style & PsiOpts.STR_STYLE_PSITIP: r += sel(pname, name) elif style & PsiOpts.STR_STYLE_LATEX: r += sel(lname, name) if isinstance(a, str) and a == "#break": break if not infix: if i == 0: r += "(" if i: r += "," else: if latex_group and style & PsiOpts.STR_STYLE_LATEX: r += "{" if not a_single_term: r += "(" t = "" if isinstance(a, (IVar, Comp, Term, Expr, Region)): if (style & PsiOpts.STR_STYLE_STANDARD or style & PsiOpts.STR_STYLE_LATEX) and isinstance(a, Comp): t = a.tostring(style = style, add_bracket = True) else: t = a.tostring(style = style) elif isinstance(a, str): if style & PsiOpts.STR_STYLE_PSITIP: t = repr(a) else: t = str(a) else: t = str(a) if apow is not None: apow_o = apow apow = iutil.float_tostr(apow) if ccoeff_mul: if apow == "1": pass elif apow == "-1": if i == 0: t = "-" + t else: t = iutil.float_tostr(abs(apow_o)) + t if i == 0 and apow_o < 0: t = "-" + t else: if style & PsiOpts.STR_STYLE_PSITIP: t += "**" + apow elif style & PsiOpts.STR_STYLE_LATEX: t += "^" if len(apow) != 1: t += "{" t += apow if len(apow) != 1: t += "}" else: t += "^" + apow if cropi and isinstance(a, Term) and len(t) >= 3 and (t.startswith("I(") or t.startswith("H(")): r += t[2:-1] else: r += t if not infix: if i == len(v) - 1: r += ")" else: if not a_single_term: r += ")" if latex_group and style & PsiOpts.STR_STYLE_LATEX: r += "}" if fcn_suffix: r += PsiOpts.settings["fcn_suffix"] return r @staticmethod def add_subscript_latex(s, v): s, v0 = iutil.break_subscript_latex(s) if len(v0): v0 += "," + str(v) else: v0 = str(v) if len(v0) > 1: return s + "_{" + v0 + "}" else: return s + "_" + v0 @staticmethod def copy(x): if (x is None or isinstance(x, int) or isinstance(x, float) or isinstance(x, str)): return x if isinstance(x, list): return [iutil.copy(a) for a in x] if isinstance(x, tuple): return tuple(iutil.copy(a) for a in x) if isinstance(x, dict): return {iutil.copy(a): iutil.copy(b) for a, b in x.items()} if isinstance(x, set): return {iutil.copy(a) for a in x} if isinstance(x, numpy.ndarray): return numpy.copy(x) if torch is not None and isinstance(x, torch.Tensor): return x.detach().clone() if isinstance(x, IBaseObj): return x.copy() return x @staticmethod def allcomprv(x): if x is None or isinstance(x, (int, float, bool)): return Comp.empty() if isinstance(x, (list, tuple, set)): r = Comp.empty() for a in x: r += iutil.allcomprv(a) return r return x.allcomprv() @staticmethod def allcomprealvar_exprlist(x): if x is None or isinstance(x, (int, float, bool)): return ExprArray.make([]) if isinstance(x, (list, tuple, set)): r = ExprArray.make([]) for a in x: r.iadd_noduplicate(iutil.allcomprealvar_exprlist(a)) return r return x.allcomprealvar_exprlist() @staticmethod def allcompreal_exprlist(x): if x is None or isinstance(x, (int, float, bool)): return ExprArray.make([]) if isinstance(x, (list, tuple, set)): r = ExprArray.make([]) for a in x: r.iadd_noduplicate(iutil.allcompreal_exprlist(a)) return r return x.allcompreal_exprlist() @staticmethod def is_solvable_inteqn(A, b): """Solve A x = b for integer vector x. """ h = len(A) if h == 0: return True w = len(A[0]) if w == 0: return all(x == 0 for x in b) A2 = smithnormalform.matrix.Matrix(h, w, [smithnormalform.z.Z(int(round(x))) for Ar in A for x in Ar]) prob = smithnormalform.snfproblem.SNFProblem(A2) print(A2) prob.computeSNF() if not prob.isValid(): return False b2 = smithnormalform.matrix.Matrix(h, 1, [smithnormalform.z.Z(int(round(x))) for x in b]) c = prob.S * b2 for i in range(h): t = 0 if i < w: t = prob.J.get(i, i).a if t == 0: if c.get(i, 0).a != 0: return False else: if c.get(i, 0).a % t != 0: return False return True # bv = numpy.array(b) # Av = numpy.array(A) # Sv = numpy.reshape(numpy.array([x.a for x in prob.S.elements]), (prob.S.h, prob.S.w)) # Tv = numpy.reshape(numpy.array([x.a for x in prob.T.elements]), (prob.T.h, prob.T.w)) # Jv = numpy.reshape(numpy.array([x.a for x in prob.J.elements]), (prob.J.h, prob.J.w)) # cv = Sv.dot(bv) @staticmethod def check_meta_subs_criteria(v0, v1, cobj = None): mode = PsiOpts.settings["meta_subs_criteria"] if mode is False or mode == "never": return False elif mode is True or mode == "always": return True else: if (isinstance(v0, Comp) and isinstance(v1, Comp)) or (isinstance(v0, Expr) and isinstance(v1, Expr)): if mode == "eqtype": return True elif mode == "eqtype_nonempty": return (isinstance(v1, Comp) and not v1.isempty()) or (isinstance(v1, Expr) and not v1.iszero()) return False def substitute(x, v0, v1): if x is None: return if isinstance(x, (list, tuple)): for a in x: iutil.substitute(a, v0, v1) elif isinstance(x, dict): for a, b in x.items(): iutil.substitute(b, v0, v1) elif isinstance(x, IBaseObj): x.substitute(v0, v1) def substitute_whole(x, v0, v1): if x is None: return if isinstance(x, (list, tuple)): for a in x: iutil.substitute_whole(a, v0, v1) elif isinstance(x, dict): for a, b in x.items(): iutil.substitute_whole(b, v0, v1) elif isinstance(x, IBaseObj): x.substitute_whole(v0, v1) @staticmethod def isnumeric(x): return isinstance(x, (int, float, fractions.Fraction)) @staticmethod def str_join(x, delim = ""): if isinstance(x, list) or isinstance(x, tuple): return delim.join(iutil.str_join(a, delim) for a in x) return str(x) @staticmethod def tostring_join(x, style, delim = "", nlstr = "\n"): if isinstance(x, list) or isinstance(x, tuple): r = "" cmds = [] prev_comp = Comp.empty() for a in x: if r != "": r += delim if isinstance(a, str) and a.startswith("#CMD_"): cmds.append(a) else: if isinstance(a, Comp): if "#CMD_EXCLUDE_PREV" in cmds: a = a - prev_comp cmds.remove("#CMD_EXCLUDE_PREV") prev_comp = a.copy() r += iutil.tostring_join(a, style, delim, nlstr) return r # return delim.join(iutil.tostring_join(a, style, delim, nlstr) for a in x) if hasattr(x, "tostring"): return x.tostring(style = style) if isinstance(x, str): if style & PsiOpts.STR_STYLE_LATEX: if x == " ": return "\\," if all(a == " " for a in x): return "\\;" * len(x) if x == "(": return "\\left(" if x == ")": return "\\right)" if x == "because" or x == "since": return PsiOpts.settings["latex_because"] if x == "therefore": return PsiOpts.settings["latex_therefore"] if x == "exists": return PsiOpts.settings["latex_exists"] if x == "for all": return PsiOpts.settings["latex_forall"] if x == "is typical": return PsiOpts.settings["latex_is_typical"] if x == "\n": return nlstr if style & PsiOpts.STR_STYLE_LATEX: return "\\text{" + str(x) + "}" return str(x) @staticmethod def bit_reverse(x, n): r = 0 for i in range(n): if x & (1 << i): r += 1 << (n - 1 - i) return r @staticmethod def bin_to_gray(x): """Binary to Gray code. From en.wikipedia.org/wiki/Gray_code """ return x ^ (x >> 1) @staticmethod def gray_to_bin(x): """Gray code to binary. From en.wikipedia.org/wiki/Gray_code """ m = x while m: m >>= 1 x ^= m return x @staticmethod def gbsearch(fcn, num_iter = 30): """Binary search. """ lb = None ub = None for it in range(num_iter): m = 0.0 if it == 0: m = 0.0 elif ub is None: m = max(lb * 2, 1.0) elif lb is None: m = min(ub * 2, -1.0) else: m = (lb + ub) * 0.5 if fcn(m): ub = m else: lb = m if ub is None: return lb if lb is None: return ub return (lb + ub) * 0.5 @staticmethod def polygon_frompolar(poly, inf_value = 1e6): """ Convert polygon from polar representation to positive and negative portions. Parameters ---------- poly : list List of 3-tuples of vertices. inf_value : float, optional The value used as infinity. The default is 1e6. Returns ------- r : list List of two polygons (positive and negative portions, both are lists of 2-tuples). """ r = [[], []] x = [[0 if abs(a[0]) <= 1e-10 else 1 if a[0] > 0 else -1] + a[1:] for a in poly] x.append(x[0]) for i in range(len(x) - 1): if x[i][0] > 0: r[0].append(x[i][1:]) elif x[i][0] < 0: r[1].append([-y for y in x[i][1:]]) ray = None if x[i][0] == 0: ray = x[i] elif x[i][0] * x[i + 1][0] < 0: ray = [y0 + y1 for y0, y1 in zip(x[i], x[i+1])] if ray is not None: ray = ray[1:] norm = max(abs(y) for y in ray) ray = [y / norm * inf_value for y in ray] r[0].append(ray) r[1].append([-y for y in ray]) return r class MHashList: def __init__(self, x = None, ha = None): if x is None: self.x = [] else: self.x = x if ha is None: self.ha = [] else: self.ha = ha def add(self, y): h = iutil.mhash(y) i = 0 while i < len(self.x): if self.ha[i] == h: self.x.pop(i) self.ha.pop(i) else: i += 1 self.x.append(y) self.ha.append(h) def clear(self): self.x[:] = [] self.ha[:] = [] def __len__(self): return len(self.x) def __getitem__(self, key): return self.x[key] def copy(self): return MHashList(self.x[:], self.ha[:]) class MHashSet: def __init__(self, x = None, s = None): if x is None: self.x = [] else: self.x = x if s is None: self.s = set() else: self.s = s def add(self, y): h = iutil.mhash(y) if h in self.s: return False self.x.append(y) self.s.add(h) return True def clear(self): self.x[:] = [] self.s.clear() def __iadd__(self, other): for y in other: self.add(y) return self def __len__(self): return len(self.x) def __getitem__(self, key): return self.x[key] #return self.x[len(self.x) - 1 - key] def copy(self): return MHashSet(self.x[:], self.s.copy()) def __eq__(self, other): return self.s == other.s def __ne__(self, other): return self.s != other.s def __hash__(self): return hash(frozenset(self.s)) def fcn_substitute(fcn): @functools.wraps(fcn) def wrapper(cself, *args, **kwargs): clist = [] def fcn2(cself, key, val): if isinstance(key, str): key = cself.find(key) if not ((isinstance(key, Comp) and not key.isempty()) or (isinstance(key, Expr) and not key.iszero())): return if isinstance(val, str): if val == "": val = None else: val = cself.find(val) if not ((isinstance(val, Comp) and not val.isempty()) or (isinstance(val, Expr) and not val.iszero())): return if val is None: if isinstance(key, Comp): val = Comp.empty() elif isinstance(key, Expr): val = Expr.zero() else: return if isinstance(key, Expr) and not isinstance(val, Comp): t = iutil.ensure_expr(val) if t is not None: val = t # fcn(cself, key, val) clist.append([key, val]) i = 0 while i < len(args): if isinstance(args[i], dict): for key, val in args[i].items(): fcn2(cself, key, val) i += 1 elif isinstance(args[i], CompArray): for key, val in args[i].to_dict().items(): fcn2(cself, key, val) i += 1 elif isinstance(args[i], list): for key, val in args[i]: fcn2(cself, key, val) i += 1 elif i + 1 < len(args): fcn2(cself, args[i], args[i + 1]) i += 2 else: i += 1 for key, val in kwargs.items(): fcn2(cself, key, val) # print(clist) for i, v in enumerate(clist): if i < len(clist) - 1: if isinstance(v[0], Expr): v.insert(1, Expr.real("#TMPSUB" + str(i))) else: v.insert(1, Comp.rv("#TMPSUB" + str(i))) # print(clist) for it in range(2): for i, v in enumerate(clist): if it + 1 < len(v): # print(str(v[it]) + " " + str(v[it + 1])) fcn(cself, v[it], v[it + 1]) return wrapper def latex_postprocess(fcn): @functools.wraps(fcn) def wrapper(*args, **kwargs): r = fcn(*args, **kwargs) color = PsiOpts.settings["latex_color"] if color is not None: r = "{\\color{" + str(color) + "}{" + r + "}}" return r return wrapper def fcn_list_to_list(fcn): @functools.wraps(fcn) def wrapper(*args, **kwargs): islist = False maxlen = -1 maxshape = tuple() for a in itertools.chain(args, kwargs.values()): if CompArray.isthis(a) or ExprArray.isthis(a): islist = True # maxlen = max(maxlen, len(a)) if len(a) > maxlen: maxlen = len(a) if isinstance(a, list): maxshape = (len(a),) else: maxshape = a.shape if not islist: return fcn(*args, **kwargs) r = [] for i in range(maxlen): targs = [] for a in args: if CompArray.isthis(a): if i < len(a): targs.append(a[i]) else: targs.append(Comp.empty()) elif ExprArray.isthis(a): if i < len(a): targs.append(a[i]) else: targs.append(Expr.zero()) else: targs.append(a) tkwargs = dict() for key, a in kwargs.items(): if CompArray.isthis(a): if i < len(a): tkwargs[key] = a[i] else: tkwargs[key] = Comp.empty() elif ExprArray.isthis(a): if i < len(a): tkwargs[key] = a[i] else: tkwargs[key] = Expr.zero() else: tkwargs[key] = a r.append(fcn(*targs, **tkwargs)) if len(r) > 0: if isinstance(r[0], Region): return alland(r) elif isinstance(r[0], Comp): return CompArray(r, maxshape) elif isinstance(r[0], Expr): return ExprArray(r, maxshape) else: return r else: return None return wrapper class PsiRec: num_lpprob = 0 class IVarType: NIL = 0 RV = 1 REAL = 2 class AlgType: NIL = 0 SEMIGROUP = 1 GROUP = 2 ABELIAN = 3 REAL = 4 class IBaseObj: """Base class of objects """ def __init__(self): pass def add_meta(self, key, value): if self.meta is None: self.meta = {} self.meta[key] = value return self def get_meta(self, key): if self.meta is None: return None if key not in self.meta: return None return self.meta[key] def remove_meta(self, key): if self.meta is None: return self self.meta.pop(key, None) return self @latex_postprocess def _latex_(self): return "" def latex(self, skip_simplify = False): """LaTeX code """ if skip_simplify: r = "" with PsiOpts(repr_simplify = False): r = self._latex_() return r return self._latex_() def display(self, skip_simplify = False, **kwargs): """IPython display """ if kwargs: r = None with PsiOpts(**{key: val for key, val in kwargs.items()}): r = self.display(skip_simplify = skip_simplify) return r iutil.display_latex(self.latex(skip_simplify = skip_simplify)) def print(self, **kwargs): """Print this object """ if kwargs: with PsiOpts(**{key: val for key, val in kwargs.items()}): print(self) return print(self) def display_bool(self, s = "{region} \\;\\mathrm{{is}}\\;\\mathrm{{{truth}}}", skip_simplify = True): """IPython display, show bool """ iutil.display_latex(s.format(region = self.latex(skip_simplify = skip_simplify), truth = str(bool(self)))) def display_truth_value(self, s = "{region} \\;\\mathrm{{is}}\\;\\mathrm{{{truth}}}", skip_simplify = True): """IPython display, show truth value """ iutil.display_latex(s.format(region = self.latex(skip_simplify = skip_simplify), truth = str(self.truth_value()))) def substituted(self, *args, **kwargs): """Substitute variable v0 by v1 (v1 can be compound), return result""" r = self.copy() r.substitute(*args, **kwargs) return r def substituted_whole(self, *args, **kwargs): """Substitute variable v0 by v1 (v1 can be compound), return result""" r = self.copy() r.substitute_whole(*args, **kwargs) return r def subs(self, *args, **kwargs): """Alias of substituted_whole """ return self.substituted_whole(*args, **kwargs) def subs_aux(self, *args, **kwargs): """Alias of substituted_aux """ return self.substituted_aux(*args, **kwargs) def substituted_swap(self, x, y = None): """Swap two variablex x, y """ if y is not None: x = [(x, y)] return self.substituted_whole(list(x) + [(b, a) for a, b in x]) def subs_swap(self, *args, **kwargs): """Alias of substituted_swap """ return self.substituted_swap(*args, **kwargs) def reg(self, *args, **kwargs): """Alias of get_region """ return self.get_region(*args, **kwargs) def ensure_region(self, *args, **kwargs): if isinstance(self, Region): return self return self.get_region(*args, **kwargs) def defn(self, *args, **kwargs): """Alias of definition """ return self.definition(*args, **kwargs) def __copy__(x): return x.copy() def __deepcopy__(x, memo): return x.copy() @staticmethod def set_repr_latex(enabled): hasa = hasattr(IBaseObj, "_repr_latex_") if enabled and not hasa: setattr(IBaseObj, "_repr_latex_", lambda s: "$" + s.latex() + "$") if not enabled and hasa: delattr(IBaseObj, "_repr_latex_") def simplified_truth(self, reg = None, quick = False): truth = PsiOpts.settings["truth"] if truth is not None: if reg is None: reg = truth.copy() else: reg = reg & truth with PsiOpts(truth = None): if isinstance(self, Expr): return self.simplified(reg = reg, bnet = reg.get_bayesnet()) else: if quick: return self.simplified_quick(reg = reg) else: return self.simplified(reg = reg) if quick: return self.simplified_quick(reg = reg) else: return self.simplified(reg = reg) def mark_rv(self, x, *args, **kwargs): y = x.copy() y.mark(*args, **kwargs) for a, b in zip(x, y): self.substitute(a, b) return self def marked_rv(self, x, *args, **kwargs): r = self.copy() r.mark_rv(x, *args, **kwargs) return r def allcomp(self): index = IVarIndex() self.record_to(index) return index.comprv + index.compreal def allcomprv(self): index = IVarIndex() self.record_to(index) return index.comprv def getauxall(self): return Comp.empty() def allcomprv_noaux(self): return self.allcomprv() - self.getauxall() def allcompreal(self): index = IVarIndex() self.record_to(index) return index.compreal def allcompreal_exprlist(self): t = self.allcompreal() return ExprArray.make([Expr.real(v) for v in t]) def allcomprealvar_exprlist(self): t = self.allcomprealvar() return ExprArray.make([Expr.real(v) for v in t]) def simplify_regterm(self, reg = None): regterms = {} self.regtermmap(regterms, False) for (name, x) in regterms.items(): if isinstance(x, Term): y = x.copy() t = y.simplify_regterm_expr(reg = reg) if t is not None: self.substitute(Expr.fromterm(x), t) @property def rvs(self): return self.allcomprv_noaux() @property def reals(self): return self.allcomprealvar_exprlist() PsiOpts.setting(repr_latex = True) # Latex display default on class IVar(IBaseObj): """Random variable or real variable Do NOT use this class directly. Use Comp instead """ def __init__(self, vartype, name, reg = None, reg_det = False, markers = None, algtype = 0, alglist = None): self.vartype = vartype self.name = name self.reg = reg self.reg_det = reg_det self.markers = markers self.algtype = algtype self.alglist = alglist @staticmethod def rv(name): return IVar(IVarType.RV, name) @staticmethod def index(name): r = IVar(IVarType.RV, name) r.markers = [("index_shift", 0)] return r @staticmethod def real(name): return IVar(IVarType.REAL, name) @staticmethod def eps(): return IVar(IVarType.REAL, "EPS") @staticmethod def one(): return IVar(IVarType.REAL, "ONE") @staticmethod def inf(): return IVar(IVarType.REAL, "INF") def isrealvar(self): return self.vartype == IVarType.REAL and self.name != "ONE" and self.name != "EPS" and self.name != "INF" def get_marker_key(self, key): if self.markers is not None: for v, w in reversed(self.markers): if v == key: return w return None def tostring(self, style = 0): r = "" if style & PsiOpts.STR_STYLE_LATEX: if self.name == "EPS": r = PsiOpts.settings["latex_eps"] elif self.name == "INF": r = PsiOpts.settings["latex_infty"] if r == "": r = iutil.get_name_fromcat(self.name, style) shift = self.get_marker_key("index_shift") if shift is not None and shift != 0: if shift > 0: r += "+" r += str(shift) return r def __str__(self): return self.tostring(PsiOpts.settings["str_style"]) def __repr__(self): return self.tostring(PsiOpts.settings["str_style_repr"]) @latex_postprocess def _latex_(self): return self.tostring(iutil.convert_str_style("latex")) def __hash__(self): return hash(self.name) def __eq__(self, other): return self.name == other.name def copy(self): return IVar(self.vartype, self.name, None if self.reg is None else self.reg.copy(), self.reg_det, None if self.markers is None else self.markers[:], self.algtype, iutil.copy(self.alglist)) def copy_noreg(self): return IVar(self.vartype, self.name, None, False, None if self.markers is None else self.markers[:], self.algtype, iutil.copy(self.alglist)) @staticmethod def word_normalize(algtype, x): if algtype in (AlgType.ABELIAN, AlgType.REAL): x.sort(key = lambda a: str(a[0])) @staticmethod def word_combine_to(algtype, x, y): if algtype in (AlgType.ABELIAN, AlgType.REAL): for b, bc in y: for i in range(len(x)): if x[i][0] == b: x[i] = (x[i][0], x[i][1] + bc) break else: x.append((iutil.copy(b), bc)) elif algtype in (AlgType.GROUP, AlgType.SEMIGROUP): y = iutil.copy(y) while x and y and x[-1][0] == y[0][0] and x[-1][1] + y[0][1] == 0: x.pop() y.pop(0) if x and y and x[-1][0] == y[0][0]: x[-1] = (x[-1][0], x[-1][1] + y[0][1]) y.pop(0) x += y x[:] = [(a, ac) for a, ac in x if ac != 0] @staticmethod def word_combine(algtype, x, y): x = iutil.copy(x) IVar.word_combine_to(algtype, x, y) return x @staticmethod def word_pow(algtype, x, p): if p == 0: return [] x = iutil.copy(x) if p == 1: return x if algtype in (AlgType.ABELIAN, AlgType.REAL): return [(a, ac * p) for a, ac in x] if p < 0: x = [(a, -ac) for a, ac in reversed(x)] p = -p r = list(x) for i in range(p - 1): IVar.word_combine_to(algtype, r, x) return r @staticmethod def word_name(algtype, x): if len(x) == 0: return iutil.fcn_name_maker("id", ["#break"], lname = PsiOpts.settings["latex_group_id"]) if PsiOpts.settings["latex_group_additive"]: return iutil.fcn_name_maker("*", [(a, ac) if ac != 1 else a for a, ac in x], lname = (PsiOpts.settings["latex_group_add"] + " ", PsiOpts.settings["latex_group_minus"] + " "), infix = True, coeff_mul = PsiOpts.STR_STYLE_LATEX) else: return iutil.fcn_name_maker("*", [(a, ac) if ac != 1 else a for a, ac in x], lname = PsiOpts.settings["latex_group_mul"] + " ", infix = True) def get_alglist(self): if self.algtype == 0: return [] if self.alglist is None: return [[(self, 1)]] return self.alglist def __mul__(self, other): if isinstance(other, (int, float, fractions.Fraction)) and other == 1: return self.copy() if self.algtype == 0: raise ValueError("Cannot multiply non-group-valued variables.") return if self.algtype != other.algtype: raise ValueError("Cannot multiply variables with different types.") return r = self.copy() r.alglist = [] for x in self.get_alglist(): for y in other.get_alglist(): r.alglist.append(IVar.word_combine(self.algtype, x, y)) if PsiOpts.settings["alg_normalize"]: for a in r.alglist: IVar.word_normalize(r.algtype, a) r.name = IVar.word_name(r.algtype, r.alglist[0]) return r def __pow__(self, other): if self.algtype == 0: raise ValueError("Cannot take power of non-group-valued variables.") return r = self.copy() r.alglist = iutil.copy(r.get_alglist()) for x in r.alglist: x[:] = IVar.word_pow(r.algtype, x, other) if PsiOpts.settings["alg_normalize"]: for a in r.alglist: IVar.word_normalize(r.algtype, a) r.name = IVar.word_name(r.algtype, r.alglist[0]) return r def __truediv__(self, other): return self * (other ** -1) def __rtruediv__(self, other): if isinstance(other, (int, float, fractions.Fraction)) and other == 1: return self ** -1 return other * (self ** -1) def __rmul__(self, other): return self * other @staticmethod def word_recordto(algtype, index, x): for a, ac in x: index.record(Comp([a])) @staticmethod def word_toarray(algtype, index, x): r = numpy.zeros(len(index.comprv)) for a, ac in x: r[index.get_index(a)] += ac return r @staticmethod def word_tointlist(algtype, index, x): r = [0] * len(index.comprv) for a, ac in x: r[index.get_index(a)] += ac return r @staticmethod def word_toid(index, x): return [(index.get_index(a), ac) for a, ac in x] @staticmethod def word_hash(x): return hash(tuple(x)) @staticmethod def word_complexity(x): return sum(abs(ac) + 1 for a, ac in x) @staticmethod def word_checkfcn(algtype, xs, y): index = IVarIndex() IVar.word_recordto(algtype, index, y) if len(index.comprv) == 0: return True if len(xs) == 0: return False for x in xs: IVar.word_recordto(algtype, index, x) if algtype == AlgType.REAL: m = numpy.vstack([IVar.word_toarray(algtype, index, x) for x in xs]) m2 = numpy.vstack([m, IVar.word_toarray(algtype, index, y)]) return numpy.linalg.matrix_rank(m2) <= numpy.linalg.matrix_rank(m) # elif algtype == AlgType.ABELIAN: # A = [IVar.word_tointlist(algtype, index, x) for x in xs] # return iutil.is_solvable_inteqn([[A[j][i] for j in range(len(A))] for i in range(len(A[0]))], # IVar.word_tointlist(algtype, index, y)) elif algtype in (AlgType.GROUP, AlgType.SEMIGROUP, AlgType.ABELIAN): xs = [IVar.word_toid(index, x) for x in xs] y = IVar.word_toid(index, y) wlist = [] vis = set() nit = [0] maxnit = (len(xs) + 1) * PsiOpts.settings["alg_group_nit"] found = [False] start_inv = (algtype == AlgType.GROUP or algtype == AlgType.ABELIAN) push_inv = start_inv push_left = (algtype == AlgType.GROUP or algtype == AlgType.SEMIGROUP) def trypush(a): if found[0]: return False if start_inv: if len(a) == 0: found[0] = True return False else: if a == y: found[0] = True return False if nit[0] > maxnit: return False ahash = tuple(a) # IVar.word_hash(a) if ahash in vis: return False heapq.heappush(wlist, (IVar.word_complexity(a), a)) vis.add(ahash) nit[0] += 1 return True if start_inv: trypush(IVar.word_pow(algtype, y, -1)) else: trypush([]) while wlist: if found[0]: break acomp, a = heapq.heappop(wlist) for x in xs: if found[0]: break trypush(IVar.word_combine(algtype, a, x)) if push_left: trypush(IVar.word_combine(algtype, x, a)) if push_inv: xinv = IVar.word_pow(algtype, x, -1) trypush(IVar.word_combine(algtype, a, xinv)) if push_left: trypush(IVar.word_combine(algtype, xinv, a)) return found[0] return False @staticmethod def word_fcnrelation(algtype, xs): fcns = FcnRelation() index = IVarIndex() for x in xs: IVar.word_recordto(algtype, index, x) xmasks = [] for x in xs: index2 = IVarIndex() IVar.word_recordto(algtype, index2, x) xmasks.append(index.get_mask(index2.comprv)) n = len(xs) for i in range(n): for jmask in igen.subset_mask((1 << n) - 1 - (1 << i)): if algtype == AlgType.REAL: if iutil.bitcount(jmask) > len(index.comprv): continue jmask2 = 0 for k, a in enumerate(xmasks): if jmask & (1 << k): jmask2 |= a if jmask2 | xmasks[i] != jmask2: continue if fcns.check_fcn(jmask, 1 << i): continue if IVar.word_checkfcn(algtype, [a for k, a in enumerate(xs) if (1 << k) & jmask], xs[i]): fcns.add_fcn(jmask, 1 << i) fcns.simplify() return fcns class Comp(IBaseObj): """Compound random variable or real variable """ def __init__(self, varlist): self.varlist = varlist @staticmethod def empty(): """ The empty random variable. Returns ------- Comp """ return Comp([]) @staticmethod def rv(name): """ Random variable. Parameters ---------- name : str Name of the random variable. Returns ------- Comp """ return Comp([IVar(IVarType.RV, name)]) @staticmethod def index(name): """ Random variable used as an index. Parameters ---------- name : str Name of the random variable. Returns ------- Comp """ return Comp([IVar.index(name)]) @staticmethod def rv_reg(a, reg, reg_det = False): r = a.copy_noreg() for i in range(len(r.varlist)): r.varlist[i].reg = reg.copy() r.varlist[i].reg_det = reg_det return r #return Comp([IVar(IVarType.RV, str(a), reg.copy(), reg_det)]) @staticmethod def real(name): return Comp([IVar(IVarType.REAL, name)]) @staticmethod def array(name, st, en = None): if isinstance(st, int): if en is None: en = st st = 0 st = range(st, en) t = [] for i in st: istr = str(i) s = name + "_" + istr s += "@@" + str(PsiOpts.STR_STYLE_LATEX) s += "@@" + iutil.add_subscript_latex(name, istr) t.append(IVar(IVarType.RV, s)) return Comp(t) def get_name(self): if len(self.varlist) == 0: return "" return self.varlist[0].name def find(self, *args): args = [x for b in args for x in iutil.split_comma(b)] r = [] numrv = 0 numreal = 0 for carg0 in args: cmaxa0 = None for carg in carg0.split("+"): cmax = 0 cmaxa = None for a in self.varlist: t = iutil.find_similarity(a.name, carg.strip()) if t > cmax: cmax = t cmaxa = Comp([a.copy()]) if cmaxa is not None: if cmaxa0 is None: cmaxa0 = cmaxa else: cmaxa0 += cmaxa if cmaxa0 is None: continue if cmaxa0.get_type() == IVarType.RV: r.append(cmaxa0) numrv += 1 elif cmaxa0.get_type() == IVarType.REAL: r.append(Expr.fromcomp(cmaxa0)) numreal += 1 if len(r) == 0: return None if len(r) == 1: return r[0] if numreal == 0 and numrv >= 2: return CompArray(r) if numrv == 0 and numreal >= 2: return ExprArray(r) return r def get_type(self): if len(self.varlist) == 0: return IVarType.NIL return self.varlist[0].vartype def allcomp(self): return self.copy() def complexity(self): return len(self.varlist) def sorting_priority(self): return self.complexity() def sorting_tuple(self): s = str(self) return (s[0] if len(s) >= 1 else "", self.sorting_priority(), s) def subsets(self, minsize = 0, maxsize = 100000, size = None, reverse = False): """Subsets of this random variable, as a generator """ return igen.subset(self, minsize = minsize, maxsize = maxsize, size = size, reverse = reverse) @staticmethod def parse(s): """Parse a string, e.g. "X+Y, Z W" """ r = RegionParser.parse_default("H(" + s + ")") return r.allcomp() def swapped_id(self, i, j): if i >= len(self.varlist) or j >= len(self.varlist): return self.copy() r = self.copy() r.varlist[i], r.varlist[j] = r.varlist[j], r.varlist[i] return r def set_algtype(self, algtype): algtype = iutil.convert_algtype(algtype) for a in self.varlist: a.algtype = algtype return self def set_markers(self, markers): for a in self.varlist: if markers is None: a.markers = None else: a.markers = markers[:] return self def add_markers(self, markers): for a in self.varlist: if a.markers is None: a.markers = [] a.markers += markers return self def add_marker(self, key, value = 1): self.add_markers([(key, value)]) return self def add_marker_id(self, key): return self.add_marker(key, iutil.get_count()) def mark(self, *args, **kwargs): for a in args: if a == "symm" or a == "disjoint" or a == "nonsubset": self.add_marker_id(a) else: self.add_marker(a) for key, value in kwargs.items(): self.add_marker(key, value) return self def get_marker_key(self, key): for a in self.varlist: if a.markers is not None: for v, w in reversed(a.markers): if v == key: return w return None def write_marker_key(self, key, value): for a in self.varlist: if a.markers is None: a.markers = [] found = False for i in range(len(a.markers)): if a.markers[i][0] == key: a.markers[i] = (key, value) found = True if not found: a.markers.append((key, value)) return None def get_markers(self): r = [] for a in self.varlist: if a.markers is not None: for v, w in a.markers: if (v, w) not in r: r.append((v, w)) return r def set_card(self, m): self.write_marker_key("card", m) return self def get_card(self): r = 1 for a in self: t = a.get_marker_key("card") if t is None: return None r *= t return r def get_shape(self): r = [] for a in self: t = a.get_card() if t is None: raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.") return r.append(t) return tuple(r) def get_shape_in(self): return tuple() def get_shape_out(self): return self.get_shape() @property def card(self): return self.get_card() @card.setter def card(self, value): self.set_card(value) @property def shape(self): return self.get_shape() def get_index_shift(self): return self.get_marker_key("index_shift") def set_index_shift(self, x): self.write_marker_key("index_shift", x) def inc_index_shift(self, x): c = self.get_index_shift() if c is None: c = 0 self.write_marker_key("index_shift", c + x) def add_suffix(self, csuffix): for a in self.varlist: a.name += csuffix def added_suffix(self, csuffix): r = self.copy() r.add_suffix(csuffix) return r def __getitem__(self, key): r = self.varlist[key] if isinstance(r, list): return Comp(r) return Comp([r]) def __setitem__(self, key, value): if value.isempty(): del self.varlist[key] self.varlist[key] = value.varlist[0] def __delitem__(self, key): del self.varlist[key] def __iter__(self): for a in self.varlist: yield Comp([a]) def copy(self): return Comp([a.copy() for a in self.varlist]) def copy_noreg(self): return Comp([a.copy_noreg() for a in self.varlist]) def addvar(self, x): if x in self.varlist: return self.varlist.append(x.copy()) def removevar(self, x): self.varlist = [a for a in self.varlist if a != x] def reg_excluded(self): r = Comp.empty() for a in self.varlist: if a.reg is None: r.varlist.append(a) return r def index_of(self, x): for i, a in enumerate(self.varlist): if x.ispresent(a): return i return -1 def ispresent_shallow(self, x): if isinstance(x, str): if x == "real": return self.get_type() == IVarType.REAL if x == "realvar": for a in self.varlist: if a.isrealvar(): return True return False if isinstance(x, Comp): for y in x.varlist: if y in self.varlist: return True return False return x in self.varlist def ispresent(self, x): if isinstance(x, Expr) or isinstance(x, Region): x = x.allcomp() for a in self.varlist: if a.reg is not None and a.reg.ispresent(x): return True return self.ispresent_shallow(x) def __iadd__(self, other): if isinstance(other, int): c = self.get_index_shift() if c is not None: self.set_index_shift(c + other) return self for i in range(len(other.varlist)): self.addvar(other.varlist[i]) return self def __add__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, Expr): return other.__radd__(self) r = self.copy() if isinstance(other, (Comp, int)): r += other return r def __radd__(self, other): r = self.copy() if isinstance(other, (Comp, int)): r += other return r def __isub__(self, other): if isinstance(other, int): self += -other return self for i in range(len(other.varlist)): self.removevar(other.varlist[i]) return self def __sub__(self, other): r = self.copy() r -= other return r def __floordiv__(self, other): return BayesNet([self]) // other def __xor__(self, other): return BayesNet([self]) ^ other def avoid(self, *args, samesuffix = True): reg = Comp.empty() for a in args: if isinstance(a, IBaseObj): reg += a.allcomp() r = Region.universe().exists(self) r.aux_avoid_from(reg, samesuffix = samesuffix) return r.aux def __imul__(self, other): if iutil.isconstzero(other): self.varlist = [] return self if isinstance(other, Comp): self.varlist = [a * b for a in self.varlist for b in other.varlist] return self def __mul__(self, other): r = self.copy() r *= other return r def __rmul__(self, other): r = self.copy() r *= other return r def __ipow__(self, other): self.varlist = [a ** other for a in self.varlist] return self def __pow__(self, other): r = self.copy() r **= other return r def __truediv__(self, other): return self * (other ** -1) def __rtruediv__(self, other): if isinstance(other, (int, float, fractions.Fraction)) and other == 1: return self ** -1 return other * (self ** -1) def product(self): if len(self.varlist) == 0: return 1 r = None for a in self.varlist: if r is None: r = Comp([a]).copy() else: r *= Comp([a]) return r def alg_region(self): ws = [] was = [] algtypes = set(a.algtype for a in self.varlist) r = Expr.zero() for algtype in algtypes: if algtype == 0: continue for a in self.varlist: if a.algtype == algtype: for x in a.get_alglist(): ws.append(x) was.append(Comp([a]).copy()) fcns = IVar.word_fcnrelation(algtype, ws) for t in fcns.fcn: t2 = [] for x in t: ct = Comp.empty() for i, a in enumerate(was): if x & (1 << i): ct += a t2.append(ct) r += Expr.Hc(t2[1], t2[0]) if r.iszero(): return Region.universe() return r <= 0 def get_indreg(self, skip_abscont = False): if skip_abscont: return Region.universe() return self.alg_region() def inter(self, other): """Intersection.""" return Comp([a for a in self.varlist if a in other.varlist]) def interleaved(self, other): r = Comp([]) for i in range(max(len(self.varlist), len(other.varlist))): if i < len(self.varlist): r.varlist.append(self.varlist[i].copy()) if i < len(other.varlist): r.varlist.append(other.varlist[i].copy()) return r def size(self): return len(self.varlist) def __len__(self): return len(self.varlist) def __bool__(self): return bool(self.varlist) def isempty(self): """Whether self is empty.""" return (len(self.varlist) == 0) def from_mask(self, mask): """Return subset using bit mask.""" r = [] for i in range(len(self.varlist)): if mask & (1 << i) != 0: r.append(self.varlist[i]) return Comp(r) # Get bit mask of Comp def get_mask(self, x): r = 0 for i in range(len(self.varlist)): if self.varlist[i] in x: r |= (1 << i) return r def super_of(self, other): """Whether self is a superset of other.""" for i in range(len(other.varlist)): if not (other.varlist[i] in self.varlist): return False return True def super_of_index(self, other): r = len(self.varlist) for i in range(len(other.varlist)): if not (other.varlist[i] in self.varlist): return None r = min(r, self.varlist.index(other.varlist[i])) return r def disjoint(self, other): """Whether self is disjoint from other.""" for i in range(len(other.varlist)): if other.varlist[i] in self.varlist: return False return True def __contains__(self, other): if isinstance(other, IVar): return other in self.varlist if isinstance(other, Comp): return self.super_of(other) return False def __ge__(self, other): return self.super_of(other) def __le__(self, other): return other.super_of(self) def __eq__(self, other): return {a.name for a in self.varlist} == {a.name for a in other.varlist} #return self.super_of(other) and other.super_of(self) def __ne__(self, other): return {a.name for a in self.varlist} != {a.name for a in other.varlist} def __gt__(self, other): return self.super_of(other) and not other.super_of(self) def __lt__(self, other): return other.super_of(self) and not self.super_of(other) def tolist(self): return CompArray([a.copy() for a in self]) def z3(self, vardict): if z3 is None: return None n = vardict["#n"] r = None for a in self.varlist: if r is None: r = vardict[a.name] else: r = r | vardict[a.name] if r is None: return z3.BitVecVal(0, n) return r def mean(self, f = None, name = None): """ Returns the expectation of the function f. Parameters ---------- f : function, numpy.array or torch.Tensor If f is a function, the number of arguments must match the number of random variables. If f is an array or tensor, shape must match the shape of the joint distribution. Returns ------- r : Expr The expectation as an expression. """ if name is None: name = "mean_" + str(iutil.get_count("mean")) def fcncall(xdist): return xdist.mean(f) R = Expr.real(name) return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, [self.copy()])) def prob(self, *args): """ Returns the probability mass function at *args. Parameters ---------- *args : int The indices to query. E.g. (X+Y).prob(2,3) is P(X=2, Y=3). Returns ------- r : Expr The probability as an expression. """ args = tuple(args) name = "P(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_STANDARD) + "=" + str(b)) for a, b in zip(self, args)) + ")" name += "@@" + str(PsiOpts.STR_STYLE_PSITIP) + "@@" if len(self) == 1: name += repr(self) else: name += "(" + repr(self) + ")" name += ".prob(" + ",".join(str(b) for b in args) + ")" name += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" name += PsiOpts.settings["latex_prob"] + "(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "=" + str(b)) for a, b in zip(self, args)) + ")" def fcncall(xdist): return xdist[args] R = Expr.real(name) return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, [self.copy()])) def pmf(self, simplex = True): """ Returns the probability mass function as ExprArray. """ shape_in = self.get_shape_in() shape_out = self.get_shape_out() n_shape_out = iutil.product(shape_out) r = ExprArray.zeros(shape_in + shape_out) for xs in itertools.product(*[range(x) for x in shape_in]): tsum = Expr.zero() for iy, ys in enumerate(itertools.product(*[range(y) for y in shape_out])): if iy == n_shape_out - 1 and simplex: r[xs + ys] = 1 - tsum else: r[xs + ys] = self.prob(*(xs + ys)) if simplex: tsum += r[xs + ys] return r def sort(self): self.varlist.sort(key = lambda a: a.name) def tostring(self, style = 0, tosort = False, add_bracket = False): """Convert to string Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) namelist = [a.tostring(style) for a in self.varlist] if len(namelist) == 0: if style & PsiOpts.STR_STYLE_PSITIP: return "rv()" elif style & PsiOpts.STR_STYLE_LATEX: return PsiOpts.settings["latex_rv_empty"] return "!" if tosort: namelist.sort() r = "" if add_bracket and len(namelist) > 1: r += "(" if style & PsiOpts.STR_STYLE_PSITIP: r += "+".join(namelist) elif style & PsiOpts.STR_STYLE_LATEX: r += (PsiOpts.settings["latex_rv_delim"] + " ").join(namelist) else: r += ",".join(namelist) if add_bracket and len(namelist) > 1: r += ")" return r def __str__(self): return self.tostring(PsiOpts.settings["str_style"], tosort = PsiOpts.settings["str_tosort"]) def __repr__(self): return self.tostring(PsiOpts.settings["str_style_repr"]) @latex_postprocess def _latex_(self): return self.tostring(iutil.convert_str_style("latex")) def __hash__(self): #return hash(self.tostring(tosort = True)) return hash(frozenset(a.name for a in self.varlist)) def isregtermpresent(self): for b in self.varlist: if b.reg is not None: return True return False def __and__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented return Term.H(self) & Term.H(other) def __or__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, int): other = Comp.empty() return Term.H(self) | Term.H(other) def rename_var(self, name0, name1): for a in self.varlist: if a.name == name0: a.name = name1 for a in self.varlist: if a.reg is not None: a.reg.rename_var(name0, name1) def rename_map(self, namemap): """Rename according to name map """ for a in self.varlist: a.name = namemap.get(a.name, a.name) for a in self.varlist: if a.reg is not None: a.reg.rename_map(namemap) return self @fcn_substitute def substitute(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound)""" if len(v0) > 1: for v0c in v0: self.substitute(v0c, v1) return v0s = v0.get_name() for i in range(len(self.varlist)): if self.varlist[i].name == v0s: nameset = set() for j in range(len(self.varlist)): if j != i: nameset.add(self.varlist[j].name) self.varlist = self.varlist[:i] + [t.copy() for t in v1.varlist if t.name not in nameset] + self.varlist[i+1:] break # r = Comp.empty() # for a in self.varlist: # if v0.ispresent_shallow(a): # r += v1 # else: # r += Comp([a]) # self.varlist = r.varlist for a in self.varlist: if a.reg is not None: a.reg.substitute(v0, v1) @fcn_substitute def substitute_whole(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound)""" t = self.super_of_index(v0) if t is None: return self -= v0 t = min(t, len(self.varlist)) for x in v1.varlist: if x not in self.varlist: self.varlist.insert(t, x.copy()) t += 1 @staticmethod def substitute_list(a, vlist, suffix = "", isaux = False): """Substitute variables in vlist into a""" if isinstance(vlist, list): for v in vlist: Comp.substitute_list(a, v, suffix, isaux) elif isinstance(vlist, tuple): if len(vlist) >= 2: w = vlist[1] if isinstance(w, list): w = w[0] if len(w) > 0 else Comp.empty() if suffix != "": if isaux: a.substitute_aux(Comp.rv(vlist[0].get_name() + suffix), w) else: a.substitute(Comp.rv(vlist[0].get_name() + suffix), w) else: if isaux: a.substitute_aux(vlist[0], w) else: a.substitute(vlist[0], w) @staticmethod def substitute_list_to_dict(vlist, multi = False): def add_dict(r, key, value): if isinstance(value, list): for t in value: add_dict(r, key, t) return if key not in r: r[key] = value else: if not isinstance(r[key], list): r[key] = [r[key]] if value not in r[key]: r[key].append(value) r = dict() if isinstance(vlist, list): for v in vlist: t = Comp.substitute_list_to_dict(v, multi = multi) if multi: for key, value in t.items(): add_dict(r, key, value) else: r.update(t) elif isinstance(vlist, tuple): if len(vlist) >= 2: if multi: for w in vlist[1:]: add_dict(r, vlist[0], w) else: w = vlist[1] if isinstance(w, list): w = w[0] if len(w) > 0 else Comp.empty() r[vlist[0]] = w return r @staticmethod def substitute_dict_ismulti(d): return any(isinstance(value, list) for key, value in d.items()) def record_to(self, index): index.record(self) for a in self.varlist: if a.reg is not None: index.record(a.reg.allcomprv_noaux()) def fcn_of(self, b): return Expr.Hc(self, b) == 0 def table(self, *args, **kwargs): """Plot the information diagram as a Karnaugh map. """ return universe().table(self, *args, **kwargs) def venn(self, *args, **kwargs): """Plot the information diagram as a Venn diagram. Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5). """ return universe().venn(self, *args, **kwargs) class IVarIndex: """Store index of variables Do NOT use this class directly """ def __init__(self): self.dictrv = {} self.comprv = Comp.empty() self.dictreal = {} self.compreal = Comp.empty() self.prefavoid = "@@" def copy(self): r = IVarIndex() r.dictrv = self.dictrv.copy() r.comprv = self.comprv.copy() r.dictreal = self.dictreal.copy() r.compreal = self.compreal.copy() r.prefavoid = self.prefavoid return r def record(self, x): for i in range(len(x.varlist)): if not x.varlist[i].name.startswith(self.prefavoid): if x.varlist[i].vartype == IVarType.RV: if not (x.varlist[i].name in self.dictrv): self.dictrv[x.varlist[i].name] = self.comprv.size() self.comprv.varlist.append(x.varlist[i]) else: if not (x.varlist[i].name in self.dictreal): self.dictreal[x.varlist[i].name] = self.compreal.size() self.compreal.varlist.append(x.varlist[i]) def add_varindex(self, x): self.record(x.comprv) self.record(x.compreal) # Get index of IVar def get_index(self, x): if isinstance(x, Comp): x = x.varlist[0] if x.vartype == IVarType.RV: if x.name in self.dictrv: return self.dictrv[x.name] return -1 else: if x.name in self.dictreal: return self.dictreal[x.name] return -1 def __contains__(self, other): return self.get_index(other) >= 0 def ispresent(self, x): """Return whether any variable in x appears here""" for y in x: if self.get_index(y) >= 0: return True return False # Get index of name def get_index_name(self, name): if name in self.dictrv: return self.dictrv[name] if name in self.dictreal: return self.dictreal[name] return -1 # Get object of name def get_obj_name(self, name): if name in self.dictrv: return self.comprv[self.dictrv[name]] if name in self.dictreal: return Expr.fromcomp(self.compreal[self.dictreal[name]]) return None def allcomp(self): return self.comprv + self.compreal def find(self, *args): return self.allcomp().find(*args) # Get bit mask of Comp def get_mask(self, x): if x.get_type() != IVarType.RV: return 0 r = 0 for a in x.varlist: k = self.get_index(a) if k < 0: return -1 r |= (1 << k) return r def from_mask(self, m): return self.comprv.from_mask(m) def num_rv(self): return self.comprv.size() def num_real(self): return self.compreal.size() def size(self): return self.comprv.size() + self.compreal.size() def name_avoid_old(self, name0): name1 = name0 while self.get_index_name(name1) >= 0: name1 += PsiOpts.settings["rename_char"] return name1 def name_avoid(self, name0): name1 = name0 rename_char = PsiOpts.settings["rename_char"] k = 1 while self.get_index_name(name1) >= 0: name1 = iutil.set_suffix_num(name0, k, rename_char, replace_mode = "append") k += 1 return name1 def calc_rename_map(self, a): m = dict() for x in a: xstr = "" if isinstance(x, Comp): xstr = x.get_name() else: xstr = str(x) tname = self.name_avoid(xstr) self.record(Comp.rv(tname)) m[xstr] = tname return m PsiOpts.global_index = IVarIndex() class TermType: NIL = 0 IC = 1 REAL = 2 REGION = 3 class TermAllowType: H = 1 HC = 2 I = 4 IC = 8 I3 = 16 DEFAULT = 1 + 2 + 4 + 8 + 16 class Term(IBaseObj): """A term in an expression Do NOT use this class directly. Use Expr instead """ def __init__(self, x, z = None, reg = None, sn = 0, fcncall = None, fcnargs = None, reg_outer = None, z_present = False, termtname = None): self.x = x if z is None: self.z = Comp.empty() else: self.z = z self.reg = reg self.sn = sn self.fcncall = fcncall self.fcnargs = fcnargs self.reg_outer = reg_outer self.z_present = z_present self.termtname = termtname def copy(self): return Term([a.copy() for a in self.x], self.z.copy(), iutil.copy(self.reg), self.sn, self.fcncall, iutil.copy(self.fcnargs), iutil.copy(self.reg_outer), self.z_present, self.termtname) def copy_noreg(self): r = Term([a.copy_noreg() for a in self.x], self.z.copy_noreg(), None, 0) r.termtname = self.termtname return r @staticmethod def zero(): return Term([], Comp.empty()) def isempty(self): """Whether self is empty.""" return len(self.x) == 0 or any(len(a) == 0 for a in self.x) def setzero(self): self.x = [] self.z = Comp.empty() self.reg = None self.reg_outer = None self.sn = 0 self.fcncall = None self.fcnargs = None self.z_present = False self.termtname = None def iszero(self): if self.get_type() == TermType.REGION: return False else: if len(self.x) == 0: return True for a in self.x: if a.isempty(): return True return False @staticmethod def fromcomp(x): return Term([x.copy()], Comp.empty()) @staticmethod def H(x): return Term([x.copy()], Comp.empty()) @staticmethod def I(x, y): return Term([x.copy(), y.copy()], Comp.empty()) @staticmethod def Hc(x, z): return Term([x.copy()], z.copy()) @staticmethod def Ic(x, y, z): return Term([x.copy(), y.copy()], z.copy()) @staticmethod def fcn(fcnname, fcncall, fcnargs): cname = fcnname return Term(Comp.real(cname), Comp.empty(), reg = Region.universe(), sn = 0, fcncall = fcncall, fcnargs = fcnargs) @staticmethod def eps(): """Epsilon.""" return Term([Comp([IVar.eps()])], Comp.empty()) @staticmethod def one(): """One.""" return Term([Comp([IVar.one()])], Comp.empty()) @staticmethod def inf(): """Infinity.""" return Term([Comp([IVar.inf()])], Comp.empty()) def allcomp(self): r = Comp.empty() for a in self.x: r += a r += self.z return r def allcomprv_shallow(self): r = Comp.empty() for a in self.x: r += a r += self.z return Comp([a for a in r.varlist if a.vartype == IVarType.RV]) def get_name(self): return self.allcomp().get_name() def get_maxent_comp(self): if self.termtname is None or self.termtname != "H0": return None if len(self.fcnargs) != 1: return None a = self.fcnargs[0] if not isinstance(a, Term): return None if len(a.x) != 1 or not a.z.isempty(): return None return a.x[0] def size(self): r = self.z.size() for a in self.x: r += a.size() return r def complexity(self): # return (len(self.z) + sum(len(a) for a in self.x)) * 2 + len(self.x) + (not self.z.isempty()) r = len(self.z) + sum(len(a) for a in self.x) * 2 if len(self.x) == 1: r += 2 elif len(self.x) >= 3: r += len(self.x) * 2 if self.get_type() == TermType.REGION: if self.reg is not None: r += self.reg.complexity() return r def sorting_priority(self): return self.complexity() def get_type(self): if self.reg is not None: return TermType.REGION if len(self.x) == 0: return TermType.NIL if self.x[0].get_type() == IVarType.REAL: return TermType.REAL return TermType.IC def iseps(self): if self.get_type() != TermType.REAL: return False return self.x[0].varlist[0].name == "EPS" def isone(self): if self.get_type() != TermType.REAL: return False return self.x[0].varlist[0].name == "ONE" def isinf(self): if self.get_type() != TermType.REAL: return False return self.x[0].varlist[0].name == "INF" def isrealvar(self): if self.get_type() != TermType.REAL: return False return self.x[0].varlist[0].name != "EPS" and self.x[0].varlist[0].name != "ONE" and self.x[0].varlist[0].name != "INF" def isnonneg(self): if self.get_type() == TermType.IC: if PsiOpts.settings.get("quantum", False): if len(self.x) == 1: return self.z.isempty() elif len(self.x) == 2: return self.z.isempty() or (self.x[0].inter(self.x[1])).isempty() return False else: if len(self.x) == 0: return True elif len(self.x) == 1: if self.z.isempty(): return PsiOpts.settings.get("hge0", False) else: return PsiOpts.settings.get("hcge0", False) elif len(self.x) == 2: if self.z.isempty(): return PsiOpts.settings.get("ige0", False) else: return PsiOpts.settings.get("icge0", False) else: return False # return len(self.x) <= 2 if self.isone() or self.iseps() or self.isinf(): return True return False def isic2(self): if self.get_type() == TermType.IC: if len(self.x) != 2: return False return (self.x[0]-self.z).disjoint(self.x[1]-self.z) #return ((self.x[0]+self.x[1]+self.z).size() # == self.x[0].size()+self.x[1].size()+self.z.size()) return False def isic3(self): if self.get_type() == TermType.IC: if len(self.x) != 3: return False return True return False def isicle2(self): if self.get_type() == TermType.IC: return len(self.x) <= 2 return False def ishc(self): if self.get_type() == TermType.IC: if len(self.x) != 1: return False return True return False def isihc2(self): if self.get_type() == TermType.IC: if len(self.x) != 2: return False return True return False def ish(self): if self.get_type() == TermType.IC: if len(self.x) != 1: return False if not self.z.isempty(): return False return True return False def restricted(self, a): if self.get_type() == TermType.IC: r = self.copy() r.x = [y.inter(a) for y in r.x] r.z = r.z.inter(a) return r return None def record_to(self, index): if self.get_type() == TermType.REGION: if self.reg is not None: index.record(self.reg.allcomprv_noaux()) if self.reg_outer is not None: index.record(self.reg_outer.allcomprv_noaux()) if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): a.record_to(index) for a in self.x: a.record_to(index) self.z.record_to(index) def definition(self): """Return the definition of this term. """ if self.get_type() == TermType.REGION: cname = self.x[0].get_name() if cname.find(PsiOpts.settings["fcn_suffix"]) >= 0: return self.substituted(self.x[0], Comp.real(cname.replace(PsiOpts.settings["fcn_suffix"], ""))) return self.copy() def z3(self, vardict): if z3 is None: return None n = vardict["#n"] H = vardict["#H"] r = None if self.get_type() == TermType.IC: for mask in range(1 << len(self.x)): csum = sum((self.x[i] for i in range(len(self.x)) if mask & (1 << i)), self.z.copy()) if csum.isempty(): continue cH = H(csum.z3(vardict)) issub = iutil.bitcount(mask) % 2 == 0 if r is None: if issub: r = -cH else: r = cH else: if issub: r = r - cH else: r = r + cH return r elif self.get_type() == TermType.REAL: if self.isone(): return z3.RealVal(1) return vardict[self.x[0].get_name()] elif self.get_type() == TermType.REGION: raise ValueError("Optimization quantities are not supported for Z3. Use another solver.") def get_shape(self): r = [] for a in itertools.chain(self.z, *(self.x)): t = a.get_card() if t is None: raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.") return r.append(t) return tuple(r) def get_shape_in(self): r = [] for a in self.z: t = a.get_card() if t is None: raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.") return r.append(t) return tuple(r) def get_shape_out(self): r = [] for a in itertools.chain(*(self.x)): t = a.get_card() if t is None: raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.") return r.append(t) return tuple(r) def prob(self, *args): """ Returns the conditional probability mass function at *args. Parameters ---------- *args : int The indices to query. E.g. (X+Y|Z).prob(2,3,4) is P(X=3, Y=4 | Z=2). Returns ------- r : Expr The conditional probability as an expression. """ cx = [] for a in self.x: cx += list(a) args = tuple(args) name = "P(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_STANDARD) + "=" + str(b)) for a, b in zip(cx, args[len(self.z):])) + "|" name += ",".join((a.tostring(style = PsiOpts.STR_STYLE_STANDARD) + "=" + str(b)) for a, b in zip(self.z, args[:len(self.z)])) + ")" name += "@@" + str(PsiOpts.STR_STYLE_PSITIP) + "@@" name += "(" + "&".join(repr(a) for a in cx) + "|" + repr(self.z) + ")" name += ".prob(" + ",".join(str(b) for b in args) + ")" name += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" name += PsiOpts.settings["latex_prob"] + "(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "=" + str(b)) for a, b in zip(cx, args[len(self.z):])) + "|" name += ",".join((a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "=" + str(b)) for a, b in zip(self.z, args[:len(self.z)])) + ")" def fcncall(xdist): return xdist[args] R = Expr.real(name) return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, [self.copy()])) # def fcncall(P): # return P[self][args] # R = Expr.real(name) # return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, ["model"])) def pmf(self, simplex = True): """ Returns the probability mass function as ExprArray. """ shape_in = self.get_shape_in() shape_out = self.get_shape_out() n_shape_out = iutil.product(shape_out) r = ExprArray.zeros(shape_in + shape_out) for xs in itertools.product(*[range(x) for x in shape_in]): tsum = Expr.zero() for iy, ys in enumerate(itertools.product(*[range(y) for y in shape_out])): if iy == n_shape_out - 1 and simplex: r[xs + ys] = 1 - tsum else: r[xs + ys] = self.prob(*(xs + ys)) if simplex: tsum += r[xs + ys] return r @staticmethod def fcneval(fcncall, fcnargs): if fcncall == "*": return fcnargs[0] * fcnargs[1] elif fcncall == "/": return fcnargs[0] / fcnargs[1] elif fcncall == "**": return fcnargs[0] ** fcnargs[1] else: return fcncall(*fcnargs) def get_fcneval(self, fcnargs): return Term.fcneval(self.fcncall, fcnargs) def value(self, method = "", num_iter = 30, prog = None): if self.isone(): return 1.0 if self.iseps(): return 0.0 if self.isinf(): return float("inf") if self.reg is None: return None if self.sn == 0: return None ms = method.split(",") if isinstance(self.reg, RegionOp): ms.append("bsearch") if "bsearch" in ms: selfterm = Expr.real(str(self)) return self.sn * iutil.gbsearch( lambda x: self.reg.implies(self.sn * selfterm <= x), num_iter = num_iter) else: selfterm = Expr.real(str(self)) cs = self.reg.consonly().imp_flipped() index = IVarIndex() cs.record_to(index) r = [] dual_enabled = None val_enabled = None if "dual" in ms: dual_enabled = True if "val" in ms: val_enabled = True cprog = cs.init_prog(index, lp_bounded = False, dual_enabled = dual_enabled, val_enabled = val_enabled) cprog.checkexpr_ge0(-self.sn * selfterm, optval = r) if prog is not None: prog.append(cprog) if len(r) == 0: return None return r[0] * -self.sn def get_reg_sgn_bds(self): if self.get_type() != TermType.REGION: return None if self.sn == 0: return None reg = self.reg.copy() if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: return None sn = self.sn tbds = reg.get_lb_ub_eq(self) reg.eliminate_term(self) reg.simplify_quick() if sn > 0: return (reg, sn, tbds[1]) else: return (reg, sn, tbds[0]) def sort(self): for a in self.x: a.sort() self.z.sort() # self.x.sort(key = lambda a: (a.sorting_priority(), str(a))) self.x.sort(key = lambda a: a.sorting_tuple()) def tostring(self, style = 0, tosort = False, add_bracket = False): """Convert to string Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) termType = self.get_type() if termType == TermType.NIL: return "0" elif termType == TermType.REAL: return self.x[0].tostring(style = style, tosort = tosort) elif termType == TermType.IC: r = "" if len(self.x) == 1: if style & PsiOpts.STR_STYLE_LATEX: r += PsiOpts.settings["latex_H"] else: r += "H" else: if style & PsiOpts.STR_STYLE_LATEX: r += PsiOpts.settings["latex_I"] else: r += "I" r += "(" namelist = [a.tostring(style = style, tosort = tosort) for a in self.x] if tosort: namelist.sort() if style & PsiOpts.STR_STYLE_PSITIP: r += "&".join(namelist) elif style & PsiOpts.STR_STYLE_LATEX: r += (PsiOpts.settings["latex_mi_delim"] + " ").join(namelist) else: r += ";".join(namelist) if self.z.size() > 0: if style & PsiOpts.STR_STYLE_LATEX: r += PsiOpts.settings["latex_cond"] else: r += "|" r += self.z.tostring(style = style, tosort = tosort) r += ")" return r elif termType == TermType.REGION: if self.x[0].varlist[0].name.find(PsiOpts.settings["fcn_suffix"]) >= 0: return self.x[0].tostring(style = style, tosort = tosort) reg = self.reg sn = self.sn bds = [self.copy_noreg()] rsb = self.get_reg_sgn_bds() if rsb is not None and (not style & PsiOpts.STR_STYLE_PSITIP or len(rsb[2]) == 1 or rsb[0].isuniverse()): reg, sn, bds = rsb # elif rsb is not None and rsb[0].isuniverse(): # reg, sn, bds = rsb # sn *= -1 reg_universe = reg.isuniverse(canon = True) if len(bds) == 0: if style & PsiOpts.STR_STYLE_LATEX: if sn > 0: return PsiOpts.settings["latex_infty"] else: if add_bracket: return "(-" + PsiOpts.settings["latex_infty"] + ")" else: return "-" + PsiOpts.settings["latex_infty"] else: if sn > 0: return "INF" else: if add_bracket: return "(-" + "INF" + ")" else: return "-" + "INF" if style & PsiOpts.STR_STYLE_LATEX: r = "" if not reg_universe: if sn > 0: r += PsiOpts.settings["latex_sup"] else: r += PsiOpts.settings["latex_inf"] r += "_{" r += reg.tostring(style = style & ~PsiOpts.STR_STYLE_LATEX_ARRAY, tosort = tosort, small = True, skip_outer_exists = True) r += "}" if len(bds) > 1: if sn > 0: r += PsiOpts.settings["latex_min"] else: r += PsiOpts.settings["latex_max"] r += "\\left(" r += ",\\, ".join(b.tostring(style = style, tosort = tosort, add_bracket = len(bds) == 1 and (add_bracket or not reg_universe)) for b in bds) if len(bds) > 1: r += "\\right)" return r else: r = "" if not reg_universe: r += "(" r += reg.tostring(style = style, tosort = tosort, small = True) r += ")" if sn > 0: r += ".maximum" else: r += ".minimum" r += "(" if len(bds) > 1: if style & PsiOpts.STR_STYLE_PSITIP: if sn > 0: r += "emin" else: r += "emax" else: if sn > 0: r += "min" else: r += "max" r += "(" r += ", ".join(b.tostring(style = style, tosort = tosort, add_bracket = len(bds) == 1 and (add_bracket or not reg_universe)) for b in bds) if len(bds) > 1: r += ")" if not reg_universe: r += ")" return r return "" def __str__(self): return self.tostring(PsiOpts.settings["str_style"], tosort = PsiOpts.settings["str_tosort"]) def __repr__(self): return self.tostring(PsiOpts.settings["str_style_repr"]) @latex_postprocess def _latex_(self): return self.tostring(iutil.convert_str_style("latex")) def __hash__(self): #return hash(self.tostring(tosort = True)) return hash((frozenset(hash(a) for a in self.x), hash(self.z))) def simplify(self, reg = None, bnet = None): if self.get_type() == TermType.IC: for i in range(len(self.x)): self.x[i] -= self.z for i in range(len(self.x)): if self.x[i].isempty(): self.x = [] self.z = Comp.empty() return self for i in range(len(self.x)): for j in range(len(self.x)): if i != j and (self.x[j] is not None) and self.x[i].super_of(self.x[j]): self.x[i] = None break self.x = [a for a in self.x if a is not None] if bnet is not None: for i in range(len(self.x) + 1): cc = None if i == len(self.x): cc = self.z else: if len(self.x) == 1: continue cc = self.x[i] j = 0 while j < len(cc.varlist): if bnet.check_ic(Expr.Ic(cc[j], sum((self.x[i2] for i2 in range(len(self.x)) if i2 != i), Comp.empty()), (Comp.empty() if i == len(self.x) else self.z) + sum((cc[j2] for j2 in range(len(cc.varlist)) if j2 != j), Comp.empty()))): cc.varlist.pop(j) else: j += 1 if len(self.x) == 3: for i in range(len(self.x)): if bnet.check_ic(Expr.Ic(self.x[(i + 1) % 3], self.x[(i + 2) % 3], self.x[i] + self.z)): self.x.pop(i) break if len(self.x) == 2: for i in range(2): j = 0 while j < len(self.x[i]): if bnet.check_ic(Expr.Ic(self.x[i][j], self.x[1 - i], self.z)): self.z.varlist.append(self.x[i].varlist[j]) self.x[i].varlist.pop(j) else: j += 1 return self def simplified(self, reg = None, bnet = None): r = self.copy() r.simplify(reg, bnet) return r def simplify_quick(self, **kwargs): return self.simplify(**kwargs) def simplified_quick(self, **kwargs): return self.simplified(**kwargs) def simplify_regterm_perform(self, reg = None): # print("A") # print(self) # print(self.reg) did = False if self.reg is not None: self.reg.simplify(reg = reg) did = True if self.reg_outer is not None: self.reg_outer.simplify(reg = reg) did = True # print(self.reg) return did def simplify_regterm_expr(self, reg = None): if self.simplify_regterm_perform(reg = reg): if self.reg_outer is None: rsb = self.get_reg_sgn_bds() if rsb is not None: # print(rsb) reg, sn, bds = rsb if len(bds) <= 1 and reg.isuniverse(canon = True): if len(bds) == 0: return Expr.inf() * (1 if sn > 0 else -1) else: return bds[0] return Expr.fromterm(self) return None def match_x(self, other): viss = [-1] * len(self.x) viso = [-1] * len(other.x) for i in range(len(self.x)): for j in range(len(other.x)): if viso[j] < 0 and self.x[i] == other.x[j]: viss[i] = j viso[j] = i break return (viss, viso) def __eq__(self, other): if self.z != other.z: return False if len(self.x) != len(other.x): return False viso = [-1] * len(other.x) for i in range(len(self.x)): found = False for j in range(len(other.x)): if viso[j] < 0 and self.x[i] == other.x[j]: found = True viso[j] = i break if not found: return False # (viss, viso) = self.match_x(other) # if -1 in viso: # return False return True def try_remove(self, x, sn): if not self.ispresent(x): return True if self.get_type() == TermType.IC: if len(self.x) >= 3: return False if len(self.x) == 2 and self.z.ispresent(x): return False if sn > 0: if any(x2.ispresent(x) for x2 in self.x): return False if sn < 0: if self.z.ispresent(x): return False for x2 in self.x: x2 -= x self.z -= x return True else: return False def z_can_extend_to(self, target, bnet = None): if self.z == target: return True if bnet is None: return False if not target.super_of(self.z): return False return bnet.check_ic(Expr.Ic(target - self.z, sum(self.x, Comp.empty()), self.z)) def try_iadd(self, other, bnet = None, term_allow = None): if self.iseps() and other.iseps(): return True # if self.isinf(): # return True if self.get_type() == TermType.IC and other.get_type() == TermType.IC: # H(X|Y) + I(X;Y) = H(X) if len(self.x) + 1 == len(other.x): (viss, viso) = self.match_x(other) if viso.count(-1) == 1: j = viso.index(-1) selfz = self.z + other.x[j] + other.z if self.z_can_extend_to(selfz, bnet): otherz = selfz - other.x[j] if other.z_can_extend_to(otherz, bnet): self.z = otherz return True # if other.x[j].disjoint(other.z) and self.z == other.x[j] + other.z: # self.z -= other.x[j] # return True # H(X|Y) + H(Y) = H(X,Y) if len(self.x) == len(other.x): (viss, viso) = self.match_x(other) if viss.count(-1) == 1 and viso.count(-1) == 1: i = viss.index(-1) j = viso.index(-1) selfz = self.z + other.x[j] + other.z if self.z_can_extend_to(selfz, bnet): otherz = selfz - other.x[j] if other.z_can_extend_to(otherz, bnet): self.x[i] += other.x[j] self.z = otherz return True # if other.x[j].disjoint(other.z) and self.z == other.x[j] + other.z: # self.x[i] += other.x[j] # self.z -= other.x[j] # return True return False def try_isub(self, other, bnet = None, term_allow = None): if self.get_type() == TermType.IC and other.get_type() == TermType.IC: # H(X) - I(X;Y) = H(X|Y) if len(self.x) + 1 == len(other.x) and (term_allow is None or term_allow & TermAllowType.HC): selfz = self.z + other.z if self.z_can_extend_to(selfz, bnet) and other.z_can_extend_to(selfz, bnet): (viss, viso) = self.match_x(other) if viso.count(-1) == 1: j = viso.index(-1) self.z = selfz + other.x[j] return True # H(X) - H(X|Y) = I(X;Y) if len(self.x) == len(other.x) and (term_allow is None or (term_allow & TermAllowType.I and (len(self.x) < 2 or term_allow & TermAllowType.I3))): otherz = other.z + self.z if other.z_can_extend_to(otherz, bnet): (viss, viso) = self.match_x(other) if viso.count(-1) == 0: self.x.append(otherz - self.z) return True # H(X,Y) - H(X) = H(Y|X) if len(self.x) == len(other.x) and (term_allow is None or term_allow & TermAllowType.HC): selfz = self.z + other.z if self.z_can_extend_to(selfz, bnet) and other.z_can_extend_to(selfz, bnet): (viss, viso) = self.match_x(other) if viss.count(-1) == 1 and viso.count(-1) == 1: i = viss.index(-1) j = viso.index(-1) if (self.x[i].super_of(other.x[j]) or (len(self.x) == 2 and bnet is not None and bnet.check_ic(Expr.Ic(other.x[j], self.x[1 - i], self.x[i] + selfz)))): self.x[i] -= other.x[j] self.z = selfz + other.x[j] return True # H(X,Y) - H(X|Y) = H(Y) if len(self.x) == len(other.x): (viss, viso) = self.match_x(other) if viss.count(-1) == 1 and viso.count(-1) == 1: i = viss.index(-1) j = viso.index(-1) if self.x[i].super_of(other.x[j]): otherz = (self.x[i] - other.x[j]) + self.z if other.z_can_extend_to(otherz, bnet): self.x[i] -= other.x[j] return True # if self.x[i] - other.x[j] == other.z - self.z: # self.x[i] -= other.x[j] # return True return False def try_isub_flipsign(self, other, bnet = None, term_allow = None): if self.get_type() == TermType.IC and other.get_type() == TermType.IC and (term_allow is None or term_allow & TermAllowType.IC): # I(X+U & Y+Z) - I(X & Y) = I(U & Y | X) + I(X+U & Z | Y) if len(self.x) == 2 and len(other.x) == 2: selfz = self.z + other.z if self.z_can_extend_to(selfz, bnet) and other.z_can_extend_to(selfz, bnet): s0 = self.x[0] s1 = self.x[1] for it in range(2): o0 = other.x[it] o1 = other.x[1 - it] if s0.super_of(o0) and s1.super_of(o1): self.x[0] = s0 - o0 self.x[1] = o1 self.z = selfz + o0 other.x[0] = s0 other.x[1] = s1 - o1 other.z = selfz + o1 return True return False def __and__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, Comp): other = Term.H(other) return Term([a.copy() for a in self.x] + [a.copy() for a in other.x], self.z + other.z, z_present = self.z_present or other.z_present) def __or__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, Comp): other = Term.H(other) if isinstance(other, int): other = Term.H(Comp.empty()) return Term([a.copy() for a in self.x], self.z + other.allcomp(), z_present = True) def symbol_split(self): r = [] for a in self.x: if len(r): r.append("&") r.append(a.copy()) if (not self.z.isempty()) or self.z_present: r.append("|") if not self.z.isempty(): r.append(self.z.copy()) return r @staticmethod def from_symbols(symbols, prefer_multi = False): cs = [] for t in symbols: if isinstance(t, Term): t2 = t.symbol_split() cs += t2 elif isinstance(t, Comp): cs.append(t.copy()) elif isinstance(t, str) and (t == "&" or t == "|"): cs.append(t) elif isinstance(t, ConcDist): return t else: cs.append(iutil.ensure_comp(t).copy()) r = Term.H(Comp.empty()) z_present = False for t in cs: if isinstance(t, str): if t == "&": r.x.append(Comp.empty()) elif t == "|": z_present = True elif isinstance(t, Comp): if z_present: r.z += t else: r.x[-1] += t if prefer_multi and len(symbols) >= 2 and len(r.x) <= 1: rsymbols = [] have_bar = False for t in symbols: if rsymbols and not have_bar: rsymbols.append("&") rsymbols.append(t) if isinstance(t, str) and t == "|": have_bar = True elif isinstance(t, Term) and len(t.z): have_bar = True return Term.from_symbols(rsymbols, prefer_multi=False) return r def istight(self, canon = False): if self.get_type() == TermType.REGION: if self.reg_outer is None: return True if canon: return False else: return self.reg_outer.implies(self.reg) return True def tighten(self): if self.istight(canon = False): self.reg_outer = None def lu_bound(self, sn, name = None): if self.get_type() != TermType.REGION: return self.copy() r = self.copy() if name is None: name = self.x[0].get_name() if sn > 0: name += "_LB" else: name += "_UB" r.substitute(self.x[0], Comp.real(name)) if sn * self.sn < 0: if r.reg_outer is not None: r.reg = r.reg_outer r.reg_outer = None return r def lower_bound(self, name = None): return self.lu_bound(1, name = name) def upper_bound(self, name = None): return self.lu_bound(-1, name = name) def ispresent(self, x): """Return whether any variable in x appears here""" if isinstance(x, IVar): x = Comp([x]) if not isinstance(x, str) and not isinstance(x, Comp): x = x.allcomp() if self.get_type() == TermType.REGION: if self.reg is not None and self.reg.ispresent(x): return True if self.reg_outer is not None and self.reg_outer.ispresent(x): return True if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): if a.ispresent(x): return True xlist = [] if isinstance(x, str): xlist = [x] else: xlist = x.varlist for y in xlist: for a in self.x: if a.ispresent(y): return True if self.z.ispresent(y): return True return False def __contains__(self, other): condition_included = PsiOpts.settings.get("condition_included", False) x = other if isinstance(x, IVar): x = Comp([x]) if self.get_type() == TermType.REGION: if self.reg is not None and x in self.reg: return True if self.reg_outer is not None and x in self.reg_outer: return True if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): if x in a: return True if condition_included: for a in self.x: if x in a + self.z: return True else: for a in self.x: if x in a: return True if x in self.z: return True return False def rename_var(self, name0, name1): if self.get_type() == TermType.REGION: if self.reg is not None: self.reg.rename_var(name0, name1) if self.reg_outer is not None: self.reg_outer.rename_var(name0, name1) if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): a.rename_var(name0, name1) for a in self.x: a.rename_var(name0, name1) self.z.rename_var(name0, name1) def rename_map(self, namemap): """Rename according to name map """ if self.get_type() == TermType.REGION: if self.reg is not None: self.reg.rename_map(namemap) if self.reg_outer is not None: self.reg_outer.rename_map(namemap) if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): a.rename_map(namemap) for a in self.x: a.rename_map(namemap) self.z.rename_map(namemap) return self @fcn_substitute def substitute(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound)""" if self.get_type() == TermType.REGION: if self.reg is not None: self.reg.substitute(v0, v1) if self.reg_outer is not None: self.reg_outer.substitute(v0, v1) if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): a.substitute(v0, v1) for a in self.x: a.substitute(v0, v1) self.z.substitute(v0, v1) @fcn_substitute def substitute_whole(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound)""" condition_included = PsiOpts.settings.get("condition_included", False) if self.get_type() == TermType.REGION: if self.reg is not None: self.reg.substitute_whole(v0, v1) if self.reg_outer is not None: self.reg_outer.substitute_whole(v0, v1) if self.fcnargs is not None: for a in self.fcnargs: if isinstance(a, (IVar, Comp, Term, Expr, Region)): a.substitute_whole(v0, v1) if condition_included: for a in self.x: a += self.z for a in self.x: a.substitute_whole(v0, v1) self.z.substitute_whole(v0, v1) if condition_included: for a in self.x: a -= self.z def get_var_avoid(self, v): r = None for a in self.x + [Comp.empty()]: b = a + self.z if b.ispresent(v): b = b - v if r is None: r = b else: r = r.inter(b) return r class Expr(IBaseObj): """An expression """ def __init__(self, terms, mhash = None, meta = None): self.terms = terms self.mhash = mhash self.meta = meta def copy(self): return Expr([(a.copy(), c) for (a, c) in self.terms], self.mhash, iutil.copy(self.meta)) def copy_(self, other): self.terms = [(a.copy(), c) for (a, c) in other.terms] self.mhash = other.mhash self.meta = iutil.copy(other.meta) def copy_noreg(self): return Expr([(a.copy_noreg(), c) for (a, c) in self.terms], None) @staticmethod def parse(s): """Parse a string, e.g. I(X;Y,Z|W) + 2H(X Z) """ return RegionParser.parse_default(s) @staticmethod def fromcomp(x): return Expr([(Term.fromcomp(x), 1.0)]) @staticmethod def fromterm(x): return Expr([(x.copy(), 1.0)]) @staticmethod def fcn(fcncall, name = None): """Wrap any function mapping a ConcModel to a number as an Expr. E.g. the Hamming distortion is given by Expr.fcn(lambda P: P[X+Y].mean(lambda x, y: float(x != y))). For optimization using PyTorch, the return value should be a scalar torch.Tensor with gradient information. """ if name is None: name = "fcn_" + str(iutil.get_count("fcn")) return Expr.fromterm(Term([Comp.real(name)], Comp.empty(), Region.universe(), 0, fcncall, ["model"])) def find(self, *args): return self.allcomp().find(*args) def get_const(self, only = True): r = 0.0 for (a, c) in self.terms: if a.isone(): r += c else: if only: return None return r def split_ic_real(self): r = [Expr.zero(), Expr.zero()] for (a, c) in self.terms: if a.get_type() == TermType.IC: r[0].terms.append((a, c)) else: r[1].terms.append((a, c)) return r def get_name(self): return self.allcomp().get_name() def __len__(self): return len(self.terms) def len_iccount(self): r = 0 for (a, c) in self.terms: if a.get_type() == TermType.IC: r += len(a.x) else: r += 1 return r def get_maxent_comp(self): if len(self.terms) != 1: return None return self.terms[0][0].get_maxent_comp() def __bool__(self): return bool(self.terms) def __getitem__(self, key): r = self.terms[key] if isinstance(r, list): return Expr(r) return Expr([r]) def allcomprealvar(self): r = Comp.empty() for a, c in self.terms: if a.isrealvar() and a.fcncall is None: r += a.x[0] return r def allcompreal_exprlist(self): r = ExprArray([]) for a, c in self.terms: if a.get_type() == TermType.REAL: r.iadd_noduplicate([Expr.fromterm(a)]) return r def allcomprealvar_exprlist(self): r = ExprArray([]) for a, c in self.terms: if a.isrealvar() and a.fcncall is None: r.iadd_noduplicate([Expr.fromterm(a)]) return r def __iadd__(self, other): if isinstance(other, Comp): other = Expr.fromcomp(other) other = iutil.ensure_expr(other) self.terms += [(a.copy(), c) for (a, c) in other.terms] self.mhash = None return self def __add__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, Comp): other = Expr.fromcomp(other) other = iutil.ensure_expr(other) return Expr([(a.copy(), c) for (a, c) in self.terms] + [(a.copy(), c) for (a, c) in other.terms], meta = iutil.meta_concat([self.meta, other.meta])) def __radd__(self, other): if isinstance(other, Comp): other = Expr.fromcomp(other) other = iutil.ensure_expr(other) return Expr([(a.copy(), c) for (a, c) in other.terms] + [(a.copy(), c) for (a, c) in self.terms], meta = iutil.meta_concat([other.meta, self.meta])) def __neg__(self): return Expr([(a.copy(), -c) for (a, c) in self.terms], None, iutil.copy(self.meta)) def __isub__(self, other): if isinstance(other, Comp): other = Expr.fromcomp(other) other = iutil.ensure_expr(other) self.terms += [(a.copy(), -c) for (a, c) in other.terms] self.mhash = None return self def __sub__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, Comp): other = Expr.fromcomp(other) other = iutil.ensure_expr(other) return Expr([(a.copy(), c) for (a, c) in self.terms] + [(a.copy(), -c) for (a, c) in other.terms], meta = iutil.meta_concat([self.meta, other.meta])) def __rsub__(self, other): if isinstance(other, Comp): other = Expr.fromcomp(other) other = iutil.ensure_expr(other) return Expr([(a.copy(), c) for (a, c) in other.terms] + [(a.copy(), -c) for (a, c) in self.terms], meta = iutil.meta_concat([other.meta, self.meta])) def __imul__(self, other): if isinstance(other, Expr): tother = other.get_const() if tother is None: return self * other # raise ValueError("Multiplication with non-constant expression is not supported.") else: other = tother self.terms = [(a, c * other) for (a, c) in self.terms] self.mhash = None return self def __mul__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, int) or isinstance(other, float): if other == 0: return Expr.zero() if isinstance(other, Expr): tother = other.get_const() if tother is None: tself = self.get_const() if tself is None: # raise ValueError("Multiplication with non-constant expression is not supported.") return Expr.fromterm(Term(Comp.real( iutil.fcn_name_maker("*", [self, other], lname = PsiOpts.settings["latex_times"] + " ", infix = True) ), reg = Region.universe(), fcncall = "*", fcnargs = [self, other])) else: return other * tself else: other = tother return Expr([(a.copy(), c * other) for (a, c) in self.terms], None, iutil.copy(self.meta)) def __rmul__(self, other): if isinstance(other, int) or isinstance(other, float): if other == 0: return Expr.zero() if isinstance(other, Expr): tother = other.get_const() if tother is None: # raise ValueError("Multiplication with non-constant expression is not supported.") return Expr.fromterm(Term(Comp.real( iutil.fcn_name_maker("*", [other, self], lname = PsiOpts.settings["latex_times"] + " ", infix = True) ), reg = Region.universe(), fcncall = "*", fcnargs = [other, self])) else: other = tother return Expr([(a.copy(), c * other) for (a, c) in self.terms], None, iutil.copy(self.meta)) def __itruediv__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented if isinstance(other, Expr): tother = other.get_const() if tother is None: return self / other # raise ValueError("In-place division with non-constant expression is not supported.") else: other = tother self.terms = [(a, c / other) for (a, c) in self.terms] self.mhash = None return self def __pow__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented return Expr.fromterm(Term(Comp.real( iutil.fcn_name_maker("**", [self, other], lname = "^", infix = True, latex_group = True) ), reg = Region.universe(), fcncall = "**", fcnargs = [self, other])) def __ipow__(self, other): return self ** other def __rpow__(self, other): return Expr.const(other) ** self def __invert__(self): return self == 0 def record_to(self, index): for (a, c) in self.terms: a.record_to(index) def regtermmap(self, cmap, recur): return (self >= 0).regtermmap(cmap, recur) def allcomp(self): r = Comp.empty() for (a, c) in self.terms: r += a.allcomp() return r def allcomprv_shallow(self): r = Comp.empty() for (a, c) in self.terms: r += a.allcomprv_shallow() return r def size(self): return len(self.terms) def iszero(self): """Whether the expression is zero""" return len(self.terms) == 0 def setzero(self): """Set expression to zero""" self.terms = [] self.mhash = None def isnonneg(self): """Whether the expression is always nonnegative""" for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if c < 0: return False if not a.isnonneg(): return False return True def isnonpos(self): """Whether the expression is always nonpositive""" for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if c > 0: return False if not a.isnonneg(): return False return True def isnonneg_ic2(self): for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if c < 0: return False if not a.isic2(): return False return True def isnonpos_ic2(self): for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if c > 0: return False if not a.isic2(): return False return True def isnonpos_hc(self): for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if c > 0: return False if not a.ishc(): return False return True def isrealvar(self): if len(self.terms) != 1: return False a, c = self.terms[0] if abs(c - 1.0) > PsiOpts.settings["eps"]: return False return a.isrealvar() @staticmethod def zero(): """The constant zero expression.""" return Expr([]) @staticmethod def H(x): """Entropy.""" return Expr([(Term.H(x), 1.0)]) @staticmethod def I(x, y): """Mutual information.""" return Expr([(Term.I(x, y), 1.0)]) @staticmethod def Hc(x, z): """Conditional entropy.""" return Expr([(Term.Hc(x, z), 1.0)]) @staticmethod def Ic(x, y, z): """Conditional mutual information.""" return Expr([(Term.Ic(x, y, z), 1.0)]) @staticmethod def real(name): """Real variable.""" if isinstance(name, IVar): return Expr([(Term([Comp([name])], Comp.empty()), 1.0)]) if isinstance(name, Comp): return Expr([(Term([name], Comp.empty()), 1.0)]) return Expr([(Term([Comp.real(name)], Comp.empty()), 1.0)]) @staticmethod def eps(): """Epsilon.""" return Expr([(Term([Comp([IVar.eps()])], Comp.empty()), 1.0)]) @staticmethod def one(): """One.""" return Expr([(Term([Comp([IVar.one()])], Comp.empty()), 1.0)]) @staticmethod def inf(): """Infinity.""" return Expr([(Term([Comp([IVar.inf()])], Comp.empty()), 1.0)]) @staticmethod def const(c): """Constant.""" if abs(c) <= PsiOpts.settings["eps"]: return Expr.zero() return Expr([(Term([Comp([IVar.one()])], Comp.empty()), float(c))]) def commonpart_coeff(self, v, forbid_h = True): r = 0.0 for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if a.get_type() == TermType.IC: if not isinstance(v, list) and not isinstance(v, tuple): if not a.z.ispresent(v) and all(t.ispresent(v) for t in a.x): r += c else: if all(not t.ispresent(vt) for t in a.x for vt in v): continue # if len(a.x) > 2: # return None v2 = [vt for vt in v if not a.z.ispresent(vt)] if len(v2) == 0: continue if any(all(not t.ispresent(vt) for vt in v2) for t in a.x): continue if len(a.x) > 2: tr = (Expr.fromterm(Term(a.x[1:], a.z)) - Expr.fromterm(Term(a.x[1:], a.z+a.x[0]))).commonpart_coeff(v) if tr is None: return None r += c * tr else: for vt in v2: tpres = [t.ispresent(vt) for t in a.x] if all(tpres): if forbid_h: r += c * numpy.inf elif not any(tpres): break else: r += c return r def ent_coeff(self, v): r = 0.0 for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue if a.get_type() == TermType.IC: if all(t.ispresent(v) for t in a.x) and not a.z.ispresent(v): r += c return r def var_mi_only(self, v): return abs(self.ent_coeff(v)) <= PsiOpts.settings["eps"] def z3(self, vardict): if z3 is None: return None r = None for (a, c) in self.terms: t = a.z3(vardict) if abs(c) != 1: t = t * iutil.float_toz3(abs(c)) if c >= 0: if r is None: r = t else: r = r + t else: if r is None: r = -t else: r = r - t return r def istight(self, canon = False): return all(a.istight(canon) for a, c in self.terms) def tighten(self): for a, c in self.terms: a.tighten() def lu_bound(self, sn, name = None): return Expr([(a.lu_bound(1 if sn * c >= 0 else -1, name = name), c) for a, c in self.terms]) def lower_bound(self, name = None): return self.lu_bound(1, name = name) def upper_bound(self, name = None): return self.lu_bound(-1, name = name) def get_reg_sgn_bds(self): reg = Region.universe() sn = 0 bds = [] rest = Expr.zero() for (a, c) in self.terms: if abs(c) <= PsiOpts.settings["eps"]: continue t = a.get_reg_sgn_bds() if t is None: rest.terms.append((a, c)) else: if sn != 0: return None reg, sn, bds = t bds = [b * c for b in bds] if c < 0: sn = -sn return (reg, sn, [b + rest for b in bds]) def sort(self): for x, c in self.terms: x.sort() self.terms.sort(key = lambda a: (-round(a[1] * 1000.0), a[0].sorting_priority(), str(a[0]))) def sorting_tuple_eqn(self): s = self.tostring_eqn(">=", lhsvar = "real") return (self.sorting_priority(), s) def tostring(self, style = 0, tosort = False, add_bracket = False, tosort_pm = False): """Convert to string Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) termlist = self.terms if tosort: termlist = sorted(termlist, key=lambda a: (-round(a[1] * 1000.0), a[0].tostring(style = style, tosort = tosort))) elif tosort_pm: termlist = sorted(termlist, key=lambda a: a[1] < 0) use_bracket = add_bracket and len(termlist) >= 2 float_style = self.get_meta("float_style") r = "" if use_bracket: if style & PsiOpts.STR_STYLE_LATEX: r += "\\left(" else: r += "(" first = True for (a, c) in termlist: if abs(c) <= PsiOpts.settings["eps"]: continue if c > 0.0 and not first: r += "+" if a.isone(): r += iutil.float_tostr(c, style, bracket = False, force_float = float_style) else: need_bracket = False if abs(c - 1.0) < PsiOpts.settings["eps"]: pass elif abs(c + 1.0) < PsiOpts.settings["eps"]: r += "-" need_bracket = True else: r += iutil.float_tostr(c, style, force_float = float_style) if style & PsiOpts.STR_STYLE_PSITIP: r += "*" need_bracket = True r += a.tostring(style = style, tosort = tosort, add_bracket = need_bracket) first = False if r == "": return "0" if use_bracket: if style & PsiOpts.STR_STYLE_LATEX: r += "\\right)" else: r += ")" return r def __str__(self): return self.tostring(PsiOpts.settings["str_style"], tosort = PsiOpts.settings["str_tosort"]) def __repr__(self): if PsiOpts.settings.get("repr_simplify", False): return self.simplified_quick().tostring(PsiOpts.settings["str_style_repr"]) return self.tostring(PsiOpts.settings["str_style_repr"]) @latex_postprocess def _latex_(self): if PsiOpts.settings.get("repr_simplify", False): return self.simplified_quick().tostring(iutil.convert_str_style("latex")) return self.tostring(iutil.convert_str_style("latex")) def __hash__(self): if self.mhash is None: #self.mhash = hash(self.tostring(tosort = True)) self.mhash = hash(tuple(sorted((hash(a), c) for a, c in self.terms))) return self.mhash def table(self, *args, **kwargs): """Plot the information diagram as a Karnaugh map. """ return universe().table(*args, self, **kwargs) def venn(self, *args, **kwargs): """Plot the information diagram as a Venn diagram. Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5). """ return universe().venn(*args, self, **kwargs) def isregtermpresent(self): for (a, c) in self.terms: rvs = a.allcomprv_shallow() for b in rvs.varlist: if b.reg is not None: return True if a.get_type() == TermType.REGION: return True return False def sortIc(self): def sortkey(a): x = a[0] if x.isic2(): return x.size() else: return 100000 self.terms.sort(key=sortkey) self.mhash = None def __le__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented other = iutil.ensure_expr(other) return Region([other - self], [], Comp.empty(), Comp.empty(), Comp.empty()) def __lt__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented other = iutil.ensure_expr(other) return Region([other - self - Expr.eps()], [], Comp.empty(), Comp.empty(), Comp.empty()) def __ge__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented other = iutil.ensure_expr(other) return Region([self - other], [], Comp.empty(), Comp.empty(), Comp.empty()) def __gt__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented other = iutil.ensure_expr(other) return Region([self - other - Expr.eps()], [], Comp.empty(), Comp.empty(), Comp.empty()) def __eq__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented other = iutil.ensure_expr(other) return Region([], [other - self], Comp.empty(), Comp.empty(), Comp.empty()) def __ne__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return NotImplemented #return RegionOp.union([self > other, self < other]) return ~RegionOp.inter([self == other]) def equiv(self, other): """Whether self is equal to other""" return (self <= other).check() and (other <= self).check() def real_present(self): for (a, c) in self.terms: if a.get_type() == TermType.REAL: return True return False def complexity(self): max_denom = PsiOpts.settings["max_denom"] r = 0 for (a, c) in self.terms: frac = fractions.Fraction(c).limit_denominator(max_denom) r += min(abs(frac.numerator) + abs(frac.denominator), 8) r += a.complexity() * 4 if any(a.ishc() for a, c in self.terms): r += 500 r += len(self.allcomp()) * 1000 return r def sorting_priority(self): return int(not self.real_present()) * 100000 + self.complexity() def get_coeff(self, b, get_pos = False): """Get coefficient of Term b""" if isinstance(b, Term): r = 0.0 pos = None for i, (a, c) in enumerate(self.terms): if a == b: r += c pos = i if get_pos: return (r, pos) return r if isinstance(b, Comp): r = 0.0 pos = None for i, (a, c) in enumerate(self.terms): if a.ishc() and (a.x[0] - a.z).super_of(b): r += c pos = i if get_pos: return (r, pos) return r r = None pos = None for a2, c2 in b.terms: coeff2 = b.get_coeff(a2) if coeff2 == 0: continue coeff1, tpos = self.get_coeff(a2, get_pos = True) if coeff1 == 0: if get_pos: return (0.0, None) return 0.0 if pos is None: pos = tpos elif tpos is not None: pos = min(pos, tpos) t = coeff1 / coeff2 if r is None: r = t else: if r > 0: if t < 0: if get_pos: return (0.0, None) return 0.0 r = min(r, t) else: if t > 0: if get_pos: return (0.0, None) return 0.0 r = max(r, t) if r is None: if get_pos: return (0.0, None) return 0.0 if get_pos: return (r, pos) return r def __contains__(self, other): if isinstance(other, (Term, Expr)): return self.get_coeff(other) != 0 for a, c in self.terms: if other in a: return True return False def partition(self, p): self.mhash = None t = self.terms self.terms = [] for a, c in t: a2 = [a.restricted(px) for px in p] if any(b is None for b in a2): self.terms.append((a, c)) continue for b in a2: self.terms.append((b, c)) self.simplify_quick() def partitioned(self, p): r = self.copy() r.partition(p) return r def isbalanced(self, v = None): if v is None: v = self.allcomprv_shallow() # return all(self.get_coeff(x) == 0 for x in v) ceps = PsiOpts.settings["eps"] return all(abs(self.get_coeff(x)) <= ceps for x in v) def balanced(self, v = None, w = None, sn = 1): ceps = PsiOpts.settings["eps"] if v is None: v = self.allcomprv_shallow() if w is None: w = self.allcomprv_shallow() r = self.copy() cs = [] for x in v: c = -r.get_coeff(x) * sn if c < -ceps: return None if c > ceps: cs.append((x, c)) while cs: t = min(c for x, c in cs) tx = sum((x for x, c in cs), Comp.empty()) tz = w - tx r += Expr.Hc(tx, tz) * t * sn cs = [(x, c - t) for x, c in cs if c - t > ceps] return r.simplified_quick() def get_sign(self, b): """Return whether Expr is increasing or decreasing in random variable b""" r = 0 for (a, c) in self.terms: sn = 0 if abs(c) <= PsiOpts.settings["eps"] or not a.ispresent(b): continue sn = 1 if c > 0 else -1 if a.get_type() == TermType.IC: if len(a.x) > 2 or len(a.x) == 0: return 0 nx = 0 nz = 0 if a.z.ispresent(b): nz += 1 for x in a.x: if x.ispresent(b): nx += 1 if nx + nz != 1: return 0 if len(a.x) == 2 and nz == 1: return 0 if nz == 1: sn = -sn else: return 0 if r != 0 and r != sn: return 0 r = sn return r def get_var_avoid(self, x): r = None for (a, c) in self.terms: t = a.get_var_avoid(x) if r is None: r = t elif t is not None: r = r.inter(t) return r def name_avoid(self, name0): index = IVarIndex() self.record_to(index) return index.name_avoid(name0) def isconvex(self, v = None, bnet = None): """Check whether expression is convex with respect to random variables v and real variables. False return value does NOT necessarily mean expression is not convex. """ tmpvar = Expr.real(self.name_avoid("t")) return (self <= tmpvar).isconvex(v, bnet) def isconcave(self, v = None, bnet = None): """Check whether expression is concave with respect to random variables v and real variables. False return value does NOT necessarily mean expression is not concave. """ tmpvar = Expr.real(self.name_avoid("t")) return (self >= tmpvar).isconvex(v, bnet) def isaffine(self, v = None, bnet = None): """Check whether expression is affine with respect to random variables v and real variables. False return value does NOT necessarily mean expression is not affine. """ tmpvar = Expr.real(self.name_avoid("t")) return (self <= tmpvar).isconvex(v, bnet) and (self >= tmpvar).isconvex(v, bnet) def concave_envelope(self, v = None, bnet = None, q = None): """Compute the upper concave envelope with respect to random variables v. C. Nair, "Upper concave envelopes and auxiliary random variables," International Journal of Advances in Engineering Sciences and Applied Mathematics, vol. 5, no. 1, pp. 12-20, 2013. """ name = "env_" + str(iutil.get_count("env")) tmpvar = Expr.real(self.name_avoid(name)) return (self >= tmpvar).convexified(v, bnet, q = q).maximum(tmpvar, None, allow_reuse = True) def convex_envelope(self, v = None, bnet = None, q = None): """Compute the lower convex envelope with respect to random variables v. C. Nair, "Upper concave envelopes and auxiliary random variables," International Journal of Advances in Engineering Sciences and Applied Mathematics, vol. 5, no. 1, pp. 12-20, 2013. """ name = "env_" + str(iutil.get_count("env")) tmpvar = Expr.real(self.name_avoid(name)) return (self <= tmpvar).convexified(v, bnet, q = q).minimum(tmpvar, None, allow_reuse = True) def remove_term(self, b): """Remove Term b in place.""" self.terms = [(a, c) for (a, c) in self.terms if a != b] self.mhash = None return self def removed_term(self, b): """Remove Term b, return Expr after removal.""" return Expr([(a, c) for (a, c) in self.terms if a != b], meta = iutil.copy(self.meta)) def symm_sort(self, terms): """Sort the random variables in terms assuming symmetry among those terms.""" self.mhash = None index = IVarIndex() terms.record_to(index) terms = index.comprv n = len(terms.varlist) v = [0] * n for (a, c) in self.terms: cint = int(round(c * 1000)) for b, bc in [(t, 2) for t in a.x] + [(a.z, 1)]: mask = index.get_mask(b) count = iutil.bitcount(mask) for i in range(n): if mask & (1 << i): v[i] += cint * bc + count * 5 + 11 vs = sorted(list(range(n)), key = lambda k: v[k], reverse = True) tmpvar = Comp.array("#TMPVAR", 0, n) for i in range(n): self.substitute(terms[i], tmpvar[i]) for i in range(n): self.substitute(tmpvar[i], terms[vs[i]]) def coeff_sum(self): """Sum of coefficients""" return sum([c for (a, c) in self.terms]) def coeff_sign(self): """Sign of coefficients""" p0 = False p1 = False for a, c in self.terms: if c > 0: p0 = True elif c < 0: p1 = True if p0 and p1: return 0 if p0: return 1 if p1: return -1 return 0 def simplify_mul(self, mul_allowed = 0): self.mhash = None if mul_allowed > 0: max_denom = PsiOpts.settings["max_denom"] max_denom_mul = PsiOpts.settings["max_denom_mul"] denom = 1 for (a, c) in self.terms: if a.isone(): continue denom = iutil.lcm(fractions.Fraction(c).limit_denominator( max_denom).denominator, denom) if denom > max_denom_mul: break if denom > 0 and denom <= max_denom_mul: if mul_allowed >= 2: if self.coeff_sum() < 0: denom = -denom if all(a.isone() or abs(c * denom - round(c * denom)) <= PsiOpts.settings["eps"] for (a, c) in self.terms): self.terms = [(a, iutil.float_snap(c * denom)) for (a, c) in self.terms] num = None for (a, c) in self.terms: tnum = abs(fractions.Fraction(c).limit_denominator(max_denom).numerator) if tnum == 0: continue if num is None: num = tnum else: num = iutil.gcd(num, tnum) if num is not None and num > 1: self.terms = [(a, iutil.float_snap(c / num)) for (a, c) in self.terms] def mi_disjoint(self): i = 0 while i < len(self.terms): a, c = self.terms[i] if a.get_type() == TermType.IC and len(a.x) >= 2: xt = a.x[0] for j in range(1, len(a.x)): xt = xt.inter(a.x[j]) if not xt.isempty(): self.terms.insert(i + 1, (Term.Hc(xt, a.z), c)) for j in range(len(a.x)): a.x[j] = a.x[j] - xt a.z = a.z + xt i += 1 i += 1 def combine_same_terms(self): ceps = PsiOpts.settings["eps"] for i in range(len(self.terms)): for j in range(i): if self.terms[i][0] == self.terms[j][0]: self.terms[j] = (self.terms[j][0], self.terms[j][1] + self.terms[i][1]) self.terms[i] = (self.terms[i][0], 0.0) break self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()] def simplify(self, reg = None, bnet = None, quick = False, term_allow = None): """Simplify the expression in place""" ceps = PsiOpts.settings["eps"] reduce_coeff = PsiOpts.settings.get("simplify_reduce_coeff", False) self.mhash = None if not quick and PsiOpts.settings.get("simplify_regterm", False): self.simplify_regterm(reg) for (a, c) in self.terms: a.simplify(reg, bnet) self.mi_disjoint() did = True while did: did = False for i in range(len(self.terms)): for j in range(i): if self.terms[i][0] == self.terms[j][0]: self.terms[j] = (self.terms[j][0], self.terms[j][1] + self.terms[i][1]) self.terms[i] = (self.terms[i][0], 0.0) did = True break self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()] for i in range(len(self.terms)): if abs(self.terms[i][1]) > ceps: for j in range(len(self.terms)): if i != j and abs(self.terms[j][1]) > ceps: ci = self.terms[i][1] cj = self.terms[j][1] if abs(ci - cj) <= ceps: if self.terms[i][0].try_iadd(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[j] = (self.terms[j][0], 0.0) did = True elif reduce_coeff and ci * cj > 0: if abs(ci) > abs(cj): ti = self.terms[i][0].copy() if self.terms[i][0].try_iadd(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[i] = (self.terms[i][0], cj) self.terms[j] = (ti, ci - cj) did = True else: if self.terms[i][0].try_iadd(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[j] = (self.terms[j][0], cj - ci) did = True elif abs(ci + cj) <= ceps: if self.terms[i][0].try_isub(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[j] = (self.terms[j][0], 0.0) did = True elif self.terms[i][0].try_isub_flipsign(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[j] = (self.terms[j][0], ci) did = True elif reduce_coeff and ci * cj < 0: if abs(ci) > abs(cj): ti = self.terms[i][0].copy() if self.terms[i][0].try_isub(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[i] = (self.terms[i][0], -cj) self.terms[j] = (ti, ci + cj) did = True else: if self.terms[i][0].try_isub(self.terms[j][0], bnet = bnet, term_allow = term_allow): self.terms[j] = (self.terms[j][0], cj + ci) did = True self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()] if did: for (a, c) in self.terms: a.simplify(reg, bnet) #self.terms = [(a, iutil.float_snap(c)) for (a, c) in self.terms # if abs(c) > ceps and not a.iszero()] if term_allow is None and PsiOpts.settings["term_allow"] is not None: self.simplify_break_allow(reg, bnet, quick, PsiOpts.settings["term_allow"]) return self def simplify_break_allow(self, reg = None, bnet = None, quick = False, term_allow = None): if (not term_allow & TermAllowType.I) or (not term_allow & TermAllowType.I3): did = True while did: did = False for i in range(len(self.terms)): a, c = self.terms[i] if a.get_type() != TermType.IC: continue if len(a.x) <= 1: continue if term_allow & TermAllowType.I and len(a.x) <= 2: continue minc = 0 mint = None for k in range(len(a.x)): t = list(self.terms) a2 = a.copy() a2.x.pop(k) a3 = a2.copy() a3.z += a.x[k] t[i:i+1] = [(a2, c), (a3, -c)] t = Expr(t) t.simplify(reg, bnet, quick, term_allow) tc = t.complexity() if mint is None or tc < minc: minc = tc mint = t self.copy_(mint) did = True break if (not term_allow & TermAllowType.HC) or (not term_allow & TermAllowType.IC): did = False t = list(self.terms) i = 0 while i < len(t): a, c = t[i] if a.get_type() != TermType.IC: i += 1 continue if len(a.z) == 0: i += 1 continue if len(a.x) <= 1: if term_allow & TermAllowType.HC: i += 1 continue else: if term_allow & TermAllowType.IC: i += 1 continue a2 = a.copy() for k in range(len(a2.x)): a2.x[k] += a.z a2.z = Comp.empty() t[i:i+1] = [(a2, c), (Term.H(a.z.copy()), -c)] i += 2 did = True if did: t = Expr(t) self.copy_(t) self.simplify(reg, bnet, quick, term_allow) return self def simplified(self, reg = None, bnet = None, quick = False): """Simplify the expression, return simplified expression""" r = self.copy() r.simplify(reg, bnet, quick) return r def simplify_more_cond(self, bnet = None): return None def simplified_exhaust_inner(self, bnet = None): if len(self.terms) <= 2: return None scom = self.complexity() for i, (a, c) in enumerate(self.terms): if a.get_type() == TermType.IC: a2 = a.copy() for ix in range(len(a2.x)): if len(a2.x[ix]) <= 1: continue for x in a2.x[ix]: ks = [Term([p - x if ip == ix else p.copy() for ip, p in enumerate(a2.x)], a2.z.copy()), Term([x.copy() if ip == ix else p.copy() for ip, p in enumerate(a2.x)], a2.z + (a2.x[ix] - x))] # print(str(a2) + " " + str(ks[0]) + " " + str(ks[1])) for it in range(2): expr = Expr([(a4.copy(), c4) if i4 != i else (ks[it].copy(), c) for i4, (a4, c4) in enumerate(self.terms)]) # print(" " + str(expr)) expr.simplify_quick(bnet = bnet) # print(" " + str(expr)) expr += Expr([(ks[1 - it].copy(), c)]) # print(" " + str(expr)) expr.simplify_quick(bnet = bnet) # print(" " + str(expr)) # print() if expr.complexity() < scom: return expr return None def break_hc(self): self.mhash = None olen = len(self.terms) for i in range(olen): a, c = self.terms[i] if a.ishc() and not a.z.isempty(): self.terms.append((Term.I(a.x[0], a.z), -c)) a.z = Comp.empty() def simplify_break_hc(self, bnet = None): r = self.copy() r.break_hc() r.simplify_quick(bnet = bnet) if r.complexity() < self.complexity(): self.copy_(r) self.mhash = None def simplified_break_hc(self, bnet = None): r = self.copy() r.simplify_break_hc(bnet = bnet) return r def simplify_exhaust(self, bnet = None): self.terms.sort(key = lambda a: a[0].complexity()) self.simplify_break_hc(bnet = bnet) self.terms.sort(key = lambda a: a[0].complexity()) self.mhash = None while True: t = self.simplified_exhaust_inner(bnet = bnet) if t is None: break self.terms = t.terms self.mhash = None def simplified_exhaust(self, bnet = None): r = self.copy() r.simplify_exhaust(bnet = bnet) return r def simplify_target(self, target, bnet = None): for term in target: if term.terms[0][0].get_type() != TermType.IC: continue t = (self - term).simplified_quick(bnet = bnet) + term if t.complexity() <= self.complexity(): self.copy_(t) return self def simplify_quick(self, reg = None, **kwargs): return self.simplify(reg = reg, **kwargs, quick = True) def simplified_quick(self, reg = None, **kwargs): return self.simplified(reg = reg, **kwargs, quick = True) def simplified_prog(self, prog = None, reg = None): if prog is None: index = IVarIndex() reg.record_to(index) self.record_to(index) prog = reg.init_simplified_prog(index) cself, creal = self.split_ic_real() prog.clear_dual() with PsiOpts(proof_enabled = False): prog.checkexpr_ge0(cself) r = prog.get_dual_sum_meta() if r is None: return self.copy() return r.simplified_quick() + creal # return r + creal def get_ratio(self, other, skip_simplify = False): """Try dividing self by other, return None if self is not scalar multiple of other""" es = self eo = other if not skip_simplify: es = self.simplified_quick() eo = other.simplified_quick() if es.iszero(): return 0.0 if len(es.terms) != len(eo.terms): return None if eo.iszero(): return None rmax = -1e12 rmin = 1e12 vis = [False] * len(eo.terms) for i in range(len(es.terms)): found = False for j in range(len(eo.terms)): if not vis[j] and abs(eo.terms[j][1]) > PsiOpts.settings["eps"] and es.terms[i][0] == eo.terms[j][0]: cr = es.terms[i][1] / eo.terms[j][1] rmax = max(rmax, cr) rmin = min(rmin, cr) if rmax > rmin + PsiOpts.settings["eps"]: return None vis[j] = True found = True break if not found: return None if rmax <= rmin + PsiOpts.settings["eps"]: return (rmax + rmin) * 0.5 else: return None def __truediv__(self, other): if isinstance(other, Expr): t = other.get_const() if t is None: t = self.get_ratio(other) if t is not None: return t else: return Expr.fromterm(Term(Comp.real( iutil.fcn_name_maker("/", [self, other], lname = "/", infix = True) ), reg = Region.universe(), fcncall = "*", fcnargs = [self, other])) other = t return Expr([(a.copy(), c / other) for (a, c) in self.terms]) def __rtruediv__(self, other): if isinstance(other, Expr): return other / self return Expr.const(other) / self def ispresent(self, x): """Return whether any variable in x appears here""" for (a, c) in self.terms: if a.ispresent(x): return True return False def affine_present(self): """Return whether this expression is affine.""" return self.ispresent((Expr.one() + Expr.eps() + Expr.inf()).allcomp()) def try_remove(self, x, sn): for (a, c) in self.terms: if not a.try_remove(x, sn * c): return False self.simplify_quick() return True def rename_var(self, name0, name1): for (a, c) in self.terms: a.rename_var(name0, name1) self.mhash = None def rename_map(self, namemap): """Rename according to name map """ for (a, c) in self.terms: a.rename_map(namemap) self.mhash = None return self def definition(self): """Return the definition of this expression. """ return Expr([(a.definition(), c) for a, c in self.terms]) def substitute_rate(self, v0, v1): with PsiOpts(meta_subs_criteria = False): self.mhash = None for i, v0c in enumerate(v0): if i > 0: self.substitute(v0c, v0[0]) c = self.commonpart_coeff(v0[0]) self.substitute(v0[0], Comp.empty()) if c != 0: self += v1 * c return self @fcn_substitute def substitute(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place""" self.mhash = None if isinstance(v0, Expr): if len(v0.terms) > 0: t = v0.terms[0][0] tmpterms = self.terms self.terms = [] for (a, c) in tmpterms: if a == t: self += v1 * c else: self.terms.append((a, c)) elif isinstance(v1, Expr): self.substitute_rate(v0, v1) else: for (a, c) in self.terms: a.substitute(v0, v1) if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute(self.meta, v0, v1) return self @fcn_substitute def substitute_whole(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound)""" if isinstance(v0, Term): v0 = Expr.fromterm(v0) if isinstance(v1, Term): v1 = Expr.fromterm(v1) self.mhash = None if isinstance(v0, Expr): if len(v0.terms) > 0: coeff, pos = self.get_coeff(v0, get_pos = True) if coeff != 0 and pos is not None: self.terms[pos:pos] = (v1 * coeff).terms self.terms.extend((v0 * (-coeff)).terms) self.combine_same_terms() elif isinstance(v1, Expr): self.substitute_rate(v0, v1) else: for (a, c) in self.terms: a.substitute_whole(v0, v1) if iutil.check_meta_subs_criteria(v0, v1): iutil.substitute_whole(self.meta, v0, v1) return self def condition(self, b): """Condition on random variable b, in place""" for (a, c) in self.terms: if a.get_type() == TermType.IC: a.z += b self.mhash = None return self def conditioned(self, b): """Condition on random variable b, return result""" r = self.copy() r.condition(b) return r def var_neighbors(self, v): r = v.copy() for (a, c) in self.terms: if a.get_type() == TermType.IC: t = sum(a.x, Comp.empty()) + a.z if t.ispresent(v): r += t elif a.get_type() == TermType.REGION: r += a.reg.var_neighbors(v) return r def __abs__(self): return eabs(self) #return emax(self, -self) def value(self, method = "", num_iter = 30, prog = None): r = 0.0 for (a, c) in self.terms: if a.isone(): r += c else: t = a.value(method = method, num_iter = num_iter, prog = prog) if t is None: return None r += c * t return r def solve_prog(self, method = "", num_iter = 30): prog = [] optval = self.value(method = method, num_iter = num_iter, prog = prog) if len(prog) > 0: return (optval, prog[0]) else: return (optval, None) def __call__(self, method = "", num_iter = 30, prog = None): return self.value(method = method, num_iter = num_iter, prog = prog) def __float__(self): return float(self.value()) def __int__(self): return int(round(float(self.value()))) def split_lhs(self, lhsvar): lhs = [] rhs = [] for (a, c) in self.terms: index = IVarIndex() a.record_to(index) t = index.size() index2 = IVarIndex() lhsvar.record_to(index) lhsvar.record_to(index2) if index.size() != t + index2.size(): lhs.append((a, c)) else: rhs.append((a, c)) return (Expr(lhs), Expr(rhs)) def split_posneg(self): lhs = [] rhs = [] for (a, c) in self.terms: if c > 0: rhs.append((a, c)) else: lhs.append((a, c)) return (Expr(lhs), Expr(rhs)) def split_present(self, eqnstr, lhsvar = None, prefer_ge = None): if isinstance(lhsvar, str) and lhsvar == "real": lhsvar = self.allcomprealvar() if prefer_ge is None: prefer_ge = PsiOpts.settings["str_eqn_prefer_ge"] eps_hide = PsiOpts.settings["str_eps_hide"] cs = self if eps_hide and eqnstr != "==": eps_coeff = cs.get_coeff(Term.eps()) cs = cs.substituted(Expr.eps(), Expr.zero()) if eqnstr == ">=" and eps_coeff < 0: eqnstr = ">" elif eqnstr == "<=" and eps_coeff > 0: eqnstr = "<" lhs = cs rhs = Expr.zero() if lhsvar is not None: lhs, rhs = cs.split_lhs(lhsvar) if lhs.iszero() or rhs.iszero(): lhs, rhs = cs.split_posneg() if prefer_ge: lhs, rhs = rhs, lhs if lhs.iszero(): lhs, rhs = rhs, lhs elif lhs.get_const() is not None: lhs, rhs = rhs, lhs rhs *= -1.0 toflip = False lhs_sign = lhs.coeff_sign() if lhs_sign < 0: toflip = True elif lhs_sign == 0: rhs_sign = rhs.coeff_sign() if rhs_sign < 0: toflip = True elif rhs_sign == 0: lhs_sum = lhs.coeff_sum() if lhs_sum < 0: toflip = True elif lhs_sum == 0: rhs_sum = rhs.coeff_sum() if rhs_sum < 0: toflip = True if toflip: lhs *= -1.0 rhs *= -1.0 eqnstr = iutil.reverse_eqnstr(eqnstr) return (lhs, rhs, eqnstr) def tostring_eqn(self, eqnstr, style = 0, tosort = False, lhsvar = None, prefer_ge = None, line_len = None, pf_note = None): if pf_note is None: pf_note = PsiOpts.settings["str_proof_note"] style = iutil.convert_str_style(style) if line_len is None: if style & PsiOpts.STR_STYLE_LATEX: line_len = PsiOpts.settings["latex_line_len"] lhs, rhs, eqnstr = self.split_present(eqnstr, lhsvar, prefer_ge) r = (lhs.tostring(style = style, tosort = tosort, tosort_pm = True) + " " + iutil.eqnstr_style(eqnstr, style) + " " + rhs.tostring(style = style, tosort = tosort, tosort_pm = True)) cnote = None if pf_note: cnote = self.get_meta("pf_note") if cnote is not None: cnote = iutil.pf_note_str(cnote, style, add_space = 4) if cnote is not None: r = [r, cnote] if style & PsiOpts.STR_STYLE_LATEX: r = iutil.latex_split_line(r, line_len) else: if isinstance(r, list): r = "".join(r) # if style & PsiOpts.STR_STYLE_PSITIP: # r = r[0] # else: # r = "".join(r) # if style & PsiOpts.STR_STYLE_LATEX and line_len is not None: # r = iutil.latex_split_line(r, line_len) return r def tostring_line_len(self, eqnstr = "", style = 0, tosort = False, line_len = None): style = iutil.convert_str_style(style) if line_len is None: if style & PsiOpts.STR_STYLE_LATEX: line_len = PsiOpts.settings["latex_line_len"] r = "" if eqnstr != "": r += iutil.eqnstr_style(eqnstr, style) + " " r += self.tostring(style = style, tosort = tosort, tosort_pm = True) if style & PsiOpts.STR_STYLE_LATEX and line_len is not None: r = iutil.latex_split_line(r, line_len, slstr = "\\;" * 3) return r class FcnRelation(IBaseObj): """Stores functional dependencies""" def __init__(self, fcn = None): self.index = IVarIndex() self.fcn = [] if fcn is not None: self += fcn def copy(self): r = FcnRelation() r.index = self.index.copy() r.fcn = list(self.fcn) return r def check_fcn(self, x, y): if y | x == x: return True i = 0 elapsed = 0 while elapsed < len(self.fcn): if self.fcn[i][0] | x == x and self.fcn[i][1] | x != x: x |= self.fcn[i][1] if y | x == x: return True elapsed = 0 i = (i + 1) % len(self.fcn) elapsed += 1 return False def simplify(self): i = 0 while i < len(self.fcn): t = self.fcn[i] self.fcn.pop(i) if not self.check_fcn(t[0], t[1]): self.fcn.insert(i, t) i += 1 def add_fcn(self, x, y): y = y & ~x if not self.check_fcn(x, y): self.fcn.append((x, y)) def check(self, x, y): if isinstance(x, Comp): x = self.index.get_mask(x) if isinstance(y, Comp): y = self.index.get_mask(y) return self.check_fcn(x, y) def __iadd__(self, other): if isinstance(other, list): for x in other: self += x return self if isinstance(other, tuple): t = list(other) if isinstance(t[0], Comp): self.index.record(t[0]) t[0] = self.index.get_mask(t[0]) if isinstance(t[1], Comp): self.index.record(t[1]) t[1] = self.index.get_mask(t[1]) t[1] &= ~(t[0]) self.fcn.append(tuple(t)) elif isinstance(other, FcnRelation): for a in other.fcn: self += (other.index.from_mask(a[0]), other.index.from_mask(a[1])) return self elif isinstance(other, IBaseObj): other.record_to(self.index) return self def __add__(self, other): r = self.copy() r += other return r def get_hc(self): r = Expr.zero() for a in self.fcn: r += Expr.Hc(self.index.from_mask(a[1]), self.index.from_mask(a[0])) return r def get_region(self): return self.get_hc() <= 0 def __bool__(self): return bool(self.get_region()) class BayesNet(IBaseObj): """Bayesian network""" def __init__(self, edges = None): self.index = IVarIndex() self.parent = [] self.child = [] self.fcn = [] if edges is not None: self += edges def copy(self): r = BayesNet() r.index = self.index.copy() r.parent = [list(x) for x in self.parent] r.child = [list(x) for x in self.child] r.fcn = list(self.fcn) return r def allcomp(self): return self.index.comprv.copy() def get_parents(self, x): """ Get the parents of node x. Parameters ---------- x : Comp Returns ------- Comp """ i = self.index.get_index(x) if i < 0: return None return sum((self.index.comprv[x] for x in self.parent[i]), Comp.empty()) def get_children(self, x): """ Get the children of node x. Parameters ---------- x : Comp Returns ------- Comp """ i = self.index.get_index(x) if i < 0: return None return sum((self.index.comprv[x] for x in self.child[i]), Comp.empty()) def get_ancestors(self, x, descendant = False, include_self = True): """ Get the ancestors of node x. Parameters ---------- x : Comp Returns ------- Comp """ n = self.index.comprv.size() vis = [False] * n i = self.index.get_index(x) vis[i] = include_self cstack = [i] r = Comp.empty() while len(cstack): x = cstack.pop() if vis[x]: r += self.index.comprv[x] for y in (self.child[x] if descendant else self.parent[x]): if not vis[y]: vis[y] = True cstack.append(y) return r def get_descendants(self, x, **kwargs): """ Get the descendants of node x. Parameters ---------- x : Comp Returns ------- Comp """ return self.get_ancestors(x, descendant = True, **kwargs) def edges(self): """ Generator over the edges of the network. Yields ------ Pairs of Comp representing the edges. """ n = self.index.comprv.size() for i in range(n): for j in self.child[i]: yield (self.index.comprv[i], self.index.comprv[j]) def add_edge_id(self, i, j): if i < 0 or j < 0 or i == j: return if i not in self.parent[j]: self.parent[j].append(i) if j not in self.child[i]: self.child[i].append(j) def remove_edge_id(self, i, j): if i < 0 or j < 0 or i == j: return self.parent[j].remove(i) self.child[i].remove(j) def record(self, x): self.index.record(x) n = self.index.comprv.size() while len(self.parent) < n: self.parent.append([]) while len(self.child) < n: self.child.append([]) while len(self.fcn) < n: self.fcn.append(False) @fcn_list_to_list def set_fcn(self, x, v = True): """Mark variables in x to be functions of their parents.""" self.record(x) for xa in x.varlist: i = self.index.get_index(xa) if i >= 0: self.fcn[i] = v def is_fcn(self, x): """Query whether x is a function of their parents.""" for xa in x.varlist: i = self.index.get_index(xa) if i >= 0: if not self.fcn[i]: return False return True @fcn_list_to_list def add_edge(self, x, y): """Add edges from every variable in x to every variable in y. Also add edges among variables in y. """ self.record(x) self.record(y) for xa in x.varlist: for ya in y.varlist: self.add_edge_id(self.index.get_index(xa), self.index.get_index(ya)) for yi in range(len(y.varlist)): for yj in range(yi + 1, len(y.varlist)): self.add_edge_id(self.index.get_index(y.varlist[yi]), self.index.get_index(y.varlist[yj])) return y def __iadd__(self, other): if isinstance(other, list): for x in other: self += x return self if isinstance(other, tuple): for i in range(len(other) - 1): self.add_edge(other[i], other[i + 1]) return self elif isinstance(other, Term): self.add_edge(other.z, other.x[0]) return self elif isinstance(other, Comp): self.add_edge(Comp.empty(), other) return self elif isinstance(other, BayesNet): for x, y in other.edges(): self.add_edge(x, y) for x in other.allcomp(): if other.is_fcn(x): self.set_fcn(x) return self return self def __add__(self, other): r = self.copy() r += other return r def join(self, other): if not isinstance(other, BayesNet): other = BayesNet([other.allcomp()]) r = self + other for x in self.allcomp(): if self.get_children(x).isempty(): for y in other.allcomp(): if other.get_parents(y).isempty(): r += (x, y) return r def __floordiv__(self, other): return self.join(other) def __xor__(self, other): return self + other def communicate(self, a, b): return self.get_ancestors(a).ispresent(b) and self.get_ancestors(b).ispresent(a) def scc(self): n = self.index.comprv.size() r = BayesNet() vis = [False] * n for i in range(n): if vis[i]: continue cgroup = Comp.empty() cparent = Comp.empty() for j in range(i, n): if self.communicate(self.index.comprv[i], self.index.comprv[j]): vis[j] = True cgroup += self.index.comprv[j] cparent += self.get_parents(self.index.comprv[j]) cparent -= cgroup for k in range(len(cgroup)): tparent = cparent + cgroup[:k] r.add_edge(tparent, cgroup[k]) if self.is_fcn(cgroup[k]) and tparent.super_of(self.get_parents(cgroup[k])): r.set_fcn(cgroup[k]) return r def get_components(self): n = self.index.comprv.size() vis = [False] * n cstack = [] r = [] for s in range(n): if vis[s]: continue cstack = [s] vis[s] = True r.append(self.index.comprv[s]) while cstack: i = cstack.pop() for j in self.child[i] + self.parent[i]: if not vis[j]: cstack.append(j) vis[j] = True r[-1] += self.index.comprv[j] return r def indep_components(self): n = self.index.comprv.size() vis = [False] * n r = [] did = True while did: did = False for i in range(n): if vis[i]: continue if len(self.parent[i]) == 0: r.append(1 << i) vis[i] = True did = True continue for j in range(len(r)): if all(r[j] & (1 << p) for p in self.parent[i]): r[j] |= 1 << i vis[i] = True did = True break return [self.index.from_mask(m) for m in r] def tsorted(self): n = self.index.comprv.size() cstack = [] cnparent = [0] * n nrec = 0 r = BayesNet() for i in range(n - 1, -1, -1): cnparent[i] = len(self.parent[i]) if cnparent[i] == 0: cstack.append(i) while len(cstack) > 0: i = cstack.pop() r.record(self.index.comprv[i]) nrec += 1 for j in reversed(self.child[i]): if cnparent[j] > 0: cnparent[j] -= 1 if cnparent[j] == 0: cstack.append(j) if nrec < n: return None for i in range(n): for j in self.parent[i]: r.add_edge(self.index.comprv[j], self.index.comprv[i]) for i in range(n): if self.fcn[i]: r.set_fcn(self.index.comprv[i]) return r def reversed(self): n = self.index.comprv.size() r = BayesNet() for x in reversed(self.index.comprv): r += x for i in range(n): for j in self.parent[i]: r.add_edge(self.index.comprv[j], self.index.comprv[i]) for i in range(n): if self.fcn[i]: r.set_fcn(self.index.comprv[i]) return r def iscyclic(self): return self.tsorted() is None def contracted_node(self, x): self = self.tsorted() k = self.index.get_index(x) if k < 0: return self.copy() n = self.index.comprv.size() r = BayesNet() child = sorted(self.child[k]) for i0 in range(len(child)): i = child[i0] for j0 in range(i0 + 1, len(child)): j = child[j0] r.add_edge(self.index.comprv[i], self.index.comprv[j]) for i in range(n): if i == k: continue for j in self.parent[i]: if j == k: for j2 in self.parent[k]: if j2 == k: continue r.add_edge(self.index.comprv[j2], self.index.comprv[i]) continue r.add_edge(self.index.comprv[j], self.index.comprv[i]) for i in range(n): if i == k: continue if self.fcn[i]: if self.fcn[k] or (k not in self.parent[i]): r.set_fcn(self.index.comprv[i]) return r def eliminated(self, x): r = self.copy() for a in x: r = r.contracted_node(a) return r def __sub__(self, other): return self.eliminated(other) def check_hc_mask(self, x, z): n = self.index.comprv.size() cstack = [] vis = [False] * n x &= ~z for i in range(n): if x & (1 << i): cstack.append(i) vis[i] = True if z & (1 << i): vis[i] = True while cstack: i = cstack.pop() if not self.fcn[i]: return False for j in self.parent[i]: if not vis[j]: cstack.append(j) vis[j] = True return True def fcn_descendants_mask(self, x): n = self.index.comprv.size() did = True while did: did = False for i in range(n): if self.fcn[i] and not x & (1 << i): if all(x & j for j in self.parent[i]): x |= 1 << i did = True return x def fcn_descendants(self, x): return self.index.from_mask(self.fcn_descendants_mask(self.index.get_mask(x))) def check_ic_mask(self, x, y, z, icset = None): if x < 0 or (icset is None and y < 0) or z < 0: return False z = self.fcn_descendants_mask(z) x &= ~z y &= ~z if icset is None: if x & y != 0: # if not self.check_hc_mask(x & y, z): # return False # z |= x & y # x &= ~z # y &= ~z return False if x == 0 or y == 0: return True else: if x == 0: icset[0] = (1 << self.index.comprv.size()) - 1 - z return True n = self.index.comprv.size() desc = z cstack = [] for i in range(n): if z & (1 << i) != 0: cstack.append(i) while len(cstack) > 0: i = cstack.pop() for j in self.parent[i]: if desc & (1 << j) == 0: desc |= (1 << j) cstack.append(j) vis = [0, x] cstack = [] for i in range(n): if x & (1 << i) != 0: cstack.append((1, i)) cicset = 0 while len(cstack) > 0: (d, i) = cstack.pop() if icset is None: if y & (1 << i) != 0: return False else: cicset |= 1 << i if z & (1 << i) == 0: for j in self.child[i]: if vis[0] & (1 << j) == 0: vis[0] |= (1 << j) cstack.append((0, j)) if (d == 0 and desc & (1 << i) != 0) or (d == 1 and z & (1 << i) == 0): for j in self.parent[i]: if vis[1] & (1 << j) == 0: vis[1] |= (1 << j) cstack.append((1, j)) if icset is not None: icset[0] = (1 << n) - 1 - (x | z | cicset) return True def check_ic(self, icexpr): for a, c in icexpr.terms: if a.isihc2(): if not self.check_ic_mask(self.index.get_mask(a.x[0]), self.index.get_mask(a.x[1]), self.index.get_mask(a.z)): return False elif a.ishc(): if not self.check_hc_mask(self.index.get_mask(a.x[0]), self.index.get_mask(a.z)): return False else: return False return True def max_ic_set(self, x, z): icset = [0] self.check_ic_mask(self.index.get_mask(x), 0, self.index.get_mask(z), icset) return self.index.from_mask(icset[0]) def relative_children_mask(self, x, y, children = True): cstack = [] vis = 0 n = self.index.comprv.size() for i in range(n): if x & (1 << i): cstack.append(i) vis |= 1 << i r = 0 while cstack: i = cstack.pop() if y & (1 << i): r |= 1 << i continue for j in (self.child[i] if children else self.parent[i]): if not vis & (1 << j): vis |= 1 << j cstack.append(j) return r def markov_blanket_mask(self, x, y): y &= ~x c = self.relative_children_mask(x, y) | x p = self.relative_children_mask(c, y & ~c, children=False) r = (c | p) & ~x if self.check_ic_mask(x, y & ~r, r): return r return y def markov_blanket(self, x, y = None): if y is None: y = self.index.comprv - x r = self.markov_blanket_mask(self.index.get_mask(x), self.index.get_mask(y)) return self.index.from_mask(r) def from_ic_inplace(self, icexpr, roots = None, add_var = None, add_hc = True): n_root = 0 if roots is not None: self.record(roots) n_root = self.index.comprv.size() if add_var is not None: self.record(add_var) ics = [] for (a, c) in icexpr.terms: if a.isic2(): self.record(a.x[0]) self.record(a.x[1]) self.record(a.z) x0 = self.index.get_mask(a.x[0]) x1 = self.index.get_mask(a.x[1]) z = self.index.get_mask(a.z) x0 &= ~z x1 &= ~z x0 &= ~x1 ics.append((x0, x1, z)) elif add_hc and a.ishc(): self.record(a.x[0]) self.record(a.z) x0 = self.index.get_mask(a.x[0]) z = self.index.get_mask(a.z) x0 &= ~z ics.append((x0, -1, z)) n = self.index.comprv.size() ics = [(x0, x1 if x1 >= 0 else (1 << n) - 1 - x0 - z, z) for x0, x1, z in ics] numcond = [0] * n for x0, x1, z in ics: for i in range(n): if z & (1 << i): numcond[i] += 1 ilist = list(range(n_root, n)) ilist.sort(key = lambda i: numcond[i]) n2 = n - n_root xk = 0 zk = 0 vis = 0 np2 = (1 << n2) dp = [100000000000] * np2 dpi = [-1] * np2 dped = [-1] * np2 dp[np2 - 1] = 0 for tvis in range(np2 - 2, -1, -1): vis = (tvis << n_root) | ((1 << n_root) - 1) nvis = iutil.bitcount(vis) for i in ilist: if vis & (1 << i) == 0: nedge = nvis * (10000 - numcond[i]) + dp[tvis | (1 << (i - n_root))] # nedge = nvis + dp[tvis | (1 << (i - n_root))] if nedge < dp[tvis]: dp[tvis] = nedge dpi[tvis] = i dped[tvis] = vis for (x0, x1, z) in ics: if z & ~vis != 0: continue if x0 & (1 << i) != 0: xk = x1 zk = (z + x0 - (1 << i)) & vis elif x1 & (1 << i) != 0: xk = x0 zk = (z + x1 - (1 << i)) & vis else: continue if vis & ~(zk | xk) != 0: continue nedge = iutil.bitcount(zk) * (10000 - numcond[i]) + dp[tvis | (1 << (i - n_root))] # nedge = iutil.bitcount(zk) + dp[tvis | (1 << (i - n_root))] if nedge < dp[tvis]: dp[tvis] = nedge dpi[tvis] = i dped[tvis] = zk #for vis in range(np2): # print("{0:b}".format(vis) + " " + str(dp[vis]) + " " + str(dpi[vis]) + " " + "{0:b}".format(dped[vis])) cvis = 0 for it in range(n2): i = dpi[cvis] ed = dped[cvis] for j in range(n): if ed & (1 << j) != 0: self.add_edge_id(j, i) cvis |= (1 << (i - n_root)) def from_ic(icexpr, roots = None, add_var = None, add_hc = True): """Construct Bayesian network from the sum of conditional mutual information terms (Expr). """ r = BayesNet() r.from_ic_inplace(icexpr, roots, add_var = add_var, add_hc = add_hc) return r def from_ic_list(icexpr, roots = None, add_var = None): """Construct a list of Bayesian networks from the sum of conditional mutual information terms (Expr). """ r = [] icexpr = icexpr.copy() while not icexpr.iszero(): t = BayesNet.from_ic(icexpr, roots = roots, add_var = add_var).tsorted() olen = len(icexpr.terms) icexpr.terms = [(a, c) for a, c in icexpr.terms if not t.check_ic(Expr.fromterm(a))] icexpr.mhash = None if len(icexpr.terms) == olen: tl = BayesNet.from_ic_list(Expr.fromterm(icexpr.terms[0][0]), roots = roots, add_var = add_var) icexpr.terms = [(a, c) for a, c in icexpr.terms if not any(t.check_ic(Expr.fromterm(a)) for t in tl)] icexpr.mhash = None r += tl continue r.append(t) return r def get_markov(self): """Get Markov chains as a list of lists. """ cs = self.tsorted() n = cs.index.comprv.size() r = [] def parent_min(i): r = i for j in cs.parent[i]: r = min(r, parent_min(j)) return r def parent_segment(i): if not cs.parent[i]: return -1 m = min(cs.parent[i]) if len(cs.parent[i]) != i - m: return -1 for j in range(m + 1, i): if set(cs.parent[j]) != set(cs.parent[m]).union(range(m, j)): return -1 return m def node_segment(i): for j in range(i - 1, -1, -1): if set(cs.parent[i]) != set(cs.parent[j]).union(range(j, i)): return j + 1 return 0 def recur(st, en): if st >= en: return i = en - 1 cms = [en, parent_min(i)] if cms[-1] > st: while cms[-1] > st: t = parent_min(cms[-1] - 1) cms.append(t) tl = [] for i in range(len(cms) - 1): if i: tl.append([]) tl.append(list(range(cms[i + 1], cms[i]))) r.append(tl) for i in range(len(cms) - 1): recur(cms[i + 1], cms[i]) return if len(cs.parent[i]) >= i - st: recur(st, i) return m = node_segment(i) t = [list(range(m, i + 1))] i = m while i >= st: t.append(cs.parent[i]) m = parent_segment(i) if m < 0: break i = m t.append([j for j in range(st, i) if j not in cs.parent[i]]) while not t[-1]: t.pop() if len(t) >= 3: r.append(t) recur(st, i) recur(0, n) return [[sum((cs.index.comprv[a] for a in b), Comp.empty()) for b in reversed(tl)] for tl in reversed(r)] def get_ic_sorted(self): n = self.index.comprv.size() r = Expr.zero() compvis = Comp.empty() for i in range(n): ps = Comp.empty() for j in self.parent[i]: ps += self.index.comprv[j] y = compvis - ps if y.size() > 0: r += Expr.Ic(self.index.comprv[i], y, ps) compvis += self.index.comprv[i] return r def get_ic_exhaust(self): n = self.index.comprv.size() r = Expr.zero() for i in range(n): ps = Comp.empty() pmask = 0 for j in self.parent[i]: ps += self.index.comprv[j] pmask |= 1 << j for mask in igen.subset_mask((1 << n) - 1 - (1 << i) - pmask): if mask == 0: continue if self.check_ic_mask(1 << i, mask, pmask): r += Expr.Ic(self.index.comprv[i], self.index.from_mask(mask), ps) return r def get_ic(self): return self.tsorted().get_ic_sorted() def get_region(self, exhaust = False): n = self.index.comprv.size() r = Expr.zero() for i in range(n): if self.fcn[i]: ps = Comp.empty() for j in self.parent[i]: ps += self.index.comprv[j] r += Expr.Hc(self.index.comprv[i], ps) if exhaust: r += self.get_ic_exhaust() else: r += self.get_ic() return r == 0 def assume(self): """Assume this Bayesian network is true in the current context. """ self.get_region().assume() def assume_only(self): """Assume this Bayesian network is true in the current context. Overwrite existing assumptions. """ self.get_region().assume_only() def assumed(self): """Create a context where this Bayesian network is assumed to be true. Use "with bnet.assumed(): ..." """ return self.get_region().assumed() def assumed_only(self): """Create a context where this Bayesian network is assumed to be true. Overwrite existing assumptions. Use "with bnet.assumed_only(): ..." """ return self.get_region().assumed_only() @latex_postprocess def _latex_(self): ms = self.get_markov() eqnlist = [Region.markov_tostring(cm, PsiOpts.STR_STYLE_LATEX, False) for cm in ms] if len(eqnlist) == 0: return "" if len(eqnlist) == 1: return eqnlist[0] return "\\begin{array}{l}\n" + "\\\\\n".join(eqnlist) + "\\\\\n\\end{array}" def _repr_svg_(self): if graphviz is None: return None return self.graph()._repr_svg_() def _repr_latex_(self): if graphviz is not None: return None if PsiOpts.settings.get("repr_latex", False): return self._latex_() return None def __bool__(self): return bool(self.get_region()) def __or__(self, other): return self.get_region() | other def __and__(self, other): return self.get_region() & other def __lshift__(self, other): return self.get_region() << other def __rshift__(self, other): return self.get_region() >> other def get_basis(self, more_vars = None): """Get a basis of the entropy region of this Bayesian network (may not be minimal). """ return self.get_region().get_basis(more_vars = more_vars) def tostring(self, tsort = True): if tsort: tself = self.tsorted() if tself is not None: return tself.tostring(tsort = False) n = self.index.comprv.size() r = "" for i in range(n): first = True for j in self.parent[i]: if not first: r += "," r += self.index.comprv.varlist[j].tostring() first = False r += " -> " + self.index.comprv.varlist[i].tostring() + ("*" if self.fcn[i] else "") + "\n" return r def __str__(self): return self.tostring() def __repr__(self): return self.tostring() def __hash__(self): return hash(self.tostring()) def graph(self, tsort = True, shape = "plaintext", lr = True, groups = None, ortho = False, **kwargs): """Return the graphviz digraph of the network that can be displayed in the console. """ if graphviz is None: raise RuntimeError("Requires graphviz. Please install it first.") if tsort: return self.tsorted().graph(tsort = False, shape = shape, lr = lr, groups = groups) n = self.index.comprv.size() r = graphviz.Digraph() if lr: r.graph_attr["rankdir"] = "LR" if ortho: r.graph_attr["splines"] = "ortho" for key, value in kwargs.items(): r.graph_attr[key] = str(value) if groups is None: groups = [] remrv = self.index.comprv.copy() for gi, g in enumerate(groups): with r.subgraph(name = "cluster_" + str(gi)) as rs: rs.attr(color = "blue") for c in g: if not remrv.ispresent(c): continue remrv -= c i = self.index.get_index(c) rs.node(self.index.comprv[i].get_name(), str(self.index.comprv[i]) + ("*" if self.fcn[i] else ""), shape = shape) for i in range(n): if not remrv.ispresent(self.index.comprv[i]): continue r.node(self.index.comprv[i].get_name(), str(self.index.comprv[i]) + ("*" if self.fcn[i] else ""), shape = shape) for i in range(n): for j in self.parent[i]: r.edge(self.index.comprv[j].get_name(), self.index.comprv[i].get_name()) return r class ValIndex: def __init__(self, v = None): if v is None: self.v = None self.vmap = {} else: self.v = v self.vmap = {} for i, x in enumerate(v): self.vmap[x] = i def get_index(self, x): return self.vmap.get(x, -1) class ConcDist(IBaseObj): """Concrete distributions / conditional distributions of random variables.""" def convert_shape(x): if x is None: return tuple() if isinstance(x, int): return (x,) if isinstance(x, Comp): return x.get_shape() if isinstance(x, ConcDist): return x.shape_out return tuple(x) def convert_shape_pair(x): if x is None: return (tuple(), tuple()) if isinstance(x, int): return (tuple(), (x,)) if isinstance(x, Term): return (x.z.get_shape(), sum((t.get_shape() for t in x.x), tuple())) if isinstance(x, Comp) or isinstance(x, ConcDist): return (tuple(), ConcDist.convert_shape(x)) if isinstance(x, tuple) and (len(x) == 0 or isinstance(x[0], int)): return (tuple(), tuple(x)) return (ConcDist.convert_shape(x[0]), ConcDist.convert_shape(x[1])) def __init__(self, p = None, num_in = None, shape = None, shape_in = None, shape_out = None, isvar = False, randomize = False, isfcn = False, check_valid = False): self.force_float = True self.isvar = isvar self.iscache = False self.v = None # if p is None and shape_in is None and shape_out is None: # self.p = None # return self.expr = None self.isfcn = isfcn if isinstance(p, list) and iutil.hasinstance(p, Expr): p = ExprArray(p) if isinstance(p, ExprArray) and p.isconst(): p = p.to_numpy() if isinstance(p, ExprArray): self.expr = p if num_in is None: if shape_in is None: num_in = 0 else: num_in = len(shape_in) if num_in is not None: shape_in = p.shape[:num_in] shape_out = p.shape[num_in:] p = None if self.isvar and torch is None: raise RuntimeError("Requires pytorch. Please install it first.") if shape is not None: shape_in, shape_out = ConcDist.convert_shape_pair(shape) if isinstance(p, ConcDist): if num_in is None and shape is None and shape_in is None: num_in = p.get_num_in() p = p.p self.sublens = [] if p is None: self.shape_in = ConcDist.convert_shape(shape_in) self.shape_out = ConcDist.convert_shape(shape_out) self.shape = self.shape_in + self.shape_out if randomize: self.randomize() else: self.set_uniform() else: if isinstance(p, list): p = numpy.array(p) self.p = p if num_in is None: if shape_in is not None: num_in = len(shape_in) else: num_in = 0 tshape = p.shape self.shape_in = tshape[:num_in] self.shape_out = tshape[num_in:] self.shape = self.shape_in + self.shape_out if self.isvar: self.normalize() else: if check_valid and not self.isvalid(): raise ValueError("Invalid probability distribution. Must contain nonnegative entries that sum to 1.") if isfcn: self.clamp_fcn() def copy(self): r = ConcDist(p = iutil.copy(self.p), shape_in = self.shape_in, shape_out = self.shape_out, isvar = self.isvar, isfcn = self.isfcn) r.force_float = self.force_float r.p = iutil.copy(self.p) r.expr = iutil.copy(self.expr) r.v = None r.copy_torch() return r def clamp_fcn(self, randomize = False): rnd = PsiOpts.get_random() for xs in itertools.product(*[range(x) for x in self.shape_in]): mzs = None m = -1.0 for zs in itertools.product(*[range(z) for z in self.shape_out]): t = float(self.p[xs + zs]) if randomize: t /= rnd.exponential() self.p[xs + zs] = 0.0 if t > m: m = t mzs = zs self.p[xs + mzs] = 1.0 self.copy_torch() def fraction_snap(self, denom = None, eps = None): for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): self.p[xs + zs] = iutil.float_snap(float(self.p[xs + zs]), denom = denom, eps = eps) self.normalize() def istorch(self): return torch is not None and isinstance(self.p, torch.Tensor) def is_placeholder(self): return self.p is None def get_num_in(self): return len(self.shape_in) def card_out(self): r = 1 for a in self.shape_out: r *= a return r def __getitem__(self, key): return self.p[key] def __setitem__(self, key, value): self.p[key] = value def copy_(self, other): """Copy content of other to self. """ if isinstance(other, ConcDist): self.p = other.p else: self.p = other self.copy_torch() def flattened_sublen(self): shape_its = [] shape_out_sub = [] c = 0 sublens = list(self.sublens) + [len(self.shape_out) - sum(self.sublens)] for l in sublens: shape_out_sub.append(iutil.product(self.shape_out[c:c+l])) c += l shape_out_sub = tuple(shape_out_sub) r = numpy.zeros(self.shape_in + shape_out_sub) if torch is not None and isinstance(self.p, torch.Tensor): r = torch.tensor(r, dtype=torch.float64) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): ids = [] c = 0 for l in sublens: t = 0 for i in range(c, c + l): t = t * self.shape_out[i] + zs[i] ids.append(t) c += l r[xs + tuple(ids)] = self.p[xs + zs] r = ConcDist(r, num_in = self.get_num_in()) r.sublens = [1] * (len(shape_out_sub) - 1) return r def calc_torch(self): if not self.isvar: return card_out = self.card_out() self.p = torch.zeros(self.shape_in + self.shape_out, dtype=torch.float64) for xs in itertools.product(*[range(x) for x in self.shape_in]): s = 0.0 ci = 0 for zs in itertools.product(*[range(z) for z in self.shape_out]): if ci < card_out - 1: self.p[xs + zs] = self.v[xs + (ci,)] s += self.v[xs + (ci,)] else: self.p[xs + zs] = 1.0 - s ci += 1 def copy_torch(self): if not self.isvar: self.v = None return card_out = self.card_out() # self.p = self.p.numpy() if self.v is None: self.v = numpy.zeros(self.shape_in + (card_out - 1,)) self.v = torch.tensor(self.v, dtype=torch.float64, requires_grad = True) with torch.no_grad(): for xs in itertools.product(*[range(x) for x in self.shape_in]): ci = 0 for zs in itertools.product(*[range(z) for z in self.shape_out]): if ci < card_out - 1: self.v[xs + (ci,)] = float(self.p[xs + zs]) ci += 1 self.calc_torch() def normalize(self): if self.isfcn: self.clamp_fcn() return ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] card_out = self.card_out() if self.isvar and torch is not None and isinstance(self.p, torch.Tensor): self.p = self.p.detach().numpy() for xs in itertools.product(*[range(x) for x in self.shape_in]): s = 0.0 for zs in itertools.product(*[range(z) for z in self.shape_out]): s += self.p[xs + zs] if s > ceps: for zs in itertools.product(*[range(z) for z in self.shape_out]): self.p[xs + zs] /= s else: for zs in itertools.product(*[range(z) for z in self.shape_out]): self.p[xs + zs] = 1.0 / card_out self.copy_torch() def clamp(self): if not self.isvar: return if self.isfcn: self.clamp_fcn(randomize = True) return ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] card_out = self.card_out() vt = self.v.detach().numpy() for xs in itertools.product(*[range(x) for x in self.shape_in]): for z in range(card_out - 1): t = vt[xs + (z,)] if numpy.isnan(t): self.randomize() return vt[xs + (z,)] = max(t, 0.0) if card_out > 2: while True: s = 0.0 minpos = 1e20 numpos = 0 for z in range(card_out - 1): if vt[xs + (z,)] > 0: s += vt[xs + (z,)] minpos = min(minpos, vt[xs + (z,)]) numpos += 1 if s <= 1.0 + ceps: break tored = (s - 1.0) / numpos good = False if tored <= minpos: good = True else: tored = minpos for z in range(card_out - 1): if vt[xs + (z,)] > 0: vt[xs + (z,)] -= tored if good: break for z in range(card_out - 1): vt[xs + (z,)] = min(vt[xs + (z,)], 1.0) # for z in range(card_out - 1): # self.v[xs + (z,)] = vt[xs + (z,)] with torch.no_grad(): self.v.copy_(torch.tensor(vt, dtype=torch.float64)) # self.v.copy_(torch.tensor(vt)) # with torch.no_grad(): # for xs in itertools.product(*[range(x) for x in self.shape_in]): # for z in range(card_out - 1): # self.v[xs + (z,)] = vt[xs + (z,)] # self.v = torch.tensor(self.v, requires_grad = True) self.calc_torch() def set_uniform(self): card_out = self.card_out() self.p = numpy.ones(self.shape_in + self.shape_out) * (1.0 / card_out) if self.isvar: self.normalize() def randomize(self): rnd = PsiOpts.get_random() self.p = rnd.exponential(size = self.shape) self.normalize() def hop(self, prob): rnd = PsiOpts.get_random() if torch is not None and isinstance(self.p, torch.Tensor): self.p = self.p.detach().numpy() for xs in itertools.product(*[range(x) for x in self.shape_in]): if rnd.uniform() >= prob: continue for zs in itertools.product(*[range(z) for z in self.shape_out]): self.p[xs + zs] = rnd.exponential() self.normalize() def get_p(self): return self.p def get_v(self): # if self.isfcn: # return None return self.v def numpy(self): """Convert to numpy array.""" if iutil.istorch(self.p): return self.p.detach().numpy() return self.p def torch(self): """Convert to torch.Tensor.""" if iutil.istorch(self.p): return self.p return torch.tensor(self.p, dtype=torch.float64) def entropy(self): """Entropy of this distribution.""" ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] loge = PsiOpts.settings["ent_coeff"] istorch = torch is not None and isinstance(self.p, torch.Tensor) r = 0.0 for xs in itertools.product(*[range(m) for m in self.shape]): c = self.p[xs] if istorch: r -= c * torch.log((c + ceps_d) / (1.0 + ceps_d)) * loge else: if c > ceps: r -= c * numpy.log(c) * loge return ConcReal(r) def items(self): for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): yield self.p[xs + zs] def convert(x): if not isinstance(x, ConcDist): x = ConcDist(x) return x def __add__(self, other): other = ConcDist.convert(other) if (self.shape_in, self.shape_out) != (other.shape_in, other.shape_out): raise ValueError("Shape mismatch.") return return ConcDist(self.p + other.p, num_in = self.get_num_in()) def __radd__(self, other): return self + other def __mul__(self, other): if isinstance(other, int) or isinstance(other, float): return ConcDist(self.p * float(other), num_in = self.get_num_in()) other = ConcDist.convert(other) if self.shape_in != other.shape_in: raise ValueError("Shape mismatch.") return r = None if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)): r = torch.zeros(self.shape_in + self.shape_out + other.shape_out, dtype=torch.float64) else: r = numpy.zeros(self.shape_in + self.shape_out + other.shape_out) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): for ws in itertools.product(*[range(w) for w in other.shape_out]): r[xs + zs + ws] += self.p[xs + zs] * other.p[xs + ws] return ConcDist(r, num_in = self.get_num_in()) def __rmul__(self, other): if isinstance(other, int) or isinstance(other, float): return self * other other = ConcDist.convert(other) return other * self def __pow__(self, other): r = ConcDist(shape_in = self.shape_in, shape_out = tuple()) for i in range(other): r = r * self return r def chan_product(self, other): """Product channel. """ other = ConcDist.convert(other) r = None if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)): r = torch.zeros(self.shape_in + other.shape_in + self.shape_out + other.shape_out, dtype=torch.float64) else: r = numpy.zeros(self.shape_in + other.shape_in + self.shape_out + other.shape_out) for xs in itertools.product(*[range(x) for x in self.shape_in]): for x2s in itertools.product(*[range(x2) for x2 in other.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): for ws in itertools.product(*[range(w) for w in other.shape_out]): r[xs + x2s + zs + ws] += self.p[xs + zs] * other.p[x2s + ws] return ConcDist(r, num_in = self.get_num_in() + other.get_num_in()) def chan_power(self, other): """n-product channel. """ r = ConcDist(shape_in = tuple(), shape_out = tuple()) for i in range(other): r = r.chan_product(self) return r def __matmul__(self, other): other = ConcDist.convert(other) if self.shape_out != other.shape_in: raise ValueError("Shape mismatch.") return r = None if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)): r = torch.zeros(self.shape_in + other.shape_out, dtype=torch.float64) else: r = numpy.zeros(self.shape_in + other.shape_out) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): for ws in itertools.product(*[range(w) for w in other.shape_out]): r[xs + ws] += self.p[xs + zs] * other.p[zs + ws] return ConcDist(r, num_in = self.get_num_in()) def __truediv__(self, other): if isinstance(other, int) or isinstance(other, float): return ConcDist(self.p * float(1.0 / other), num_in = self.get_num_in()) if self.shape_in != other.shape_in: raise ValueError("Shape mismatch.") return cshape_in = self.shape_out[:len(other.shape_out)] if cshape_in != other.shape_out: raise ValueError("Shape mismatch.") return cshape_out = self.shape_out[len(other.shape_out):] ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] zsn = 1 for k in cshape_out: zsn *= k cepsdzsn = ceps_d / zsn r = None if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)): r = torch.zeros(self.shape_in + cshape_in + cshape_out, dtype=torch.float64) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in cshape_in]): for ws in itertools.product(*[range(w) for w in cshape_out]): r[xs + zs + ws] = (self.p[xs + zs + ws] + cepsdzsn) / (other.p[xs + zs] + ceps_d) else: r = numpy.zeros(self.shape_in + cshape_in + cshape_out) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in cshape_in]): if other.p[xs + zs] > ceps: for ws in itertools.product(*[range(w) for w in cshape_out]): r[xs + zs + ws] = self.p[xs + zs + ws] / other.p[xs + zs] else: for ws in itertools.product(*[range(w) for w in cshape_out]): r[xs + zs + ws] = 1.0 / zsn return ConcDist(r, num_in = len(self.shape_in + cshape_in)) def semidirect(self, other, ids = None): """Semidirect product. """ if ids is None: ids = range(len(other.shape_in)) r = None if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)): r = torch.zeros(self.shape_in + self.shape_out + other.shape_out, dtype=torch.float64) else: r = numpy.zeros(self.shape_in + self.shape_out + other.shape_out) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): for ws in itertools.product(*[range(w) for w in other.shape_out]): r[xs + zs + ws] = self.p[xs + zs] * other.p[tuple(zs[i] for i in ids) + ws] return ConcDist(r, num_in = self.get_num_in()) def marginal(self, *args): """ Marginal distribution. Parameters ---------- *args : int Indices of the random variables of interest. E.g. for P(Y0,Y1,Y2|X), P.marginal(0,2) gives P(Y0,Y2|X) Returns ------- ConcDist The marginal distribution. """ ids = args if isinstance(ids, int): ids = [ids] r = None cshape = tuple(self.shape_out[i] for i in ids) if torch is not None and isinstance(self.p, torch.Tensor): r = torch.zeros(self.shape_in + cshape, dtype=torch.float64) else: r = numpy.zeros(self.shape_in + cshape) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in cshape]): wrange = [range(w) for w in self.shape_out] for i in range(len(cshape)): if len(wrange[ids[i]]) == 1 and wrange[ids[i]][0] != zs[i]: break wrange[ids[i]] = [zs[i]] else: for ws in itertools.product(*wrange): r[xs + zs] += self.p[xs + ws] return ConcDist(r, num_in = self.get_num_in()) def reorder(self, ids): r = None cshape = tuple(self.shape_out[i] for i in ids) if torch is not None and isinstance(self.p, torch.Tensor): r = torch.zeros(self.shape_in + cshape, dtype=torch.float64) else: r = numpy.zeros(self.shape_in + cshape) for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): r[xs + tuple(zs[ids[i]] for i in range(len(ids)))] = self.p[xs + zs] return ConcDist(r, num_in = self.get_num_in()) def given(self, *args): """ For a conditional distribution P(Y|X), give the distribution P(Y|X=x). Parameters ---------- *args : int or None. The values to substitute to X. Must have the same number of arguments as the number of random variables conditioned. Arguments are either int (value of RV) or None if the RV is not substituted. Returns ------- ConcDist The distribution after substitution. """ r = None cshape = tuple(self.shape_in[i] for i in range(len(self.shape_in)) if args[i] is None) if torch is not None and isinstance(self.p, torch.Tensor): r = torch.zeros(cshape + self.shape_out, dtype=torch.float64) else: r = numpy.zeros(cshape + self.shape_out) for xs in itertools.product(*[range(x) for x in cshape]): xs2 = [0] * len(self.shape_in) xsi = 0 for i in range(len(self.shape_in)): if args[i] is None: xs2[i] = xs[xsi] xsi += 1 else: xs2[i] = args[i] xs2 = tuple(xs2) for zs in itertools.product(*[range(z) for z in self.shape_out]): r[xs + zs] = self.p[xs2 + zs] return ConcDist(r, num_in = len(cshape)) def mean(self, f = None): """ Returns the expectation of the function f. Parameters ---------- f : function, numpy.array or torch.Tensor If f is a function, the number of arguments must match the number of dimensions (random variables) of the joint distribution. If f is an array or tensor, shape must match the shape of the distribution. Returns ------- r : float or torch.Tensor The expectation. Type is torch.Tensor if self or f is torch.Tensor. """ if f is None: f = lambda x: x r = None for xs in itertools.product(*[range(x) for x in self.shape_in]): for zs in itertools.product(*[range(z) for z in self.shape_out]): if callable(f): t = self.p[xs + zs] * f(*(xs + zs)) else: t = self.p[xs + zs] * f[xs + zs] if r is not None: r += t else: r = t return r def __str__(self): return str(self.p) def __repr__(self): r = "" r += "ConcDist(" t = repr(self.p) if t.find("\n") >= 0: r += "\n" r += t if self.get_num_in() > 0: r += ", num_in=" + str(self.get_num_in()) r += ")" return r @latex_postprocess def _latex_(self): t = ExprArray(self.p) t.set_float(self.force_float) return t._latex_() def tostring(self, style = 0): """Convert to string. Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) if style & PsiOpts.STR_STYLE_LATEX: return self._latex_() return str(self) def fcn(fcncall, shape): shape_in, shape_out = ConcDist.convert_shape_pair(shape) return ConcDist(ExprArray.fcn(fcncall, shape_in + shape_out), shape_in = shape_in, shape_out = shape_out) def det_fcn(fcncall, shape, isvar = False): shape_in, shape_out = ConcDist.convert_shape_pair(shape) shape_out = list(shape_out) ys = [] for xs in itertools.product(*[range(x) for x in shape_in]): t = fcncall(*xs) # t = 0 # if len(xs) == 1: # t = fcncall(xs[0]) # else: # t = fcncall(xs) if isinstance(t, (bool, float)): t = int(t) if isinstance(t, int): t = (t,) for i in range(len(shape_out)): if shape_out[i] is None or shape_out[i] < t[i] + 1: shape_out[i] = t[i] + 1 ys.append(t) shape_out = tuple(shape_out) p = numpy.zeros(shape_in + shape_out) for xs, t in zip(itertools.product(*[range(x) for x in shape_in]), ys): p[xs + t] = 1.0 return ConcDist(p, num_in = len(shape_in), isvar = isvar) def isvalid(self): """Whether this distribution is valid. """ ceps = PsiOpts.settings["eps_check"] for xs in itertools.product(*[range(x) for x in self.shape_in]): csum = 0.0 for zs in itertools.product(*[range(z) for z in self.shape_out]): c = float(self.p[xs + zs]) if c < -ceps: return False csum += c if abs(csum - 1.0) > ceps: return False return True def valid_region(self, skip_simplify = False): """For a symbolic distribution, returns the region where this is a valid distribution. """ if self.expr is None: return Region.universe() r = Region.universe() for xs in itertools.product(*[range(x) for x in self.shape_in]): sumexpr = Expr.zero() for zs in itertools.product(*[range(z) for z in self.shape_out]): cexpr = self.expr[xs + zs] r.iand_norename(cexpr >= 0) sumexpr += cexpr r.iand_norename(sumexpr == 1) if not skip_simplify: return r.simplified() return r def uniform(n, isvar = False): """n-ary uniform distribution.""" return ConcDist(numpy.ones(n) / n, isvar = isvar) def bit(isvar = False): """Fair bit.""" return ConcDist.uniform(2, isvar = isvar) def bern(a, isvar = False): """Bernoulli distribution.""" return ConcDist([1.0 - a, a], isvar = isvar) def random(n, isvar = False): """n-ary random distribution.""" r = ConcDist(numpy.ones(n) / n, isvar = isvar) r.randomize() return r def symm_chan(n, crossover, isvar = False): """n-ary symmetric channel.""" a = 1.0 - crossover b = crossover / (n - 1) # return ConcDist(numpy.ones((n, n)) * b + numpy.eye(n) * (a - b), num_in = 1, isvar = isvar) return ConcDist([[a if i == j else b for j in range(n)] for i in range(n)], num_in = 1, isvar = isvar) def bsc(crossover, isvar = False): """Binary symmetric channel.""" return ConcDist.symm_chan(2, crossover, isvar = isvar) def bin_chan(cross01, cross10, isvar = False): """Binary channel.""" return ConcDist([[1.0 - cross01, cross01], [cross10, 1.0 - cross10]], num_in = 1, isvar = isvar) def erasure_chan(n, er_prob, isvar = False): """n-ary erasure channel.""" # return ConcDist(numpy.hstack([numpy.eye(n) * (1.0 - er_prob), # numpy.ones((n, 1)) * er_prob]), num_in = 1, isvar = isvar) return ConcDist([[(1.0 - er_prob) if i == j else 0.0 for j in range(n)] + [er_prob] for i in range(n)], num_in = 1, isvar = isvar) def bec(er_prob, isvar = False): """Binary erasure channel.""" return ConcDist.erasure_chan(2, er_prob, isvar = isvar) def equal(shape_in, isvar = False): """Transition probability for two equal random vectors.""" if isinstance(shape_in, int): shape_in = (shape_in,) p = numpy.zeros(shape_in + shape_in) for xs in itertools.product(*[range(x) for x in shape_in]): p[xs + xs] = 1.0 return ConcDist(p, num_in = len(shape_in), isvar = isvar) def flat(shape_in, isvar = False): """Transition probability for flattening a random vector into one random variable.""" if isinstance(shape_in, int): shape_in = (shape_in,) nout = iutil.product(shape_in) p = numpy.zeros(shape_in + (nout,)) for xs in itertools.product(*[range(x) for x in shape_in]): t = 0 for a, b in zip(xs, shape_in): t = t * b + a p[xs + (t,)] = 1.0 return ConcDist(p, num_in = len(shape_in), isvar = isvar) def add(shape_in, isvar = False): """Transition probability from several random variables to their sum.""" if isinstance(shape_in, int): shape_in = (shape_in,) nout = sum(shape_in) - len(shape_in) + 1 p = numpy.zeros(shape_in + (nout,)) for xs in itertools.product(*[range(x) for x in shape_in]): p[xs + (sum(xs),)] = 1.0 return ConcDist(p, num_in = len(shape_in), isvar = isvar) def gaussian(r, l, isvar = False): """Quantized standard Gaussian distribution in the range [-r, r], divided into l cells. """ sqrt2 = numpy.sqrt(2.0) p = numpy.zeros(l) cdf = 0.0 cdf0 = 0.0 for i in range(l + 1): x = (i * 2.0 / l - 1.0) * r cdf2 = 0.5 * (1 + scipy.special.erf(x / sqrt2)) if i > 0: p[i - 1] = cdf2 - cdf cdf = cdf2 if i == 0: cdf0 = cdf for i in range(l): p[i] /= cdf - cdf0 return ConcDist(p, num_in = 0, isvar = isvar) def convolve_kernel(shape_in, kernel, isvar = False): """Transition probability from X to X+Z, where X is a random vector with a pmf of shape shape_in, and Z is independent of X and follows the distribution given by kernel. """ if isinstance(shape_in, int): shape_in = (shape_in,) if isinstance(kernel, ConcDist): kernel = kernel.p shape_out = tuple(a + b - 1 for a, b in zip(shape_in, kernel.shape)) p = numpy.zeros(shape_in + shape_out) for xs in itertools.product(*[range(x) for x in shape_in]): for zs in itertools.product(*[range(x) for x in kernel.shape]): p[xs + tuple(x + z for x, z in zip(xs, zs))] = kernel[zs] return ConcDist(p, num_in = len(shape_in), isvar = isvar) class ConcReal(IBaseObj): """Concrete real variable.""" def __init__(self, x = None, lbound = None, ubound = None, scale = 1.0, isvar = False, isint = False, randomize = False): self.force_float = True self.isvar = isvar self.isint = isint if self.isvar and torch is None: raise RuntimeError("Requires pytorch. Please install it first.") if x is None: x = 0.0 if isinstance(x, int): x = float(x) if isinstance(x, ConcReal): x = x.x if isinstance(lbound, int): lbound = float(lbound) if isinstance(ubound, int): ubound = float(ubound) if isinstance(scale, int): scale = float(scale) self.x = x self.v = None self.lbound = lbound self.ubound = ubound self.scale = scale self.copy_torch() if randomize: self.randomize() def copy(self): r = ConcReal(x = iutil.copy(self.x), lbound = self.lbound, ubound = self.ubound, scale = self.scale, isvar = self.isvar, isint = self.isint) r.force_float = self.force_float r.x = iutil.copy(self.x) r.v = None r.copy_torch() return r def const(x = 0.0): """Constant. """ return ConcReal(x, x, x) def convert(x): if isinstance(x, int) or isinstance(x, float): x = ConcReal.const(x) elif not isinstance(x, ConcReal): x = ConcReal(x) return x def calc_torch(self): if self.isvar: self.x = self.v def copy_torch(self): if self.isvar: if self.v is None: self.v = torch.tensor(self.x, dtype=torch.float64, requires_grad = True) with torch.no_grad(): self.v.copy_(torch.tensor(self.x, dtype=torch.float64)) self.x = self.v def copy_(self, other): """Copy content of other to self. """ if isinstance(other, int): other = float(other) if isinstance(other, ConcReal): self.x = other.x else: self.x = other self.copy_torch() def __add__(self, other): other = ConcReal.convert(other) return ConcReal(self.x + other.x, None if self.lbound is None or other.lbound is None else self.lbound + other.lbound, None if self.ubound is None or other.ubound is None else self.ubound + other.ubound, self.scale + other.scale) def __radd__(self, other): return self + other def __mul__(self, other): other = ConcReal.convert(other) lbound = None ubound = None for a in [self.lbound, self.ubound]: for b in [other.lbound, other.ubound]: if a is not None and b is not None: t = a * b if lbound is None or lbound > t: lbound = t if ubound is None or ubound < t: ubound = t return ConcReal(self.x * other.x, lbound, ubound, self.scale * other.scale) def __rmul__(self, other): return self * other def __sub__(self, other): return self + other * -1 def __rsub__(self, other): return other + self * -1 def __neg__(self): return self * -1 def __truediv__(self, other): return self * (1 / other) def __rtruediv__(self, other): if (isinstance(other, int) or isinstance(other, float)) and other > 0: other = float(other) lbound = None ubound = None if self.lbound is not None and self.lbound > 0: ubound = other / self.lbound if self.ubound is not None: lbound = other / self.ubound if self.ubound is not None and self.ubound < 0: lbound = other / self.ubound if self.lbound is not None: ubound = other / self.lbound return ConcReal(other / self.x, lbound, ubound, other / self.scale) return other * (1 / self) def clamp(self): if not self.isvar: return vt = float(self.v.detach().numpy()) if numpy.isnan(vt): self.randomize() return if self.lbound is not None and vt < self.lbound: vt = self.lbound if self.ubound is not None and vt > self.ubound: vt = self.ubound if self.isint: vt = round(vt) with torch.no_grad(): self.v.copy_(torch.tensor(vt, dtype=torch.float64)) def randomize(self): rnd = PsiOpts.get_random() if self.lbound is None or self.ubound is None: self.x = rnd.exponential() * self.scale if self.ubound is not None: self.x = self.ubound - self.x elif self.lbound is not None: self.x = self.lbound + self.x else: if rnd.uniform() < 0.5: self.x *= -1 else: self.x = rnd.uniform(self.lbound, self.ubound) self.copy_torch() def fraction_snap(self, denom = None, eps = None): self.x = iutil.float_snap(float(self.x), denom = denom, eps = eps) self.copy_torch() def hop(self, prob): rnd = PsiOpts.get_random() if rnd.uniform() < prob: self.randomize() def get_x(self): return self.x def get_v(self): return self.v def __float__(self): return float(self.x) def __int__(self): return int(round(float(self.x))) def torch(self): """Convert to torch.Tensor.""" if iutil.istorch(self.x): return self.x return torch.tensor(self.x, dtype=torch.float64) @staticmethod def unbox(x): if isinstance(x, ConcReal): return x.x return x def __str__(self): return str(self.x) def __repr__(self): r = "" r += "ConcReal(" r += str(self.x) if self.lbound is not None: r += ", lbound=" + str(self.lbound) if self.ubound is not None: r += ", ubound=" + str(self.ubound) r += ")" return r @latex_postprocess def _latex_(self): return iutil.float_tostr(float(self.x), style = PsiOpts.STR_STYLE_LATEX, force_float = self.force_float) def tostring(self, style = 0): """Convert to string. Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) if style & PsiOpts.STR_STYLE_LATEX: return self._latex_() return str(self) class RealRV(IBaseObj): """Discrete real-valued random variable.""" def __init__(self, x, fcn = None, supp = None): self.comp = x self.fcn = fcn self.supp = supp @property def x(self): return self.comp class ConcModel(IBaseObj): """Concrete distributions of random variables and values of real variables.""" def __init__(self, bnet = None, istorch = None): if istorch is None: self.istorch = PsiOpts.settings["istorch"] else: self.istorch = istorch if self.istorch and torch is None: raise RuntimeError("Requires pytorch. Please install it first.") if bnet is None: self.bnet = BayesNet() else: self.bnet = bnet v = self.bnet.index.comprv n = len(v) # self.ps = [None] * n self.psmap = {} self.psmap_cache = {} self.psv = [None] * n self.psmap_mask = {} self.card = [None] * n self.pt = None self.index_real = IVarIndex() self.realvars = [] self.opt_reg = None def copy_shallow(self): r = ConcModel() r.istorch = self.istorch r.bnet = self.bnet.copy() r.psmap = dict(self.psmap) r.psmap_cache = dict(self.psmap_cache) r.psv = list(self.psv) r.psmap_mask = dict(self.psmap_mask) r.card = list(self.card) r.index_real = self.index_real.copy() r.realvars = list(self.realvars) return r def copy(self): r = ConcModel() r.istorch = self.istorch r.bnet = self.bnet.copy() r.psmap = iutil.copy(self.psmap) r.psmap_cache = iutil.copy(self.psmap_cache) r.psv = iutil.copy(self.psv) r.psmap_mask = iutil.copy(self.psmap_mask) r.card = iutil.copy(self.card) r.index_real = self.index_real.copy() r.realvars = iutil.copy(self.realvars) return r def set_real(self, x, v): self.clear_cache() if v == "var": self.set_real(x, ConcReal(isvar = True)) return if v == "var,rand": self.set_real(x, ConcReal(isvar = True, randomize = True)) return if isinstance(x, Expr): x = x.allcomp() x.record_to(self.index_real) i = self.index_real.get_index(x) while len(self.realvars) < len(self.index_real.compreal): self.realvars.append(None) self.realvars[i] = v def get_real(self, x): if isinstance(x, Expr): x = x.allcomp() i = self.index_real.get_index(x) if i < 0: return None return self.realvars[i] def clear_cache(self): self.pt = None # self.psmap = {key: item for key, item in self.psmap.items() if not item.iscache} self.psmap_mask = {} self.psmap_cache = {} def comp_to_tuple(self, x): r = [] for a in x: t = self.bnet.index.get_index(a) if t < 0: return None r.append(t) return tuple(r) def comp_to_pair(self, x): if isinstance(x, Term): return (self.comp_to_tuple(x.z), sum((self.comp_to_tuple(t) for t in x.x), tuple())) elif isinstance(x, Comp): return (tuple(), self.comp_to_tuple(x)) else: return (self.comp_to_tuple(x[0]), self.comp_to_tuple(x[1])) def tuple_to_comp(self, y): r = Comp.empty() for a in y: r += self.bnet.index.comprv[a] return r def comp_get_sublens(self, x): if isinstance(x, Term): r = [len(a) for a in x.x] r.pop() return r return [] def set_prob(self, x, p): if isinstance(p, str): opt_split = p.split(",") opt = None randomize = False isvar = False isfcn = False mode = "" for copt in opt_split: if copt == "var": isvar = True elif copt == "rand": randomize = True elif copt == "fcn": isfcn = True else: mode = copt if mode == "flat": shape_in, shape_out = self.convert_shape_pair(x) self.set_prob(x, ConcDist.flat(shape_in, isvar = isvar)) elif mode == "equal": shape_in, shape_out = self.convert_shape_pair(x) self.set_prob(x, ConcDist.equal(shape_in, isvar = isvar)) elif mode == "add": shape_in, shape_out = self.convert_shape_pair(x) self.set_prob(x, ConcDist.add(shape_in, isvar = isvar)) else: shape_in, shape_out = self.convert_shape_pair(x) self.set_prob(x, ConcDist(shape_in = shape_in, shape_out = shape_out, isvar = isvar, randomize = randomize, isfcn = isfcn)) return if isinstance(p, collections.Callable) and not isinstance(p, (ConcDist, list, ExprArray)): dist = ConcDist.det_fcn(p, self.convert_shape_pair(x)) self.set_prob(x, dist) return self.bnet += x cin, cout = self.comp_to_pair(x) if isinstance(p, list): if iutil.hasinstance(p, Expr): p = ExprArray(p) else: p = numpy.array(p) shape = p.shape if len(shape) != len(cin) + len(cout): raise ValueError("Number of dimensions of prob. table = " + str(len(shape)) + " does not match number of variables = " + str(len(cin) + len(cout)) + ".") return for j in range(len(cin + cout)): t = self.get_card_id((cin + cout)[j]) if t is not None and t != shape[j]: raise ValueError("Length of dimension " + str(self.bnet.index.comprv[(cin + cout)[j]]) + " of prob. table = " + str(shape[j]) + " does not match its cardinality = " + str(t) + ".") return while len(self.card) < len(self.bnet.index.comprv): self.card.append(None) while len(self.psv) < len(self.bnet.index.comprv): self.psv.append(None) for j in range(len(cin + cout)): self.card[(cin + cout)[j]] = shape[j] if not isinstance(p, ConcDist): p = ConcDist(p, num_in = len(cin), check_valid = True) self.psmap[(cin, cout)] = p for k in cout: self.psv[k] = (cin, cout) self.clear_cache() def calc_dist(self, p): if p is None: return if p.expr is not None: p.p = self[p.expr] def get_prob_mask(self, mask): t = self.psmap_mask.get(mask, None) self.calc_dist(t) if t is not None: return t n = len(self.bnet.index.comprv) k = 0 while (1 << (k + 1)) <= mask: k += 1 if self.psv[k] is None: raise ValueError("Random variable " + str(self.bnet.index.comprv[k]) + " has unspecified distribution.") return tin, tout = self.psv[k] tp = self.psmap[(tin, tout)] self.calc_dist(tp) tin_mask = 0 for a in tin: tin_mask |= 1 << a tout_mask = 0 for a in tout: tout_mask |= 1 << a mask1 = (mask | tin_mask) & ~tout_mask p2 = None if mask1 > 0: p1 = self.get_prob_mask(mask1) p2 = p1.semidirect(tp, [iutil.bitcount(mask1 & ((1 << i) - 1)) for i in tin]) else: p2 = tp idinv = [None] * n ci = 0 for i in range(n): if mask1 & (1 << i): idinv[i] = ci ci += 1 for i in tout: idinv[i] = ci ci += 1 p3 = p2.marginal(*[idinv[i] for i in range(n) if mask & (1 << i)]) self.psmap_mask[mask] = p3 return p3 def get_prob_pair(self, cin, cout): t = self.psmap.get((cin, cout), None) self.calc_dist(t) if t is not None: return t t = self.psmap_cache.get((cin, cout), None) self.calc_dist(t) if t is not None: return t istorch = self.istorch cinlen = [self.get_card_id(i) for i in cin] coutlen = [self.get_card_id(i) for i in cout] cin_mask = 0 for a in cin: cin_mask |= 1 << a cout_mask = 0 for a in cout: cout_mask |= 1 << a p1 = self.get_prob_mask(cin_mask | cout_mask) p1 = p1.reorder([iutil.bitcount((cin_mask | cout_mask) & ((1 << i) - 1)) for i in cin + cout]) r = None if cin_mask != 0: p0 = self.get_prob_mask(cin_mask) p0 = p0.reorder([iutil.bitcount(cin_mask & ((1 << i) - 1)) for i in cin]) r = p1 / p0 else: r = p1 self.psmap_cache[(cin, cout)] = r return r def get_prob(self, x): cin, cout = self.comp_to_pair(x) if cin is None or cout is None: raise ValueError("Some random variables are absent in the model.") return r = self.get_prob_pair(cin, cout) if isinstance(r, ConcDist): r.sublens = self.comp_get_sublens(x) return r def get_card_id(self, i): if i >= len(self.card): return None return self.card[i] def get_card(self, x): i = self.bnet.index.get_index(x) if i < 0: return None return self.get_card_id(i) def get_card_default(self, x): i = self.bnet.index.get_index(x) if i < 0: return x.get_card() return self.get_card_id(i) def convert_shape(self, x): if x is None: return tuple() if isinstance(x, int): return (x,) if isinstance(x, Comp): return tuple(self.get_card_default(a) for a in x) if isinstance(x, ConcDist): return x.shape_out return tuple(x) def convert_shape_pair(self, x): if x is None: return (tuple(), tuple()) if isinstance(x, int): return (tuple(), (x,)) if isinstance(x, Term): return (self.convert_shape(x.z), sum((self.convert_shape(t) for t in x.x), tuple())) if isinstance(x, Comp) or isinstance(x, ConcDist): return (tuple(), self.convert_shape(x)) if isinstance(x, tuple) and (len(x) == 0 or isinstance(x[0], int)): return (tuple(), tuple(x)) return (self.convert_shape(x[0]), self.convert_shape(x[1])) def get_H(self, x): p = self.get_prob(x) return p.entropy().x def get_ent_vector(self, x): n = len(x) r = [] for mask in range(1 << n): r.append(self.get_H(x.from_mask(mask))) return r def discover(self, x, eps = None, skip_simplify = False): """Discover conditional independence among variables in x. """ v = self.get_ent_vector(x) return Region.ent_vector_discover_ic(v, x, eps, skip_simplify) def get_region(self, vals = False): r = self.bnet.get_region() if vals: all_rv = self.bnet.index.comprv.copy() n = len(all_rv) for mask in range(1, 1 << n): x = all_rv.from_mask(mask) r.iand_norename(Expr.H(x) == float(self.get_H(x))) for y in self.allcomprealvar_exprlist(): r.iand_norename(y == float(self[y])) return r def allcomprv(self): return self.bnet.index.comprv.copy() def allcompreal(self): return self.index_real.compreal.copy() def allcomprealvar(self): return self.index_real.compreal.copy() def allcomp(self): return self.allcomprv() + self.allcompreal() def add_reg(self, reg, other_neighbors = None): varlist = [] cons = Region.universe() reg = reg.copy() regallcomprv = reg.allcomprv() if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: raise ValueError("User-defined information quantities with RegionOp constraints cannot be optimized.") return None card0 = PsiOpts.settings["opt_aux_card"] reg.aux_strengthen(self.bnet.index.comprv, other_neighbors = other_neighbors) reg.simplify_quick(zero_group = 0) regcom = reg.copy() regcom.iand_norename(regcom.completed_semigraphoid(max_iter = 10000)) # print(reg) # print(regcom) tbnet = regcom.get_bayesnet(roots = self.bnet.index.comprv.inter(regcom.allcomprv()), skip_simplify = True) # print(tbnet) fcnreg = Region.universe() clus = [] # print(reg) # print(regcom) # print(tbnet) rv_all = self.bnet.index.comprv.copy() for a in tbnet.index.comprv + regallcomprv: if self.bnet.index.get_index(a) >= 0: continue pa = None if tbnet.index.get_index(a) >= 0: pa = tbnet.get_parents(a) else: pa = rv_all.copy() rv_all += a isfcn = reg.copy_noaux().implies(Expr.Hc(a, pa) <= 0) if isfcn: fcnreg.iand_norename(Expr.Hc(a, pa) <= 0) # print(pa) # print(a) # print() for i in range(len(clus)): cpa, ca, cisfcn = clus[i] if pa == cpa + ca and cisfcn == isfcn: clus[i][1] += a break else: clus.append([pa, a, isfcn]) for cpa, ca, cisfcn in clus: ccards = [] for a in ca: ccard = a.get_card() if ccard is None: ccard = card0 ccards.append(ccard) p = ConcDist(shape = (tuple(self.get_card(t) for t in cpa), tuple(ccards)), isvar = True, randomize = True, isfcn = cisfcn) self[ca | cpa] = p varlist.append(p) # for a in tbnet.index.comprv + reg.allcomprv(): # if self.bnet.index.get_index(a) >= 0: # continue # ccard = a.get_card() # if ccard is None: # ccard = card0 # pa = None # if tbnet.index.get_index(a) >= 0: # pa = tbnet.get_parents(a) # else: # pa = self.bnet.index.comprv # # print(pa) # isfcn = reg.copy_noaux().implies(Expr.Hc(a, pa) <= 0) # if isfcn: # fcnreg.iand_norename(Expr.Hc(a, pa) <= 0) # # print(reg) # # print(a) # # print(pa) # # print(isfcn) # p = ConcDist(shape = (tuple(self.get_card(t) for t in pa), (ccard,)), # isvar = True, randomize = True, isfcn = isfcn) # self[a | pa] = p # varlist.append(p) for a in reg.allcomprealvar_exprlist(): if self.get_real(a) is None: t = ConcReal(isvar = True) self[a] = t varlist.append(t) csreg = self.get_region() & cons & fcnreg for ineq in reg: if not csreg.implies(ineq): cons.iand_norename(ineq) csreg.iand_norename(ineq) return (varlist, cons) def get_val_regterm(self, term, esgn): if term.fcncall is not None: cargs = [] for a in term.fcnargs: if isinstance(a, Term) or isinstance(a, Comp): cargs.append(self.get_prob(a)) elif isinstance(a, Expr): cargs.append(self.get_val(a, 0)) elif isinstance(a, Region): cargs.append(1 if self[a] else 0) elif a == "model": cargs.append(self) else: cargs.append(a) return term.get_fcneval(cargs) t = term.get_reg_sgn_bds() if t is None: return None reg, sgn, bds = t if len(bds) == 0: return None if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: raise ValueError("User-defined information quantities with RegionOp constraints cannot be optimized.") return None rcomp = reg.allcomprv() for b in bds: rcomp += b.allcomprv_shallow() if self.bnet.index.comprv.super_of(rcomp): r = numpy.inf for b in bds: t = self.get_val(b, esgn * sgn) * sgn if float(t) < float(r): r = t return r * sgn card0 = PsiOpts.settings["opt_aux_card"] reg.simplify_quick(zero_group = 0) cs = self.copy_shallow() varlist, cons = cs.add_reg(reg) tvar = None if len(bds) == 1: tvar = bds[0] else: tvar = Expr.real("#TVAR_" + term.get_name()) tvar_r = ConcReal(isvar = True, randomize = True) cs[tvar] = tvar_r varlist.append(tvar_r) for b in bds: cons &= tvar * sgn <= b * sgn retval = cs.optimize(tvar, varlist, cons, sgn = sgn) self.opt_reg = cs.opt_reg return retval def get_val(self, expr, esgn = 0): r = 0.0 for (a, c) in expr.terms: termType = a.get_type() if termType == TermType.IC: k = len(a.x) for t in range(1 << k): csgn = -1 mask = self.bnet.index.get_mask(a.z) for i in range(k): if (t & (1 << i)) != 0: csgn = -csgn mask |= self.bnet.index.get_mask(a.x[i]) if mask != 0: # r += self.get_H_mask(mask) * csgn * c r += self.get_H(self.bnet.index.from_mask(mask)) * csgn * c elif termType == TermType.REAL or termType == TermType.REGION: if a.isone(): r += c else: t = self.get_real(a.x[0]) if t is None: if termType == TermType.REGION: t = self.get_val_regterm(a, esgn * (1 if c > 0 else -1)) if t is None: return None if isinstance(t, ConcReal): t = t.x r += t * c return ConcReal(r) def check_region(self, x, reqgt0 = None, reqgt0_reg = None): if x.aux_present() and reqgt0 is None: x = x.simplified_quick() if not x.aux_present() and reqgt0 is None: return x.evalcheck(self) if isinstance(x, RegionOp) or x.imp_present(): if not x.auxi.isempty(): raise ValueError("Universally quantified random variables cannot be optimized.") return False clist = (~x).to_cause_consequence() for imp, cons, auxi in clist: # if auxi: # raise ValueError("Universally quantified random variables cannot be optimized.") # return False # if isinstance(cons, RegionOp): # raise ValueError("Consequences with union cannot be optimized.") # return False creqgt0 = None if not cons.isempty(): creqgt0 = cons.expr_sum_violate() # caux = x.aux.inter(imp.allcomprv() + cons.allcomprv()) # print(auxi) # print(imp.exists(auxi)) # print(cons) # print() t = self.check_region(imp.exists(auxi), reqgt0 = creqgt0, reqgt0_reg = cons) if t: return True return False # print(x.aux) # print(x.tostring()) cs = self.copy_shallow() varlist, cons = cs.add_reg(x, other_neighbors = reqgt0_reg) optval_cutoff = -numpy.inf if reqgt0 is None: reqgt0 = Expr.zero() else: optval_cutoff = PsiOpts.settings["eps_violate_cutoff"] # print(self) retval = cs.optimize(reqgt0, varlist, cons, sgn = 1, optval_cutoff = optval_cutoff) self.opt_reg = cs.opt_reg # print(self) # print(retval) # print(cs[rv("U")]) # print(cs[H(rv("U"))]) # print(x.evalcheck(cs)) # print() # if not reqgt0.iszero(): # if not retval >= PsiOpts.settings["eps_check"]: # return False if reqgt0_reg is not None and reqgt0_reg.evalcheck(cs): return False return x.evalcheck(cs) def __call__(self, x): if isinstance(x, Expr): return self.get_val(x) elif isinstance(x, ExprArray): if self.istorch: # r = torch.tensor([self[a] for a in x], dtype=torch.float64) # return torch.reshape(r, x.shape) r = torch.hstack(tuple(iutil.ensure_torch(self[a]) for a in x)) return torch.reshape(r, x.shape) else: r = numpy.array([float(self[a]) for a in x]) return numpy.reshape(r, x.shape) elif isinstance(x, Region): return self.check_region(x) return None def __getitem__(self, x): if isinstance(x, tuple): x = Term.from_symbols(x) if isinstance(x, tuple) or isinstance(x, Comp) or isinstance(x, Term): return self.get_prob(x) if isinstance(x, CompArray): return self.get_prob(x.get_term()) if isinstance(x, Expr) and len(x) == 1: a, c = x.terms[0] if c == 1: termType = a.get_type() if termType == TermType.REAL or termType == TermType.REGION: t = self.get_real(a.x[0]) if t is not None: return t return self(x) def __setitem__(self, key, value): if isinstance(key, tuple): key = Term.from_symbols(key) if isinstance(key, Expr): self.set_real(key.allcomp(), value) else: self.set_prob(key, value) def convert_torch_tensors(self, x): r = [] r_dist = [] if isinstance(x, list): for a in x: tr, tdist = self.convert_torch_tensors(a) r += tr r_dist += tdist elif torch is not None and isinstance(x, torch.Tensor): r.append(x) elif isinstance(x, Comp): for a in x: i = self.bnet.index.get_index(x) if i >= 0: tr, tdist = self.convert_torch_tensors(self.ps[i]) r += tr r_dist += tdist elif isinstance(x, Expr): t = self.get_real(x) if t is not None: tr, tdist = self.convert_torch_tensors(t) r += tr r_dist += tdist elif isinstance(x, ExprArray): for x2 in x: t = self.get_real(x2) if t is not None: tr, tdist = self.convert_torch_tensors(t) r += tr r_dist += tdist elif isinstance(x, ConcDist) or isinstance(x, ConcReal): t = x.get_v() if t is not None and torch is not None and isinstance(t, torch.Tensor): r.append(t) if x.isvar: r_dist.append(x) return (r, r_dist) def tensor_copy_list(x, y): with torch.no_grad(): for i in range(len(y)): if len(x) <= i: x.append(y[i].detach().clone()) else: x[i].copy_(y[i]) def get_tensor(self, x): r = None if isinstance(x, Expr): r = self.get_val(x) else: r = x(self) if isinstance(r, ConcReal): r = r.x return r def tensor_list_to_array(varlist): r = [] for v in varlist: k = functools.reduce(lambda x, y: x*y, v.shape, 1) r += list(numpy.reshape(v.detach().numpy(), (k,))) for i in range(len(r)): if numpy.isnan(r[i]): r[i] = 0.0 return numpy.array(r) def tensor_list_grad_to_array(varlist): r = [] for v in varlist: k = functools.reduce(lambda x, y: x*y, v.shape, 1) if v.grad is None: r += [0.0 for i in range(k)] else: r += list(numpy.reshape(v.grad.numpy(), (k,))) for i in range(len(r)): if numpy.isnan(r[i]): r[i] = 0.0 return numpy.array(r) def tensor_list_from_array(varlist, a): c = 0 for v in varlist: k = functools.reduce(lambda x, y: x*y, v.shape, 1) with torch.no_grad(): v.copy_(torch.tensor(numpy.reshape(a[c:c+k], v.shape), dtype=torch.float64)) if v.grad is not None: v.grad.data.zero_() c += k def tensor_list_get_bds(varlist, distlist, ismat): c = 0 cons = [] num_cons = 0 cons_A_data = [] cons_A_r = [] cons_A_c = [] xsize = 0 for v in varlist: k = functools.reduce(lambda x, y: x*y, v.shape, 1) xsize += k bds = [(-numpy.inf, numpy.inf) for i in range(xsize)] for v in varlist: k = functools.reduce(lambda x, y: x*y, v.shape, 1) for d in distlist: if d.get_v() is v: if isinstance(d, ConcDist): stride = functools.reduce(lambda x, y: x*y, d.shape_out, 1) - 1 if stride >= 2: for c2 in range(c, c+k, stride): if ismat: cons_A_data += [1.0] * stride cons_A_r += [num_cons] * stride cons_A_c += list(range(c2, c2+stride)) num_cons += 1 else: def get_fcn(i0, i1): def fcn(x): return 1.0 - sum(x[i0: i1]) return fcn def get_jac(i0, i1): def fcn(x): r = numpy.zeros(xsize) r[i0:i1] = -numpy.ones(i1-i0) return r return fcn cons.append({ "type": "ineq", "fun": get_fcn(c2, c2+stride), "jac": get_jac(c2, c2+stride) }) for i in range(c, c+k): bds[i] = (0.0, 1.0) elif isinstance(d, ConcReal): if d.lbound is not None or d.ubound is not None: bds[c] = (-numpy.inf if d.lbound is None else d.lbound, numpy.inf if d.ubound is None else d.ubound) break c += k if ismat: if num_cons == 0: return (bds, []) # mat = scipy.sparse.csr_matrix((cons_A_data, (cons_A_r, cons_A_c)), shape = (num_cons, xsize)) mat = scipy.sparse.coo_matrix((cons_A_data, (cons_A_r, cons_A_c)), shape = (num_cons, xsize)).toarray() lcons = scipy.optimize.LinearConstraint(mat, -numpy.inf, 1.0) return (bds, [lcons]) else: return (bds, cons) def optimize(self, expr, vs, reg = None, sgn = 1, optimizer = None, learnrate = None, learnrate2 = None, momentum = None, num_iter = None, num_iter2 = None, num_points = None, num_hop = None, hop_temp = None, hop_prob = None, alm_rho = None, alm_rho_pow = None, alm_step = None, alm_penalty = None, eps_converge = None, eps_tol = None, optval_cutoff = None): """ Minimize/maximize expr with variables in the list vs, constrained in the region reg. Uses a combination of SLSQP, gradient descent, Adam (or any optimizer with the pytorch interface) and basin-hopping. Constraints are handled using augmented Lagrangian method. Kraft, D. A software package for sequential quadratic programming. 1988. Tech. Rep. DFVLR-FB 88-28, DLR German Aerospace Center - Institute for Flight Mechanics, Koln, Germany. Kingma, Diederik P., and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014). Wales, David J.; Doye, Jonathan P. K. (1997). "Global Optimization by Basin-Hopping and the Lowest Energy Structures of Lennard-Jones Clusters Containing up to 110 Atoms". The Journal of Physical Chemistry A. 101 (28): 5111-5116. Hestenes, M. R. (1969). "Multiplier and gradient methods". Journal of Optimization Theory and Applications. 4 (5): 303-320. Parameters ---------- expr : Expr The expression to be minimized. Can either be an Expr object or a callable accepting ConcModel as argument. vs : list The variables to be optimized over (list of ConcDist and/or ConcReal). reg : Region, optional The region where the variables are constrained. The default is None. sgn : int, optional Set to 1 for maximization, -1 for minimization. The default is 1. optimizer : closure, optional The optimizer. Either "sgd", "adam" or any function that returns a pytorch optimizer given the list of variables as argument. The default is None. learnrate : float, optional Learning rate. The default is None. learnrate2 : float, optional Learning rate in phase 2. The default is None. momentum : float, optional Momentum. The default is None. num_iter : int, optional Number of iterations. The default is None. num_iter2 : int, optional Number of iterations in phase 2. The default is None. num_points : int, optional Number of random starting points. The default is None. num_hop : int, optional Number of hops for basin-hopping. The default is None. hop_temp : float, optional Temperature for basin-hopping. The default is None. hop_prob : float, optional Probability of hopping for basin-hopping. The default is None. alm_rho : float, optional Rho in augmented Lagrangian method. The default is None. alm_rho_pow : float, optional Multiply rho by this amount after each iteration. The default is None. eps_converge : float, optional Convergence is declared if the change is smaller than eps_converge. The default is None. eps_tol : float, optional Tolerance for equality and inequality constraints. The default is None. optval_cutoff : float, optional Stop the algorithm when the objective goes beyond this value. The default is None. Returns ------- r : float The minimum value. """ verbose = PsiOpts.settings.get("verbose_opt", False) verbose_step = PsiOpts.settings.get("verbose_opt_step", False) verbose_step_var = PsiOpts.settings.get("verbose_opt_step_var", False) if optimizer is None: optimizer = PsiOpts.settings["opt_optimizer"] if learnrate is None: learnrate = PsiOpts.settings["opt_learnrate"] if learnrate2 is None: learnrate2 = PsiOpts.settings["opt_learnrate2"] if momentum is None: momentum = PsiOpts.settings["opt_momentum"] if num_iter is None: num_iter = PsiOpts.settings["opt_num_iter"] if num_iter2 is None: num_iter2 = PsiOpts.settings["opt_num_iter2"] if num_points is None: num_points = PsiOpts.settings["opt_num_points"] if num_hop is None: num_hop = PsiOpts.settings["opt_num_hop"] if hop_temp is None: hop_temp = PsiOpts.settings["opt_hop_temp"] if hop_prob is None: hop_prob = PsiOpts.settings["opt_hop_prob"] if alm_rho is None: alm_rho = PsiOpts.settings["opt_alm_rho"] if alm_rho_pow is None: alm_rho_pow = PsiOpts.settings["opt_alm_rho_pow"] if alm_step is None: alm_step = PsiOpts.settings["opt_alm_step"] if alm_penalty is None: alm_penalty = PsiOpts.settings["opt_alm_penalty"] if eps_converge is None: eps_converge = PsiOpts.settings["opt_eps_converge"] if eps_tol is None: eps_tol = PsiOpts.settings["opt_eps_tol"] rnd = PsiOpts.get_random() truth = PsiOpts.settings["truth"] cs = self.copy_shallow() if isinstance(expr, (int, float)): expr = Expr.const(expr) # if reg is not None or expr.isregtermpresent(): # tvaropt = Expr.real("#TMPVAROPT") # reg = reg & (tvaropt * sgn <= expr * sgn) # reg = reg. if not isinstance(vs, list): vs = [vs] if reg is None: reg = [] elif isinstance(reg, Region) or isinstance(reg, tuple): reg = [reg] elif not isinstance(reg, list): reg = [(reg, ">=")] for i in range(len(reg)): if isinstance(reg[i], RegionOp): reg[i] = reg[i].tosimple() if reg[i] is None: reg[i] = Region.universe() if truth is not None: truth_simple = truth.tosimple() if truth_simple is not None: reg.append(truth_simple) for i in range(len(reg)): if not cs.allcomprv().super_of(reg[i].allcomprv()): tvarlist, tcons = cs.add_reg(reg[i]) reg[i] = tcons vs += tvarlist def str_tuple(x): if x[0] == 0: return iutil.tostr_verbose(x[1] * -sgn) return iutil.tostr_verbose((x[0], x[1] * -sgn)) def tuple_copy1(x, y): for i in range(len(y)): if len(x) <= i: x.append([None, y[i][1]]) else: x[i][1] = y[i][1] varlist, distlist = cs.convert_torch_tensors(vs) cons = [] cons_init = [] scipy_optimizer = None if optimizer == "sgd": pass elif optimizer == "adam": pass elif isinstance(optimizer, str): scipy_optimizer = optimizer for creg in reg: if isinstance(creg, Region): creg = creg.copy() creg.simplify_redundant() for a in creg.exprs_ge: if a.isnonpos_ic2(): if self.bnet.check_ic(a): continue if scipy_optimizer is None: slack = torch.tensor(0.0, dtype=torch.float64, requires_grad = True) cons.append([a, 0.0, slack]) varlist.append(slack) else: cons.append([a, 0.0, True]) for a in creg.exprs_eq: if a.isnonpos_ic2() or a.isnonneg_ic2(): if self.bnet.check_ic(a): continue cons.append([a, 0.0]) elif isinstance(creg, tuple): if creg[1] == ">=": if scipy_optimizer is None: slack = torch.tensor(0.0, dtype=torch.float64, requires_grad = True) cons.append([creg[0], 0.0, slack]) varlist.append(slack) else: cons.append([creg[0], 0.0, True]) # if reg.isuniverse(): # reg = None ismat = False bds = None cons_list = None cons_list_fcn = None use_jac = False def get_fcn(cexpr, mul, has_f, has_d): def fcn(x): # print(str(cexpr) + " " + str(x)) cs.clear_cache() ConcModel.tensor_list_from_array(varlist, x) for d in distlist: d.calc_torch() cval = cs.get_tensor(cexpr) if mul != 1: cval *= mul # print(" " + str(float(cval))) if not has_d: return float(cval) if isinstance(cval, float): if has_f: return (float(cval), ConcModel.tensor_list_grad_to_array(varlist) * 0.0) else: return ConcModel.tensor_list_grad_to_array(varlist) * 0.0 cval.backward() if has_f: return (float(cval), ConcModel.tensor_list_grad_to_array(varlist)) else: return ConcModel.tensor_list_grad_to_array(varlist) return fcn if scipy_optimizer is not None: ismat = (scipy_optimizer == "trust-constr") use_jac = True for cismat in ([False, True] if ismat else [False]): bds, cons_list = ConcModel.tensor_list_get_bds(varlist, distlist, ismat = cismat) # print(bds) for a in cons: cons_dict = {} if len(a) >= 3: cons_dict["type"] = "ineq" else: cons_dict["type"] = "eq" cons_dict["fun"] = get_fcn(a[0], 1, True, False) cons_dict["jac"] = get_fcn(a[0], 1, False, True) if cismat: cons_list.append(scipy.optimize.NonlinearConstraint( cons_dict["fun"], 0.0, numpy.inf, jac = cons_dict["jac"] )) else: cons_list.append(cons_dict) if not cismat: cons_list_fcn = cons_list tuple_copy1(cons_init, cons) int_big = 100000000 r = (int_big, numpy.inf, int_big) r_v = [] r_v_d = [] t = (int_big, numpy.inf, int_big) num_iter_list = [num_iter] if num_iter2 > 0: num_iter_list.append(num_iter2) if verbose: print("======== model ========") for a in cs.bnet.index.comprv: print(str(cs.bnet.get_parents(a)) + " -> " + str(a) + " card=" + str(cs.get_card(a))) print("======== probs ========") for (cin, cout), cp in cs.psmap.items(): print(str(sum(cs.bnet.index.comprv[i] for i in cin)) + " -> " + str(sum(cs.bnet.index.comprv[i] for i in cout)) + (" var" if cp.isvar else "") + (" fcn" if cp.isfcn else "")) if sgn > 0: print("======== maximize ========") else: print("======== minimize ========") print(expr) print("======== over ========") for d in distlist: if isinstance(d, ConcDist): print("dist shape=" + str((d.shape_in, d.shape_out))) elif isinstance(d, ConcDist): print("real=" + str(d.x)) if len(cons): print("======== constraints ========") for a in cons: if len(a) >= 3: print(str(a[0]) + ">=0") else: print(str(a[0]) + "==0") for cpass, cur_num_iter in enumerate(num_iter_list): cur_num_points = 1 if cpass == 0: cur_num_points = num_points tocutoff = False for ip in range(cur_num_points): if PsiOpts.is_timer_ended(): break if tocutoff: break if ip > 0: for d in distlist: d.randomize() tuple_copy1(cons, cons_init) cs.clear_cache() cur_lr = learnrate cur_num_hop = num_hop if cpass > 0: cur_lr = learnrate2 cur_num_hop = 1 for ih in range(cur_num_hop): if PsiOpts.is_timer_ended(): break if verbose_step: print("Pass #" + str(cpass) + "/" + str(len(num_iter_list)) + ", Point #" + str(ip) + "/" + str(cur_num_points) + ", Hop #" + str(ih) + "/" + str(cur_num_hop)) r_start = [] r_start_d = [] v_start = t if ih > 0: ConcModel.tensor_copy_list(r_start, varlist) tuple_copy1(r_start_d, cons) for d in distlist: d.hop(hop_prob) tuple_copy1(cons, cons_init) cs.clear_cache() if scipy_optimizer is None: if optimizer == "sgd": cur_optimizer = torch.optim.SGD(varlist, lr = cur_lr, momentum = momentum) elif optimizer == "adam": cur_optimizer = torch.optim.Adam(varlist, lr = cur_lr) else: cur_optimizer = optimizer(varlist) # for a in con_eq: # a[1] = 0.0 # for a in con_ge: # a[1] = 0.0 # a[2].copy_(0.0) cr = (int_big, numpy.inf, int_big) # lastval = (int_big, numpy.inf) lastval = numpy.inf cur_num_violate = 0 if scipy_optimizer is not None: res = None resx = None resfun = None #!!!!!!!!!!!!!!!!!!!!!!!! # print(ConcModel.tensor_list_to_array(varlist)) # print(get_fcn(expr, -sgn, True, use_jac)(ConcModel.tensor_list_to_array(varlist))) # print() if len(varlist): with warnings.catch_warnings(): warnings.simplefilter("ignore") opts = {"maxiter": cur_num_iter, "disp": verbose_step} if scipy_optimizer == "trust-constr": opts["initial_tr_radius"] = 0.001 if verbose_step: opts["verbose"] = 3 res = scipy.optimize.minimize( get_fcn(expr, -sgn, True, use_jac), ConcModel.tensor_list_to_array(varlist), method = scipy_optimizer, jac = use_jac, bounds = bds, constraints = cons_list, tol = eps_tol, options = opts ) resx = res.x resfun = res.fun else: resx = numpy.array([]) resfun = get_fcn(expr, -sgn, True, use_jac)(numpy.array([]))[0] ConcModel.tensor_list_from_array(varlist, resx) bad = False for d in distlist: if isinstance(d, ConcDist) and d.isfcn: d.clamp() bad = True if bad: resx = ConcModel.tensor_list_to_array(varlist) resfun = get_fcn(expr, -sgn, True, use_jac)(resx)[0] num_violate = 0 sum_violate = 0.0 for a in cons_list_fcn: t = a["fun"](resx) if numpy.isnan(t): sum_violate = numpy.inf num_violate += 1 continue if a["type"] == "eq": if abs(t) > eps_tol: sum_violate += abs(t) num_violate += 1 else: if t < -eps_tol: sum_violate += -t num_violate += 1 penalty = sum_violate * alm_penalty score = resfun + penalty if numpy.isnan(score): score = numpy.inf t = (0, score, num_violate) cur_num_violate = num_violate if verbose_step: print((("pass=" + str(cpass + 1) + " ") if cpass > 0 else "") + "#pt=" + str(ip) + " val=" + str(resfun) + " violate=" + str((num_violate, sum_violate)) + " opt=" + str_tuple(r)) else: for it in range(cur_num_iter + 1): if PsiOpts.is_timer_ended(): break # cur_optimizer.zero_grad() cval = cs.get_tensor(expr) if -sgn != 1: cval *= -sgn t_cval = float(cval) if numpy.isnan(t_cval): t_cval = numpy.inf crho = numpy.power(alm_rho_pow, it // alm_step) * alm_rho crho0 = numpy.power(alm_rho_pow, (it - 1) // alm_step) * alm_rho num_violate = 0 sum_dual_step = 0.0 penalty = 0.0 for a in cons: con_val = cs.get_tensor(a[0]) p_con_val = float(con_val) if len(a) >= 3: if p_con_val < -eps_tol: num_violate += 1 penalty += -p_con_val * alm_penalty else: if abs(p_con_val) > eps_tol: num_violate += 1 penalty += abs(p_con_val) * alm_penalty if len(a) >= 3: con_val -= a[2] t_con_val = float(con_val) if numpy.isnan(t_con_val): t_con_val = 0.0 sum_dual_step += abs(t_con_val) if it > 0 and it % alm_step == 0: t = a[1] + crho0 * t_con_val if not numpy.isnan(t): a[1] = t cval += con_val * a[1] + torch.square(con_val) * crho * 0.5 if verbose_step: if len(a) >= 3: # print(str(a[0]) + ">=0 : val=" + str(p_con_val) + " dual=" + str(a[1])) print(str(a[0]) + ">=0 : val=" + iutil.tostr_verbose(p_con_val) + " dual=" + iutil.tostr_verbose(a[1]) + " slack=" + iutil.tostr_verbose(float(a[2]))) else: print(str(a[0]) + "==0 : val=" + iutil.tostr_verbose(p_con_val) + " dual=" + iutil.tostr_verbose(a[1])) t = (num_violate, t_cval + penalty, num_violate) cur_num_violate = num_violate # print(t) cr = min(cr, t) if not numpy.isnan(cr[1]) and cr < r: r = cr ConcModel.tensor_copy_list(r_v, varlist) tuple_copy1(r_v_d, cons) if verbose_step: print((("pass=" + str(cpass + 1) + " ") if cpass > 0 else "") + "#pt=" + str(ip) + " #iter=" + str(it) + " val=" + str_tuple(t) + " opt=" + str_tuple(min(r, cr))) if verbose_step_var: for d in distlist: print(d) # if t[0] == lastval[0] and abs(t[1] - lastval[1]) <= eps_converge: # break if abs(t_cval - lastval) + sum_dual_step <= eps_converge: break if it == cur_num_iter: break # lastval = t lastval = t_cval cur_optimizer.zero_grad() cval.backward() cur_optimizer.step() for d in distlist: d.clamp() for a in cons: if len(a) >= 3: if numpy.isnan(float(a[2])) or float(a[2]) < 0: with torch.no_grad(): a[2].copy_(torch.tensor(0.0, dtype=torch.float64)) cs.clear_cache() # print(t) cr = min(cr, t) if not numpy.isnan(cr[1]) and cr < r: r = cr ConcModel.tensor_copy_list(r_v, varlist) tuple_copy1(r_v_d, cons) toreject = False if ih > 0: accept_prob = 1.0 if numpy.isnan(t[1]): accept_prob = 0.0 elif t > v_start: accept_prob = numpy.exp((v_start[1] - t[1]) / hop_temp) if verbose_step: print("HOP #" + str(ih) + " from=" + str_tuple(v_start) + " to=" + str_tuple(t) + " prob=" + iutil.tostr_verbose(accept_prob)) if rnd.uniform() >= accept_prob: toreject = True if verbose_step: print("HOP REJECT") ConcModel.tensor_copy_list(varlist, r_start) tuple_copy1(cons, r_start_d) for d in distlist: d.calc_torch() t = v_start cs.clear_cache() if not toreject: if verbose_step: print("HOP ACCEPT") # print(t) if optval_cutoff is not None and t[2] == 0 and t[1] <= optval_cutoff * -sgn: if verbose_step: print("CUTOFF") tocutoff = True if tocutoff: break if len(r_v): ConcModel.tensor_copy_list(varlist, r_v) tuple_copy1(cons, r_v_d) for d in distlist: d.calc_torch() cs.clear_cache() self.opt_reg = cs self.clear_cache() return (r[1] if r[2] == 0 else numpy.inf) * -sgn def minimize(self, *args, **kwargs): """ Maximize expr with variables in the list vs, constrained in the region reg. Refer to optimize for details. """ return self.optimize(*args, sgn = -1, **kwargs) def maximize(self, *args, **kwargs): """ Maximize expr with variables in the list vs, constrained in the region reg. Refer to optimize for details. """ return self.optimize(*args, sgn = 1, **kwargs) def opt_model(self): return self.opt_reg def get_bayesnet(self): return self.bnet.copy() def table(self, *args, **kwargs): """Plot the information diagram as a Karnaugh map. """ return universe().table(*args, self, **kwargs) def venn(self, *args, **kwargs): """Plot the information diagram as a Venn diagram. Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5). """ return universe().venn(*args, self, **kwargs) def graph(self, **kwargs): """Return the Bayesian network among the random variables as a graphviz digraph that can be displayed in the console. """ return self.get_bayesnet().graph(**kwargs) def set_force_float(self, force_float): for (cin, cout), t in self.psmap.items(): t.force_float = force_float for x in self.realvars: x.force_float = force_float def fraction_snap(self, denom = None, eps = None): for (cin, cout), t in self.psmap.items(): t.fraction_snap(denom = denom, eps = eps) for x in self.realvars: x.fraction_snap(denom = denom, eps = eps) def tostring(self, style = 0): """Convert to string. Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) r = "" nlstr = "\n" if style & PsiOpts.STR_STYLE_LATEX: nlstr = "\\\\\n" if style & PsiOpts.STR_STYLE_LATEX: r += "\\begin{array}{l}\n" for (cin, cout), t in self.psmap.items(): # self.calc_dist(t) # if t is None: # continue cinrv = self.tuple_to_comp(cin) coutrv = self.tuple_to_comp(cout) r += "P(" r += coutrv.tostring(style) if cinrv: r += "|" r += cinrv.tostring(style) r += ")" r += " = " r += t.tostring(style) r += nlstr for i, x in enumerate(self.index_real.compreal): r += x.tostring(style) r += " = " r += self.realvars[i].tostring(style) r += nlstr if style & PsiOpts.STR_STYLE_LATEX: r += "\\end{array}" return r def __str__(self): return self.tostring(PsiOpts.settings["str_style"]) @latex_postprocess def _latex_(self): return self.tostring(iutil.convert_str_style("latex")) class SparseMat: """List of lists sparse matrix. Do NOT use directly""" def __init__(self, width): self.width = width self.x = [] self.rowinfo = [] def copy(self): r = SparseMat(self.width) r.x = [list(a) for a in self.x] r.rowinfo = iutil.copy(self.rowinfo) return r @staticmethod def from_row(row, width): r = SparseMat(width) r.x.append(list(row)) r.rowinfo.append(None) return r @staticmethod def from_dense_row(row): ceps = PsiOpts.settings["eps"] r = SparseMat(len(row)) x = [] for i in range(len(row)): if abs(row[i]) > ceps: x.append((i, row[i])) r.x.append(x) r.rowinfo.append(None) return r def iszero(self): return all(len(a) == 0 for a in self.x) def __iadd__(self, other): if self.width != other.width or len(self.x) != len(other.x): raise RuntimeError("SparseMat iadd shape mismatch.") return self ceps = PsiOpts.settings["eps"] for a, b in zip(self.x, other.x): for j in range(len(b)): for i in range(len(a)): if a[i][0] == b[j][0]: csum = a[i][1] + b[j][1] if abs(csum) > ceps: a[i] = (a[i][0], csum) else: a.pop(i) break else: a.append(b[j]) return self def __add__(self, other): r = self.copy() r += other return r def __imul__(self, other): for a in self.x: for i in range(len(a)): a[i] = (a[i][0], a[i][1] * other) return self def __mul__(self, other): r = self.copy() r *= other return r def __isub__(self, other): self += other * -1 return self def __sub__(self, other): r = self.copy() r += other * -1 return r def ratio(self, other): ceps = PsiOpts.settings["eps"] r = None for a, b in zip(self.x, other.x): if len(a) != len(b): return None for ax, bx in zip(a, b): if ax[0] != bx[0]: return None if abs(bx[1]) <= ceps: if abs(ax[1]) <= ceps: continue else: return None t = ax[1] / bx[1] if r is not None and abs(t - r) > ceps: return None r = t return r def addrow(self, info = None): self.x.append([]) self.rowinfo.append(info) def poprow(self): self.x.pop() self.rowinfo.pop() def add_last_row(self, j, c): self.x[len(self.x) - 1].append((j, c)) def extend(self, other): self.width = max(self.width, other.width) self.x += other.x self.rowinfo += other.rowinfo def simplify_row(self, i): ceps = PsiOpts.settings["eps"] self.x[i].sort() t = self.x[i] self.x[i] = [] cj = -1 cc = 0.0 for (j, c) in t: if j == cj: cc += c else: if abs(cc) > ceps: self.x[i].append((cj, cc)) cj = j cc = c if abs(cc) > ceps: self.x[i].append((cj, cc)) def simplify_last_row(self): self.simplify_row(len(self.x) - 1) def simplify(self): for i in range(len(self.x)): self.simplify_row(i) def last_row_isempty(self): return len(self.x[len(self.x) - 1]) == 0 def unique_row(self): ir = 1.61803398875 - 0.1 p = 10007 rowmap = {} for i in range(len(self.x)): a = self.x[i] h = 0 for (j, c) in a: h = h * p + hash(c + j * ir) if h in rowmap: if a == self.x[rowmap[h]]: a[:] = [] else: rowmap[h] = i self.rowinfo = [ri for (a, ri) in zip(self.x, self.rowinfo) if len(a) > 0] self.x = [a for a in self.x if len(a) > 0] def remove_empty_rows(self): self.rowinfo = [ri for (a, ri) in zip(self.x, self.rowinfo) if len(a) > 0] self.x = [a for a in self.x if len(a) > 0] def row_dense(self, i): r = [0.0] * self.width for (j, c) in self.x[i]: r[j] += c return r def nonzero_cols(self): r = [False] * self.width for i in range(len(self.x)): for (j, c) in self.x[i]: r[j] = True return r def mapcol(self, m, allowmiss = False): for i in range(len(self.x)): for k in range(len(self.x[i])): j2 = m[self.x[i][k][0]] if not allowmiss and j2 < 0: return False self.x[i][k] = (j2, self.x[i][k][1]) if allowmiss: self.x[i] = [(j, c) for (j, c) in self.x[i] if j >= 0] return True def sumrows(self, m, remove = False): r = [] for i2 in range(len(self.x)): a = self.x[i2] cr = 0.0 did = False for i in range(len(a)): if a[i][0] in m: cr += m[a[i][0]] * a[i][1] if remove: a[i] = (-1, a[i][1]) did = True if did: self.x[i2] = [(j, c) for (j, c) in self.x[i2] if j >= 0] r.append(cr) return r def tolil(self): r = scipy.sparse.lil_matrix((len(self.x), self.width)) for i in range(len(self.x)): for (j, c) in self.x[i]: r[i, j] += c return r def tonumpyarray(self): r = numpy.zeros((len(self.x), self.width)) for i in range(len(self.x)): for (j, c) in self.x[i]: r[i, j] += c return r def tonumpyarray_row(self): r = numpy.zeros(self.width) for (j, c) in self.x[0]: r[j] += c return r class LinearProg: """A linear programming instance. Do NOT use directly""" def __init__(self, index, lptype, bnet = None, lp_bounded = None, save_res = False, prereg = None, dual_enabled = None, val_enabled = None, dual_form = None): self.index = index self.lptype = lptype self.nvar = 0 self.nxvar = 0 self.xvarid = [] self.realshift = 0 self.cellpos = [] self.bnet = bnet self.pinfeas = False self.constmap = {} self.quantum = PsiOpts.settings.get("quantum", False) self.hge0 = PsiOpts.settings.get("hge0", False) self.hcge0 = PsiOpts.settings.get("hcge0", False) self.ige0 = PsiOpts.settings.get("ige0", False) self.icge0 = PsiOpts.settings.get("icge0", False) if self.quantum or not self.hge0 or not self.hcge0 or not self.ige0 or not self.icge0: self.lptype = LinearProgType.H self.noskip = (PsiOpts.settings["proof_noskip"] and PsiOpts.settings["proof_enabled"]) if dual_form is None: self.dual_form = PsiOpts.settings["lp_dual_form"] or (PsiOpts.settings["lp_dual_form_if_proof"] and PsiOpts.settings["proof_enabled"]) else: self.dual_form = dual_form self.dual_form_ncons = 0 self.dual_form_obj = None self.dual_form_cutoff = 0.0 self.dual_form_infeas = False self.dual_form_weights = None self.cond_maximize = PsiOpts.settings["lp_cond_maximize"] if lp_bounded is None: self.lp_bounded = PsiOpts.settings["lp_bounded"] else: self.lp_bounded = lp_bounded self.lp_ubound = PsiOpts.settings["lp_ubound"] self.lp_eps = PsiOpts.settings["lp_eps"] self.lp_eps_obj = PsiOpts.settings["lp_eps_obj"] self.zero_cutoff = PsiOpts.settings["lp_zero_cutoff"] self.eps_present = False self.affine_present = False self.fcn_mode = PsiOpts.settings["fcn_mode"] self.fcn_list = [] self.save_res = save_res self.saved_var = [] self.celllink = [] self.cellsplit = [] if prereg is not None: if self.fcn_mode >= 1: for x in prereg.exprs_gei: self.addExpr_ge0_fcn(x) for x in prereg.exprs_eqi: self.addExpr_ge0_fcn(x) if self.lptype == LinearProgType.H: self.nvar = (1 << self.index.num_rv()) - 1 + self.index.num_real() self.realshift = (1 << self.index.num_rv()) - 1 elif self.lptype == LinearProgType.HMIN: n = self.index.num_rv() nbnet = bnet.index.num_rv() self.celllink = [-2] * (1 << n) self.cellsplit = [None] * (1 << n) self.cellpos = [-2] * (1 << n) cpos = 0 for mask in range(1, 1 << n): mask2 = self.fcn_mask_maximize(mask) if self.celllink[mask2] < 0: self.celllink[mask2] = mask self.celllink[mask] = mask else: self.celllink[mask] = self.celllink[mask2] for mask in range(1, 1 << n): if mask < (1 << nbnet) and mask == self.celllink[mask]: for i in range(n): if not mask & (1 << i): continue maskb = bnet.markov_blanket_mask(1 << i, mask - (1 << i)) if maskb == mask - (1 << i): continue self.cellsplit[mask] = (1 << i, mask - (1 << i) - maskb, maskb) break if mask == self.celllink[mask] and self.cellsplit[mask] is None: self.cellpos[mask] = cpos cpos += 1 self.realshift = cpos self.nvar = self.realshift + self.index.num_real() elif self.lptype == LinearProgType.HC1BN: n = self.index.num_rv() nbnet = bnet.index.num_rv() self.cellpos = [-2] * (1 << n) cpos = 0 if self.cond_maximize: for i in range(n): for mask in range((1 << i) - 1, -1, -1): maski = mask + (1 << i) if self.fcn_mode >= 1: mask2 = self.fcn_mask_maximize(mask) if mask2 & (1 << i): self.cellpos[maski] = -1 continue mask2 &= (1 << i) - 1 if mask2 != mask: if mask2 < mask: print("HC1BN REVERSED!") self.cellpos[maski] = self.cellpos[mask2 + (1 << i)] continue if i < nbnet: for j in range(i): if mask & (1 << j) == 0: if bnet.check_ic_mask(1 << i, 1 << j, mask): self.cellpos[maski] = self.cellpos[maski + (1 << j)] break if self.cellpos[maski] == -2: self.cellpos[maski] = cpos cpos += 1 else: for i in range(n): for mask in range(1 << i): maski = mask + (1 << i) if self.fcn_mode >= 1: # mask_max = self.fcn_mask_maximize(mask) # if mask_max & (1 << i): # self.cellpos[maski] = -1 # continue # mask_min = self.fcn_mask_minimize(mask_max) # if mask_min < mask: # maskj = mask_min + (1 << i) # self.cellpos[maski] = self.cellpos[maskj] # continue if self.checkfcn_mask(1 << i, mask): self.cellpos[maski] = -1 continue for j in range(i): if mask & (1 << j) != 0: if self.checkfcn_mask(1 << j, mask - (1 << j)): maskj = mask - (1 << j) + (1 << i) self.cellpos[maski] = self.cellpos[maskj] break if self.cellpos[maski] != -2: continue if i >= nbnet: self.cellpos[maski] = cpos cpos += 1 continue for j in range(i): if mask & (1 << j) != 0: if bnet.check_ic_mask(1 << i, 1 << j, mask - (1 << j)): maskj = mask - (1 << j) + (1 << i) self.cellpos[maski] = self.cellpos[maskj] break if self.cellpos[maski] == -2: self.cellpos[maski] = cpos cpos += 1 self.realshift = cpos self.nvar = self.realshift + self.index.num_real() self.nxvar = self.nvar self.Au = SparseMat(self.nvar) self.Ae = SparseMat(self.nvar) self.solver = None if self.nvar <= PsiOpts.settings["solver_scipy_maxsize"]: self.solver = iutil.get_solver("scipy") else: self.solver = iutil.get_solver() self.solver_param = {} self.bu = [] self.be = [] self.icp = [[] for i in range(self.index.num_rv())] self.dual_u = None self.dual_e = None self.dual_pf = (PsiOpts.settings["proof_enabled"] and not PsiOpts.settings["proof_nowrite"]) if dual_enabled is not None: self.dual_enabled = dual_enabled else: self.dual_enabled = PsiOpts.settings["proof_enabled"] self.val_x = None if val_enabled is not None: self.val_enabled = val_enabled else: self.val_enabled = False self.optval = None def get_optval(self): return self.optval def addreal_id(self, A, k, c): A.add_last_row(self.realshift + k, c) def addH_mask(self, A, mask, c): if self.lptype == LinearProgType.H: A.add_last_row(mask - 1, c) elif self.lptype == LinearProgType.HC1BN: n = self.index.num_rv() for i in range(n): if mask & (1 << i) != 0: cp = self.cellpos[mask & ((1 << (i + 1)) - 1)] if cp >= 0: A.add_last_row(cp, c) elif self.lptype == LinearProgType.HMIN: mask = self.celllink[mask] if self.cellsplit[mask] is None: A.add_last_row(self.cellpos[mask], c) else: ca, cb, cc = self.cellsplit[mask] self.addH_mask(A, ca | cc, c) self.addH_mask(A, cb | cc, c) if cc: self.addH_mask(A, cc, -c) def addIc_mask(self, A, x, y, zmask, c): if self.lptype == LinearProgType.H: if x == y: A.add_last_row(((1 << x) | zmask) - 1, c) if zmask != 0: A.add_last_row(zmask - 1, -c) else: A.add_last_row(((1 << x) | zmask) - 1, c) A.add_last_row(((1 << y) | zmask) - 1, c) A.add_last_row(((1 << x) | (1 << y) | zmask) - 1, -c) if zmask != 0: A.add_last_row(zmask - 1, -c) elif self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN: if x == y: self.addH_mask(A, (1 << x) | zmask, c) if zmask != 0: self.addH_mask(A, zmask, -c) else: self.addH_mask(A, (1 << x) | zmask, c) self.addH_mask(A, (1 << y) | zmask, c) self.addH_mask(A, (1 << x) | (1 << y) | zmask, -c) if zmask != 0: self.addH_mask(A, zmask, -c) def addExpr(self, A, x): for (a, c) in x.terms: termType = a.get_type() if termType == TermType.IC: k = len(a.x) for t in range(1 << k): csgn = -1 mask = self.index.get_mask(a.z) for i in range(k): if (t & (1 << i)) != 0: csgn = -csgn mask |= self.index.get_mask(a.x[i]) if mask != 0: self.addH_mask(A, mask, c * csgn) elif termType == TermType.REAL or termType == TermType.REGION: k = self.index.get_index(a.x[0].varlist[0]) if k >= 0: self.addreal_id(A, k, c) def addfcn(self, x): for (a, c) in x.terms: termType = a.get_type() if termType == TermType.IC and len(a.x) == 1: self.fcn_list.append((self.index.get_mask(a.x[0]), self.index.get_mask(a.z))) def fcn_mask_maximize(self, zmask): did = True while did: did = False for cx, cz in self.fcn_list: if cz | zmask == zmask and cx | zmask != zmask: zmask |= cx did = True return zmask def fcn_mask_minimize(self, zmask): n = self.index.num_rv() for i in range(n - 1, -1, -1): if zmask & (1 << i): if self.checkfcn_mask(1 << i, zmask - (1 << i)): zmask -= 1 << i return zmask def checkfcn_mask(self, xmask, zmask): if xmask < 0: return False if zmask | xmask == zmask: return True did = True while did: did = False for cx, cz in self.fcn_list: if cz | zmask == zmask and cx | zmask != zmask: zmask |= cx if zmask | xmask == zmask: return True did = True return False def checkfcn(self, x, z): xmask = self.index.get_mask(x) if xmask < 0: return False zmask = self.index.get_mask(z.inter(self.index.comprv)) return self.checkfcn_mask(xmask, zmask) def get_bnet_fcn_region(self): r = Region.universe() if self.bnet is not None: r = self.bnet.get_region() for cx, cz in self.fcn_list: r &= Expr.Hc(self.index.from_mask(cx), self.index.from_mask(cz)) == 0 return r def addExpr_ge0(self, x): if x.size() == 0: return self.Au.addrow(info = -x) self.addExpr(self.Au, -x) self.bu.append(0.0) # print(str(x) + " " + str(x.get_meta("pf_note"))) # ??????????????? def addExpr_ge0_fcn(self, x): if x.size() == 0: return if self.fcn_mode >= 1: if x.isnonpos(): self.addfcn(x) def addExpr_eq0(self, x): if x.size() == 0: return self.Ae.addrow(info = x.copy()) self.addExpr(self.Ae, x) self.be.append(0.0) def addExpr_eq0_fcn(self, x): if x.size() == 0: return if self.fcn_mode >= 1: if x.isnonpos() or x.isnonneg(): self.addfcn(x) def add_ent_ineq(self): quantum = self.quantum n = self.index.num_rv() npow = (1 << n) hge0 = self.hge0 hcge0 = self.hcge0 ige0 = self.ige0 icge0 = self.icge0 if not quantum and (not hge0 or not hcge0 or not ige0 or not icge0): if ige0: for xmask in range(1, 1 << n): for ymask in range(xmask + 1, 1 << n): self.addExpr_ge0(Expr.I(self.index.comprv.from_mask(xmask), self.index.comprv.from_mask(ymask))) if hge0: for xmask in range(1, 1 << n): self.addExpr_ge0(Expr.H(self.index.comprv.from_mask(xmask))) if hcge0: for zmask in range(1, 1 << n): for x in range(n): if (zmask & (1 << x)) == 0: self.addExpr_ge0(Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(zmask))) return if quantum: for x in range(n): zmask = 0 while zmask < npow: wmask = npow - 1 - (1 << x) - zmask if zmask > wmask: break self.Au.addrow(info = -Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(zmask)) - Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(wmask))) self.addIc_mask(self.Au, x, x, zmask, -1.0) self.addIc_mask(self.Au, x, x, wmask, -1.0) self.bu.append(0.0) zmask += 1 if (zmask & (1 << x)) != 0: zmask += (1 << x) else: if self.dual_form: for x in range(n): zmask = 0 while zmask < npow: self.Au.addrow(info = -Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(zmask))) self.addIc_mask(self.Au, x, x, zmask, -1.0) self.bu.append(0.0) zmask += 1 if (zmask & (1 << x)) != 0: zmask += (1 << x) else: for x in range(n): self.Au.addrow(info = -Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(npow - 1 - (1 << x)))) self.addIc_mask(self.Au, x, x, npow - 1 - (1 << x), -1.0) self.bu.append(0.0) for x in range(n): for y in range(x + 1, n): zmask = 0 while zmask < npow: self.Au.addrow(info = -Expr.Ic(self.index.comprv[x], self.index.comprv[y], self.index.comprv.from_mask(zmask))) self.addIc_mask(self.Au, x, y, zmask, -1.0) if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN: self.Au.simplify_last_row() if self.Au.last_row_isempty(): self.Au.poprow() else: self.bu.append(0.0) else: self.bu.append(0.0) zmask += 1 if (zmask & (1 << x)) != 0: zmask += (1 << x) if (zmask & (1 << y)) != 0: zmask += (1 << y) def finish(self, skip_ent_ineq = False, skip_remove_empty = False, skip_solver = False): ceps = PsiOpts.settings["eps"] dual_form = self.dual_form if not skip_ent_ineq: self.add_ent_ineq() PsiRec.num_lpprob += 1 #self.Au.simplify() #self.Ae.simplify() if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN: self.Au.unique_row() self.bu = [0.0] * len(self.Au.x) if False: k = self.index.get_index(IVar.one()) if k >= 0: self.Ae.addrow() self.addreal_id(self.Ae, k, 1.0) self.be.append(1.0) self.affine_present = True k = self.index.get_index(IVar.eps()) if k >= 0: self.Ae.addrow() self.addreal_id(self.Ae, k, 1.0) self.be.append(self.lp_eps) self.eps_present = True self.affine_present = True k = self.index.get_index(IVar.inf()) if k >= 0: self.Ae.addrow() self.addreal_id(self.Ae, k, 1.0) self.be.append(self.lp_ubound) self.affine_present = True if True: id_one = self.index.get_index(IVar.one()) if id_one >= 0: id_one += self.realshift id_eps = self.index.get_index(IVar.eps()) if id_eps >= 0: id_eps += self.realshift id_inf = self.index.get_index(IVar.inf()) if id_inf >= 0: id_inf += self.realshift if id_one >= 0 or id_eps >= 0 or id_inf >= 0: self.affine_present = True if id_eps >= 0: self.eps_present = True if self.affine_present and not dual_form: self.lp_bounded = True if self.lp_bounded: if self.index.num_rv() > 0: self.Au.addrow(info = Expr.zero()) self.addH_mask(self.Au, (1 << self.index.num_rv()) - 1, 1.0) #for i in range(self.realshift): # self.Au.add_last_row(i, 1.0) self.bu.append(self.lp_ubound) for i in range(self.index.num_real()): if Term.fromcomp(self.index.compreal[i]).isrealvar(): self.Au.addrow(info = Expr.zero()) self.addreal_id(self.Au, i, 1.0) self.bu.append(self.lp_ubound) self.Au.addrow(info = Expr.zero()) self.addreal_id(self.Au, i, -1.0) self.bu.append(self.lp_ubound) cols = self.Au.nonzero_cols() coles = self.Ae.nonzero_cols() cols = [a or b for a, b in zip(cols, coles)] #print(self.Au.x) if True: self.constmap = {} if id_one >= 0: cols[id_one] = False self.constmap[id_one] = 1.0 if id_eps >= 0: cols[id_eps] = False self.constmap[id_eps] = self.lp_eps if id_inf >= 0: cols[id_inf] = False self.constmap[id_inf] = self.lp_ubound if len(self.constmap): self.bu = [b - a for a, b in zip(self.Au.sumrows(self.constmap, remove = True), self.bu)] self.be = [b - a for a, b in zip(self.Ae.sumrows(self.constmap, remove = True), self.be)] #print(self.Au.x) if skip_remove_empty: cols = [True] * len(cols) self.xvarid = [0] * self.nvar self.nxvar = 0 for i in range(self.nvar): if cols[i]: self.xvarid[i] = self.nxvar self.nxvar += 1 else: self.xvarid[i] = -1 self.Au.mapcol(self.xvarid) self.Au.width = self.nxvar self.Ae.mapcol(self.xvarid) self.Ae.width = self.nxvar # print(self.xvarid) # print(self.Au.x) # print(self.bu) #print(self.Au.x) if True: for i in range(len(self.Au.x)): if len(self.Au.x[i]) == 0: if self.bu[i] < -ceps: self.pinfeas = True self.bu[i] = None for i in range(len(self.Ae.x)): if len(self.Ae.x[i]) == 0: if abs(self.be[i]) > ceps: self.pinfeas = True self.be[i] = None # self.Au.x = [a for a in self.Au.x if len(a)] # self.Ae.x = [a for a in self.Ae.x if len(a)] self.Au.remove_empty_rows() self.Ae.remove_empty_rows() self.bu = [a for a in self.bu if a is not None] self.be = [a for a in self.be if a is not None] #print(self.Au.x) if not skip_solver: self.calc_solver() def calc_solver(self): ceps = PsiOpts.settings["eps"] dual_form = self.dual_form self.dual_form_infeas = False if self.solver == "scipy": self.solver_param["Aus"] = self.Au.tolil() self.solver_param["Aes"] = self.Ae.tolil() elif self.solver.startswith("pulp."): prob = None xvar = None if dual_form: if self.dual_form_obj is not None: self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) * 2 self.dual_form_weights = [] prob = pulp.LpProblem("lpentineq" + str(PsiRec.num_lpprob), pulp.LpMinimize) xvar = pulp.LpVariable.dicts("x", [str(i) for i in range(self.dual_form_ncons)]) vexprs = [None] * (self.nxvar + 1) for i, (cx, cb, csn, cri) in enumerate(itertools.chain(((tx, tb, -1, tri) for tx, tb, tri in zip(self.Au.x, self.bu, self.Au.rowinfo)), ((tx, tb, tsn, tri) for tx, tb, tri in zip(self.Ae.x, self.be, self.Ae.rowinfo) for tsn in [-1, 1]))): prob += xvar[str(i)] >= 0 cweight = 1.0 + i * 0.2 / self.dual_form_ncons if cri is not None: cweight += 0.2 * len(cri.allcomp()) if cri is not None and cri.get_meta("dual_weight") is not None: cweight *= cri.get_meta("dual_weight") self.dual_form_weights.append(cweight) for (j, c) in cx: # print(str(i) + " " + str(j) + " " + str(c)) if vexprs[j] is None: vexprs[j] = c * csn * xvar[str(i)] else: vexprs[j] += c * csn * xvar[str(i)] if abs(cb) > ceps: if vexprs[self.nxvar] is None: vexprs[self.nxvar] = cb * csn * xvar[str(i)] else: vexprs[self.nxvar] += cb * csn * xvar[str(i)] for j in range(self.nxvar): if vexprs[j] is None: if abs(self.dual_form_obj[j]) > ceps: self.dual_form_infeas = True continue prob += vexprs[j] == self.dual_form_obj[j] if vexprs[self.nxvar] is None: if -self.dual_form_cutoff > ceps: self.dual_form_infeas = True else: # print(vexprs[self.nxvar]) # print(self.dual_form_cutoff) prob += vexprs[self.nxvar] >= -self.dual_form_cutoff else: prob = pulp.LpProblem("lpentineq" + str(PsiRec.num_lpprob), pulp.LpMinimize) xvar = pulp.LpVariable.dicts("x", [str(i) for i in range(self.nxvar)]) for a, b in zip(self.Au.x, self.bu): if len(a): prob += pulp.LpConstraint(pulp.lpSum([xvar[str(j)] * c for (j, c) in a]), sense = -1, rhs = b) for a, b in zip(self.Ae.x, self.be): if len(a): prob += pulp.LpConstraint(pulp.lpSum([xvar[str(j)] * c for (j, c) in a]), sense = 0, rhs = b) if False: for a, b in zip(self.Au.x, self.bu): if len(a): #print(" $ ".join([str((j, c)) for (j, c) in a])) prob += pulp.LpConstraint(pulp.LpAffineExpression([(xvar[str(j)], c) for (j, c) in a]), sense = -1, rhs = b) for a, b in zip(self.Ae.x, self.be): if len(a): prob += pulp.LpConstraint(pulp.LpAffineExpression([(xvar[str(j)], c) for (j, c) in a]), sense = 0, rhs = b) if False: for i in range(len(self.Au.x)): cexpr = None for (j, c) in self.Au.x[i]: if cexpr is None: cexpr = c * xvar[str(j)] else: cexpr += c * xvar[str(j)] if cexpr is not None: prob += cexpr <= self.bu[i] for i in range(len(self.Ae.x)): cexpr = None for (j, c) in self.Ae.x[i]: if cexpr is None: cexpr = c * xvar[str(j)] else: cexpr += c * xvar[str(j)] if cexpr is not None: prob += cexpr == self.be[i] #print(prob) self.solver_param["prob"] = prob self.solver_param["xvar"] = xvar elif self.solver.startswith("pyomo."): solver_opt = self.solver[self.solver.index(".") + 1 :] opt = SolverFactory(solver_opt) model = pyo.ConcreteModel() if self.dual_enabled: model.dual = pyo.Suffix(direction=pyo.Suffix.IMPORT) if dual_form: if self.dual_form_obj is not None: self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) * 2 self.dual_form_weights = [] model.n = pyo.Param(default = self.dual_form_ncons) model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals) model.c = pyo.ConstraintList() vexprs = [None] * (self.nxvar + 1) for i, (cx, cb, csn, cri) in enumerate(itertools.chain(((tx, tb, -1, tri) for tx, tb, tri in zip(self.Au.x, self.bu, self.Au.rowinfo)), ((tx, tb, tsn, tri) for tx, tb, tri in zip(self.Ae.x, self.be, self.Ae.rowinfo) for tsn in [-1, 1]))): model.c.add(model.x[i + 1] >= 0) cweight = 1.0 + i * 0.2 / self.dual_form_ncons if cri is not None: cweight += 0.2 * len(cri.allcomp()) if cri is not None and cri.get_meta("dual_weight") is not None: cweight *= cri.get_meta("dual_weight") self.dual_form_weights.append(cweight) for (j, c) in cx: # print(str(i) + " " + str(j) + " " + str(c)) if vexprs[j] is None: vexprs[j] = c * csn * model.x[i + 1] else: vexprs[j] += c * csn * model.x[i + 1] if abs(cb) > ceps: if vexprs[self.nxvar] is None: vexprs[self.nxvar] = cb * csn * model.x[i + 1] else: vexprs[self.nxvar] += cb * csn * model.x[i + 1] for j in range(self.nxvar): if vexprs[j] is None: if abs(self.dual_form_obj[j]) > ceps: self.dual_form_infeas = True continue model.c.add(vexprs[j] == self.dual_form_obj[j]) if vexprs[self.nxvar] is None: if -self.dual_form_cutoff > ceps: self.dual_form_infeas = True else: # print(vexprs[self.nxvar]) # print(self.dual_form_cutoff) model.c.add(vexprs[self.nxvar] >= -self.dual_form_cutoff) else: model.n = pyo.Param(default=self.nxvar) model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals) model.c = pyo.ConstraintList() for i in range(len(self.Au.x)): cexpr = None for (j, c) in self.Au.x[i]: if cexpr is None: cexpr = c * model.x[j + 1] else: cexpr += c * model.x[j + 1] if cexpr is not None: model.c.add(cexpr <= self.bu[i]) for i in range(len(self.Ae.x)): cexpr = None for (j, c) in self.Ae.x[i]: if cexpr is None: cexpr = c * model.x[j + 1] else: cexpr += c * model.x[j + 1] if cexpr is not None: model.c.add(cexpr == self.be[i]) self.solver_param["opt"] = opt self.solver_param["model"] = model elif self.solver.startswith("ortools."): solver_opt = self.solver[self.solver.index(".") + 1 :] model = ortools.linear_solver.pywraplp.Solver.CreateSolver(solver_opt) xvar = None if not model: raise RuntimeError("Fail to initialize solver " + self.solver) # if self.dual_enabled: # model.dual = pyo.Suffix(direction=pyo.Suffix.IMPORT) if dual_form: if self.dual_form_obj is not None: self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) * 2 self.dual_form_weights = [] # model.n = pyo.Param(default = self.dual_form_ncons) # model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals) n = self.dual_form_ncons xvar = [model.NumVar(0, model.infinity(), 'x' + str(i)) for i in range(n)] # model.c = pyo.ConstraintList() vexprs = [None] * (self.nxvar + 1) for i, (cx, cb, csn, cri) in enumerate(itertools.chain(((tx, tb, -1, tri) for tx, tb, tri in zip(self.Au.x, self.bu, self.Au.rowinfo)), ((tx, tb, tsn, tri) for tx, tb, tri in zip(self.Ae.x, self.be, self.Ae.rowinfo) for tsn in [-1, 1]))): # model.Add(x[i] >= 0) cweight = 1.0 + i * 0.2 / self.dual_form_ncons if cri is not None: cweight += 0.2 * len(cri.allcomp()) if cri is not None and cri.get_meta("dual_weight") is not None: cweight *= cri.get_meta("dual_weight") self.dual_form_weights.append(cweight) for (j, c) in cx: # print(str(i) + " " + str(j) + " " + str(c)) if vexprs[j] is None: vexprs[j] = c * csn * xvar[i] else: vexprs[j] += c * csn * xvar[i] if abs(cb) > ceps: if vexprs[self.nxvar] is None: vexprs[self.nxvar] = cb * csn * xvar[i] else: vexprs[self.nxvar] += cb * csn * xvar[i] for j in range(self.nxvar): if vexprs[j] is None: if abs(self.dual_form_obj[j]) > ceps: self.dual_form_infeas = True continue model.Add(vexprs[j] == self.dual_form_obj[j]) if vexprs[self.nxvar] is None: if -self.dual_form_cutoff > ceps: self.dual_form_infeas = True else: # print(vexprs[self.nxvar]) # print(self.dual_form_cutoff) model.Add(vexprs[self.nxvar] >= -self.dual_form_cutoff) else: # model.n = pyo.Param(default=self.nxvar) # model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals) # model.c = pyo.ConstraintList() n = self.nxvar xvar = [model.NumVar(-model.infinity(), model.infinity(), 'x' + str(i)) for i in range(n)] for i in range(len(self.Au.x)): cexpr = None for (j, c) in self.Au.x[i]: if cexpr is None: cexpr = c * xvar[j] else: cexpr += c * xvar[j] if cexpr is not None: model.Add(cexpr <= self.bu[i]) for i in range(len(self.Ae.x)): cexpr = None for (j, c) in self.Ae.x[i]: if cexpr is None: cexpr = c * xvar[j] else: cexpr += c * xvar[j] if cexpr is not None: model.Add(cexpr == self.be[i]) self.solver_param["model"] = model self.solver_param["xvar"] = xvar def id_toexpr(self): n = self.index.num_rv() xvarinv = [0] * self.nxvar for i in range(self.nvar): if self.xvarid[i] >= 0: xvarinv[self.xvarid[i]] = i cellposinv = list(range(1, self.realshift + 1)) if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN: for mask in range((1 << n) - 1, 0, -1): if self.cellpos[mask] >= 0: cellposinv[self.cellpos[mask]] = mask def rf(j2): if j2 >= self.nxvar: if j2 == self.nxvar: return (Expr.const(1.0), 0) return None j = xvarinv[j2] if j >= self.realshift: return (Expr.real(self.index.compreal.varlist[j - self.realshift].name), 0) mask = cellposinv[j] term = None if self.lptype == LinearProgType.H or self.lptype == LinearProgType.HMIN: term = Term.H(Comp.empty()) for i in range(n): if mask & (1 << i) != 0: term.x[0].varlist.append(self.index.comprv.varlist[i]) elif self.lptype == LinearProgType.HC1BN: term = Term.Hc(Comp.empty(), Comp.empty()) for i in range(n): if mask & (1 << i) != 0: term.z.varlist.append(self.index.comprv.varlist[i]) term.x[0].varlist.append(term.z.varlist.pop()) return (Expr.fromterm(term), mask) return rf def get_var_exprs(self): idt = self.id_toexpr() r = [idt(i)[0] for i in range(self.nxvar)] return ExprArray(r) def row_toexpr(self): idt = self.id_toexpr() def rf(x): if isinstance(x, SparseMat): x = x.x[0] expr = Expr.zero() if len(x) == 0: return expr if isinstance(x[0], tuple): for (j2, c) in x: te, mask = idt(j2) #expr.terms.append((te, c)) expr += te * c else: ceps = PsiOpts.settings["eps"] for i in range(len(x)): if abs(x[i]) > ceps: te, mask = idt(i) expr += te * x[i] return expr return rf def get_region(self, toreal = None, toreal_only = False, A = None, skip_simplify = False): if toreal is None: toreal = Comp.empty() torealmask = self.index.get_mask(toreal) idt = self.id_toexpr() if A is None: A = ([(a, 1, b) for a, b in zip(self.Au.x, self.bu)] + [(a, 0, b) for a, b in zip(self.Ae.x, self.be)]) else: A = [(a, 1, 0.0) for a in A.x] r = Region.universe() for x, sn, b in A: expr = Expr.zero() toreal_present = False for (j2, c) in x: te, mask = idt(j2) termreal = (mask & torealmask) != 0 toreal_present |= termreal if termreal: expr += Expr.real("R_" + str(te)) * c else: expr += te * c if toreal_present or not toreal_only: if not skip_simplify: expr.simplify_quick() if sn == 1: r &= (expr <= Expr.const(b)) else: r &= (expr == Expr.const(b)) return r def get_dual_region(self, mul_coeff = True, omit_trivial = False, tosum = None): if self.dual_e is None or self.dual_u is None: return None rowt = self.row_toexpr() ceps = PsiOpts.settings["eps"] r = Region.universe() for x, xb, d in zip(self.Au.x, self.bu, self.dual_u): if abs(d) > ceps: x2 = (xb - rowt(x)).simplified_quick() if omit_trivial and x2.isnonneg(): continue if mul_coeff: r.exprs_ge.append(x2 * -d) else: r.exprs_ge.append(x2) if tosum is not None: tosum += x2 * -d for x, xb, d in zip(self.Ae.x, self.be, self.dual_e): if abs(d) > ceps: x2 = (xb - rowt(x)).simplified_quick() if mul_coeff: r.exprs_eq.append(x2 * -d) else: r.exprs_eq.append(x2) if tosum is not None: tosum += x2 * -d return r def get_dual_sum_region(self, x): r = self.get_dual_region() with PsiOpts(proof_enabled = False): rsum = sum(r.exprs_ge + r.exprs_eq, Expr.zero()) rdiff = (x - rsum).simplified_quick() if not rdiff.iszero(): rowt = self.row_toexpr() for t in rdiff: v = self.get_vec(t)[0] if v is not None: tv = (t - rowt(v)).simplified_quick() if not tv.iszero(): r.exprs_eq.append(tv) return r def get_dual_sum_meta(self, rowt = None, meta = "sim"): if self.dual_e is None or self.dual_u is None: return None ceps = PsiOpts.settings["eps"] bnet = None if rowt is None: rowt = self.row_toexpr() r = Expr.zero() # print(self.Au.x) # print(self.Au.rowinfo) # print(self.Ae.x) # print(self.Ae.rowinfo) for x, ri, xb, d in itertools.chain(zip(self.Au.x, self.Au.rowinfo, self.bu, self.dual_u), zip(self.Ae.x, self.Ae.rowinfo, self.be, self.dual_e)): if abs(d) > ceps: # print(str(ri) + " " + str(d) + " " + str(ri.get_meta(meta))) if ri.get_meta(meta) is not None: r += ri * iutil.float_snap(d, PsiOpts.settings["max_denom_lp"], force = True) return r def copy_empty_nobnet(self): if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN: alt_prog = LinearProg(self.index.copy(), self.lptype, BayesNet()) alt_prog.finish(skip_ent_ineq = True, skip_remove_empty = True, skip_solver = True) return alt_prog return None def col_steps(self, cri, rowt): cri2 = rowt(self.get_vec(cri, sparse = True, cat = True)) if (cri - cri2).simplified_quick().iszero(): return [cri] # print((cri, cri2)) if not (len(cri) == 1 and cri.terms[0][0].ishc() and len(cri2) == 1 and cri2.terms[0][0].ishc() and cri.terms[0][0].x == cri2.terms[0][0].x): return [cri, cri2] # print("A") t1 = cri.terms[0][0] t2 = cri2.terms[0][0] if t1.z.super_of(t2.z) or t2.z.super_of(t1.z): return [cri, cri2] z3 = t1.z + t2.z cri3 = Expr.Hc(t1.x[0], z3) # print((cri, cri2, cri3)) if self.get_vec(cri2 - cri3, sparse = True, cat = True).iszero(): return [cri, cri3, cri2] return [cri, cri2] def get_dual_terms(self, expr, rowt = None, alt_prog = None, ri_simplify = True, nshuffle = 3, deficit_separate = True): verbose = PsiOpts.settings.get("verbose_proof_step", False) if self.dual_e is None or self.dual_u is None: return None ceps = PsiOpts.settings["eps"] bnet = None if rowt is None: rowt = self.row_toexpr() alt_rowt = None deficit = None if alt_prog is not None: alt_rowt = alt_prog.row_toexpr() deficit_type = "#TRI" if deficit_separate: deficit_type = "#DTRI" r = [] tmpi = 0 for tmpi, (x, ri, xb, d, iseq) in enumerate(itertools.chain(zip(self.Au.x, self.Au.rowinfo, self.bu, self.dual_u, [False] * len(self.bu)), zip(self.Ae.x, self.Ae.rowinfo, self.be, self.dual_e, [True] * len(self.be)))): if abs(d) > ceps: # r.append((SparseMat.from_row(x, self.nxvar).tonumpyarray_row(), # rowt(x).simplified_break_hc(bnet = bnet), d)) crow = SparseMat.from_row(x, self.nxvar + 1).tonumpyarray_row() crow[self.nxvar] = -xb # ri = ri.simplified_break_hc(bnet = bnet) if ri_simplify: ri = ri.simplified_break_hc(bnet = self.bnet) d = iutil.float_snap(d, PsiOpts.settings["max_denom_lp"], force = True) iseq = iseq or (abs(xb) < ceps and (ri * d).isnonpos()) if alt_prog is not None: crow2 = numpy.array(alt_prog.get_vec(ri, cat = True)) ri3 = rowt(crow) crow3 = numpy.array(alt_prog.get_vec(ri3, cat = True)) # print(crow2) # print(crow3) # print(ri) # print(ri3) # r.append((crow3 - crow2, (ri3 - ri).simplified(), d, True)) crow = crow2 if deficit is None: deficit = (crow3 - crow2) * d else: deficit += (crow3 - crow2) * d if verbose: arow = alt_rowt(crow3 - crow2) if not arow.iszero(): print("Deficit added : " + str(ri) + " " + str(ri3) + " : " + str(arow)) curtype = "" rii = ri.get_meta("pf_note") if isinstance(rii, list): curtype = str(rii[0]) elif (ri * d).isnonneg(): curtype = "#TRI" elif (ri * d).isnonpos(): if self.noskip: # curtype = "#N" + str(tmpi) curtype = "#CTRI" else: curtype = "#TRI" if any(abs(x) > ceps for x in crow): r.append((crow, ri, d, iseq, curtype)) # print(str(r[-1][1]) + " * " + str(d) + " " + str(r[-1][1].get_meta("pf_note"))) # ??????????????? if alt_prog is not None: exprrow = numpy.array(alt_prog.get_vec(expr, cat = True)) expr2 = rowt(self.get_vec(expr, cat = True)) exprrow2 = numpy.array(alt_prog.get_vec(expr2, cat = True)) if deficit is None: deficit = exprrow - exprrow2 else: deficit += exprrow - exprrow2 if deficit is not None and any(abs(x) > ceps for x in deficit): # for i in range(len(deficit) - 1, -1, -1): for i in range(len(deficit)): x = deficit[i] if abs(x) <= ceps: continue cri = alt_rowt([(i, 1.0)]) steps = self.col_steps(cri, rowt) if len(steps) <= 1: if verbose: print("Deficit UNRESOLVED: " + str(cri)) #+ " : " + str(alt_rowt(deficit))) print(" remainder: " + str(alt_rowt(deficit))) continue stepsrow = [numpy.array(alt_prog.get_vec(s, sparse = False, cat = True)) for s in steps] for j in range(len(steps) - 1): r.append((stepsrow[j] - stepsrow[j + 1], (steps[j] - steps[j + 1]).simplified_quick(), x, True, deficit_type)) if verbose: print("Deficit resolved " + str(j) + ": " + str(steps[j]) + " " + str(steps[j + 1]) + " : " + str((steps[j] - steps[j + 1]).simplified_quick())) deficit += (stepsrow[-1] - stepsrow[0]) * x if verbose: print(" remainder: " + str(alt_rowt(deficit))) # cri = alt_rowt([(i, x)]) # cri2 = rowt(self.get_vec(cri, sparse = True, cat = True)) # crow2 = alt_prog.get_vec(cri2, sparse = True, cat = True) # if len(crow2.x[0]) == 1 and crow2.x[0][0][0] == i: # if verbose: # print("Deficit UNRESOLVED: " + str(i) + " " + str(cri) + " " + str(cri2) #+ " : " + str(alt_rowt(deficit))) # continue # crow2 = crow2.tonumpyarray_row() # crow2[i] -= x # r.append((crow2, -(cri - cri2).simplified(), -1.0, True, "#TRI")) # deficit += crow2 # if verbose: # print("Deficit resolved : " + str(i) + " " + str(cri) + " " + str(cri2) #+ " : " + str(alt_rowt(deficit))) if deficit is not None and any(abs(x) > ceps for x in deficit): r.append((-deficit, -alt_rowt(deficit).simplified_quick(), -1.0, True, deficit_type)) if verbose: print("Deficit RESIDUAL : " + str(alt_rowt(deficit).simplified_quick())) if alt_prog is not None: rowt = alt_rowt ro = r ropt = None rnd = random.Random(123) for it in range(nshuffle): r = iutil.copy(ro) if it: rnd.shuffle(r) did = True while did: did = False for i in range(len(r) - 1, -1, -1): # for j in range(i - 1, -1, -1): for j in range(i): if abs(abs(r[i][2]) - abs(r[j][2])) > ceps: continue isn = 1 if r[i][2] * r[j][2] < 0: isn = -1 if r[i][4] == "#CTRI": if r[j][4] != "#TRI": continue elif r[j][4] == "#CTRI": if r[i][4] != "#TRI": continue else: if r[i][4] != r[j][4]: continue rii = r[i][1].get_meta("pf_note") rij = r[j][1].get_meta("pf_note") # rii_first = "" # if isinstance(rii, list): # rii_first = str(rii[0]) # elif (r[i][1] * r[i][2]).isnonneg(): # rii_first = "#SHA" # rij_first = "" # if isinstance(rij, list): # rij_first = str(rij[0]) # elif (r[j][1] * r[j][2]).isnonneg(): # rij_first = "#SHA" # print(str(r[i][1]) + " " + rii_first + " " + str(r[j][1]) + " " + rij_first) # if rii_first != rij_first: # continue t = (r[j][1] + r[i][1] * isn).simplified_break_hc(bnet = bnet) sumx = r[j][0] + r[i][0] * isn t2 = rowt(sumx).simplified_break_hc(bnet = bnet) if (len(t2), t2.complexity()) < (len(t), t.complexity()): t = t2 # if t.complexity() < (r[i][1] + r[j][1]).complexity(): # if t.len_iccount() <= max(r[i][1].len_iccount(), r[j][1].len_iccount()): if len(t) <= max(len(r[i][1]), len(r[j][1])): # print(str(rii) + " " + str(rij)) if rii is not None or rij is not None: t.add_meta("pf_note", rii or rij) # print(t) r[j] = (sumx, t, r[j][2], r[i][3] and r[j][3], r[j][4]) r.pop(i) did = True break if ropt is None or len(r) < len(ropt): ropt = r r = ropt # print([(x * d, t * d, iseq) for x, t, d, iseq in r]) # Split IC3 for i in range(len(r)): if len(r[i][1]) == 1 and r[i][1].terms[0][0].isic3(): x = r[i][1].terms[0][0].x lenx = len(x) x = x + x z = r[i][1].terms[0][0].z coeff = r[i][1].terms[0][1] * r[i][2] for k in range(lenx): e1 = Expr.Ic(x[k + 1], x[k + 2], z).simplified_quick() e2 = Expr.Ic(x[k + 1], x[k + 2], z + x[k]).simplified_quick() e1eq = False e2eq = False if coeff > 0: # if self.get_vec(e2, sparse = True, cat = True).iszero(): if ((self.bnet is not None and self.bnet.check_ic(e2)) or self.get_vec(e2, sparse = True, cat = True).iszero()): e2eq = True else: # if self.get_vec(e1, sparse = True, cat = True).iszero(): if ((self.bnet is not None and self.bnet.check_ic(e1)) or self.get_vec(e1, sparse = True, cat = True).iszero()): e1eq = True if e1eq or e2eq: r[i] = (numpy.array(alt_prog.get_vec(e1, cat = True)), e1, coeff, e1eq, "#TRI") r.append((numpy.array(alt_prog.get_vec(e2, cat = True)), e2, -coeff, e2eq, "#TRI")) break return [(x * d, t * d, iseq, curtype) for x, t, d, iseq, curtype in r] def get_dual_region_terms(self, expr, skip_trivial = True, ri_simplify = True, nshuffle = 3, compress = True): if self.dual_e is None or self.dual_u is None: return None # print(expr) rowt = self.row_toexpr() ineqs = self.get_dual_terms(expr, rowt = rowt, ri_simplify = ri_simplify, nshuffle = nshuffle) # print(ineqs) r = Region.universe() for i, (x, y, iseq, curtype) in enumerate(ineqs): if skip_trivial and curtype == "#TRI": continue if iseq: r.exprs_eq.append(y) else: r.exprs_ge.append(y) # print(r) return r def get_dual_steps(self, expr, prefer_short = True, simplify_exhaust = False, simplify_prog = True, chain = False, step_optimize = True, alt_prog = None, ri_simplify = True, nshuffle = 3, compress = True, deficit_separate = True): verbose = PsiOpts.settings.get("verbose_proof_step", False) if self.dual_e is None or self.dual_u is None: return None # print(expr) rowt = self.row_toexpr() ceps = PsiOpts.settings["eps"] ineqs = self.get_dual_terms(expr, rowt = rowt, alt_prog = alt_prog, ri_simplify = ri_simplify, nshuffle = nshuffle, deficit_separate = deficit_separate) # print(ineqs) veclen = self.nxvar + 1 get_vec = self.get_vec bnet = self.bnet if alt_prog is not None: veclen = alt_prog.nxvar + 1 get_vec = alt_prog.get_vec rowt = alt_prog.row_toexpr() bnet = None # for x, d in zip(self.Au.x, self.dual_u): # if abs(d) > ceps: # ineqs.append(SparseMat.from_row(x, self.nxvar).tonumpyarray_row() * d) # for x, d in zip(self.Ae.x, self.dual_e): # if abs(d) > ceps: # ineqs.append(SparseMat.from_row(x, self.nxvar).tonumpyarray_row() * d) prog = None if simplify_prog and self.bnet is not None: prog = self.get_bnet_fcn_region().init_simplified_prog(self.index.copy()) # for x, xin in zip(prog.Ae.x, prog.Ae.rowinfo): # print(str(xin) + " " + str(x)) fsum = numpy.zeros(veclen) for x, y, iseq, curtype in ineqs: fsum += x fsum = SparseMat.from_dense_row(fsum) # print(rowt(fsum).simplified()) csum = numpy.zeros(veclen) csum = SparseMat.from_dense_row(csum) csumt = Expr.zero() ineqs = [(SparseMat.from_dense_row(x), y, iseq, curtype) for x, y, iseq, curtype in ineqs] sn = 1 r0 = [] r1 = [] riseq = [] rcurtype = [] left_realvars = None if chain: expr0, expr, eqnstr = expr.split_present(">=", lhsvar = "real") # print(expr0) # print(expr) # print(eqnstr) gv = get_vec(expr0) csum = numpy.array(list(gv[0]) + [gv[1]]) csum = SparseMat.from_dense_row(csum) if eqnstr == ">=": sn = -1 fsum = fsum * sn + csum left_realvars = expr0.allcomprealvar() # print(rowt(fsum)) r0.append(None) r1.append(expr0) riseq.append(False) rcurtype.append("#START") exsum_gv = get_vec(expr) exsum = numpy.array(list(exsum_gv[0]) + [exsum_gv[1]]) exsum = SparseMat.from_dense_row(exsum) while ineqs: mini = -1 mint = None minc = 1000000000000000000 for i, (x, y, iseq, curtype) in enumerate(ineqs): t = None if prog is not None: t = (csumt + y * sn).simplified_quick() else: # t = rowt(csum + x * sn).simplified_break_hc(bnet = self.bnet) # t2 = (rowt(csum - exsum + x * sn) + expr).simplified_break_hc(bnet = self.bnet) t = rowt(csum + x * sn).simplified_quick(bnet = bnet) t2 = (rowt(csum - exsum + x * sn) + expr).simplified_quick(bnet = bnet) if t2.complexity() < t.complexity(): t = t2 if simplify_exhaust: t.simplify_break_hc() tc = t.complexity() if chain: tc += len(t.allcomprealvar()) * 10000 # if not y.ispresent(left_realvars): # tc += 1000000 else: if not y.ispresent("realvar"): tc += 1000000 if tc < minc: minc = tc mini = i mint = t if not step_optimize: break if prog is not None: mint = mint.simplified_prog(prog = prog) mint.sort() csumt = mint # r0.append(rowt(ineqs[mini]).simplified_break_hc(bnet = bnet)) r0.append(ineqs[mini][1]) csum += ineqs[mini][0] * sn r1.append(mint) riseq.append(ineqs[mini][2]) rcurtype.append(ineqs[mini][3]) ineqs.pop(mini) # print(r0) # print(r1) # print(riseq) if compress: i = 0 while i < len(r0) - 1: maxj = i while maxj + 1 < len(r0) and rcurtype[maxj + 1] == rcurtype[i]: maxj += 1 for j in range(maxj, i + 1, -1): cdiff = ((r1[j] - r1[i]) * sn).simplified_quick() if len(cdiff) == 1 and cdiff.terms[0][0].isicle2(): r1[i + 1] = r1[j] r0[i + 1] = cdiff riseq[i + 1] = all(riseq[i + 1 : j + 1]) del r0[i + 2 : j + 1] del r1[i + 2 : j + 1] del riseq[i + 2 : j + 1] del rcurtype[i + 2 : j + 1] break i += 1 return (r0, r1, riseq, sn) def write_pf(self, x): verbose = PsiOpts.settings.get("verbose_proof_step", False) write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False) include_note = PsiOpts.settings.get("proof_note", False) note_skip_trivial = PsiOpts.settings.get("proof_note_skip_trivial", False) pf = None if PsiOpts.settings["proof_step_dualsum"]: chain = PsiOpts.settings["proof_step_chain"] alt_prog = None if PsiOpts.settings["proof_step_bayesnet"]: alt_prog = self.copy_empty_nobnet() dualsteps = self.get_dual_steps(x, prefer_short = PsiOpts.settings["proof_step_dualsum_short"], simplify_exhaust = PsiOpts.settings["proof_step_dualsum_exhaust"], simplify_prog = PsiOpts.settings["proof_step_dualsum_prog"], chain = chain, step_optimize = PsiOpts.settings["proof_step_optimize"], alt_prog = alt_prog, ri_simplify = PsiOpts.settings["proof_step_term_simplify"], nshuffle = PsiOpts.settings["proof_step_term_nshuffle"], compress = PsiOpts.settings["proof_step_compress"], deficit_separate = PsiOpts.settings["proof_deficit_separate"]) if dualsteps is None: return cadds, csums, ciseqs, csn = dualsteps if len(cadds) == 0: return if not PsiOpts.settings["proof_noskip"] and len(cadds) <= 1: return pf = None if chain: ineqchain = [] for i, (cadd, csum, ciseq) in enumerate(zip(cadds, csums, ciseqs)): cclaim = [] pf_note = None # if cadd is not None: # print(str(cadd)+ " " + str(cadd.get_meta("pf_note"))) if include_note and cadd is not None: if not note_skip_trivial or ciseq or not cadd.isnonneg(): cclaim += ["(", "since", " "] pf_note = cadd.get_meta("pf_note") if pf_note is not None: if isinstance(pf_note, list): cclaim += pf_note else: cclaim += [pf_note] cclaim += [":", " "] if ciseq: cclaim.append((cadd.copy() == 0).remove_meta("pf_note")) else: cclaim.append((cadd.copy() >= 0).remove_meta("pf_note")) cclaim += [")"] ceqnstr = "" if ciseq: ceqnstr = "=" elif i: ceqnstr = "<=" if csn > 0 else ">=" ineqchain.append([csum, ceqnstr, cclaim]) pf = ProofObj.from_region(ineqchain, c = "Steps: ") if PsiOpts.settings["proof_note_color"] is not None: pf.meta["note_color"] = PsiOpts.settings["proof_note_color"] if PsiOpts.settings["proof_note_newline"] is not None: pf.meta["note_newline"] = PsiOpts.settings["proof_note_newline"] else: rt = Region.universe() for t in cadds: rt &= t >= 0 if not write_pf_repeat_claim or rt.isuniverse(): pf = ProofObj.from_region(x >= 0, c = "Claim: ") else: pf = ProofObj.from_region(("implies", rt, x >= 0), c = "Claim: ") for i, (cadd, csum) in enumerate(zip(cadds, csums)): cclaim = [] if i == 0: cclaim += ["Have: ", cadd >= 0] else: cclaim += ["Add: ", cadd >= 0] pf_note = None if include_note: pf_note = cadd.get_meta("pf_note") if pf_note is not None: cclaim += [" ", "("] if isinstance(pf_note, list): cclaim += pf_note else: cclaim += [pf_note] cclaim += [")"] cadd = None if chain: cadd = csum else: cadd = csum >= 0 pf += ProofObj.from_region(cadd, c = cclaim) else: r = self.get_dual_sum_region(x) if r is None: return pf = ProofObj.from_region(r, c = ["Duals for ", x >= 0]) PsiOpts.set_setting(proof_add = pf) def write_pf_old(self, x): write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False) r = self.get_dual_sum_region(x) if r is None: return xstr = str(x) + " >= 0" pf = None if PsiOpts.settings["proof_step_dualsum"]: rt = r.removed_trivial() pf = None if not write_pf_repeat_claim or rt.isuniverse(): pf = ProofObj.from_region(x >= 0, c = "Claim: ") else: pf = ProofObj.from_region(("implies", rt, x >= 0), c = "Claim: ") cadds, csums = r.get_sum_seq(prefer_short = PsiOpts.settings["proof_step_dualsum_short"], simplify_exhaust = PsiOpts.settings["proof_step_dualsum_exhaust"], bnet = self.bnet, target = x) for i, (cadd, csum) in enumerate(zip(cadds, csums)): if i == 0: pf += ProofObj.from_region(csum >= 0, c = "Have:") else: pf += ProofObj.from_region(csum >= 0, c = ["Add: ", cadd >= 0]) # self.trytry() # r2 = Region.universe() # cur = Expr.zero() # for x in r.exprs_ge: # cur = (cur + x).simplified() # r2.exprs_ge.append(cur) # for x in r.exprs_eq: # cur = (cur + x).simplified() # r2.exprs_ge.append(cur) # pf += ProofObj.from_region(r2, c = "Steps for " + xstr) else: pf = ProofObj.from_region(r, c = ["Duals for ", x >= 0]) PsiOpts.set_setting(proof_add = pf) def get_extreme_rays_vec(self, A = None): ma = None cn = 0 if A is None: cn = self.Au.width ma = self.Ae.tonumpyarray() ma = numpy.vstack((numpy.zeros(cn), self.Au.tonumpyarray(), ma, -ma)) else: cn = A.width ma = numpy.vstack((numpy.zeros(cn), A.tonumpyarray())) #print(ma) hull = scipy.spatial.ConvexHull(ma) #print("ConvexHull finished") r = [] #tone = numpy.array(self.get_vec(Expr.H(self.index.comprv))) rset = set() ceps = PsiOpts.settings["eps"] for i in range(len(hull.simplices)): if abs(hull.equations[i,-1]) > ceps: continue t = hull.equations[i,:-1] vv = max(abs(t)) if vv > ceps: t = t / vv ts = ",".join(iutil.float_tostr(x) for x in t) if ts not in rset: rset.add(ts) r.append(t) return r def get_region_elim_rays(self, aux = None, A = None, skip_simplify = False): if aux is None: aux = Comp.empty() ceps = PsiOpts.settings["eps"] auxmask = self.index.get_mask(aux) #print("Before get_extreme_rays_vec") vs = self.get_extreme_rays_vec(A) #print("After get_extreme_rays_vec") idt = self.id_toexpr() var_expr = [None] * self.nxvar var_id = [-1] * self.nxvar var_id_inv = [-1] * self.nxvar nleft = 0 for j in range(self.nxvar): var_expr[j], mask = idt(j) if mask & auxmask == 0: var_id[j] = nleft var_id_inv[nleft] = j nleft += 1 vset = set() vset.add(",".join(["0"] * nleft)) ma = numpy.zeros((1, nleft)) for v in vs: nvv = 0 vv = numpy.zeros(nleft) for i in range(len(v)): if var_id[i] >= 0: vv[var_id[i]] = v[i] nvv += 1 if nvv > 0: ts = ",".join(iutil.float_tostr(x) for x in vv) #print("ts = " + ts) if ts not in vset: vset.add(ts) ma = numpy.vstack((ma, vv)) eig, ev = numpy.linalg.eig(ma.T.dot(ma)) #print(nleft) #print(ma) #print(ma.T.dot(ma)) #print(eig) #print(ev) ev0 = numpy.zeros((nleft, 0)) ev1 = numpy.zeros((nleft, 0)) for i in range(nleft): if abs(eig[i]) <= ceps: ev0 = numpy.hstack((ev0, ev[:,i:i+1])) else: ev1 = numpy.hstack((ev1, ev[:,i:i+1])) def expr_fromvec(vec): vv = max(abs(vec)) if vv > ceps: vec = vec / vv cexpr = Expr.zero() for i in range(nleft): if abs(vec[i]) > ceps: cexpr += var_expr[var_id_inv[i]] * vec[i] if not skip_simplify: cexpr.simplify_quick() return cexpr mv = ma.dot(ev1) #print(ev0) #print(ev1) r = Region.universe() for i in range(ev0.shape[1]): expr = expr_fromvec(ev0[:,i]) #print(expr) r.iand_norename(expr == 0) if ev1.shape[1] == 0: return r if ev1.shape[1] == 1: svis = [False, False] for i in range(len(mv)): if mv[i, 1] > ceps: svis[1] = True if mv[i, 1] < -ceps: svis[0] = True if svis[0] and svis[1]: return r expr = expr_fromvec(ev1[:,0]) if svis[0]: if not expr.isnonneg(): r.iand_norename(expr >= 0) elif svis[1]: if not expr.isnonpos(): r.iand_norename(expr <= 0) else: r.iand_norename(expr == 0) return r hull = scipy.spatial.ConvexHull(mv) #print(len(hull.simplices)) #tone_o = numpy.array(self.get_vec(Expr.H(self.index.comprv - aux))) #tone = numpy.array([tone_o[var_id_inv[j]] for j in range(nleft)]) rset = set() for i in range(len(hull.simplices)): if abs(hull.equations[i,-1]) > ceps: continue t = ev1.dot(hull.equations[i,:-1]) vv = max(abs(t)) #print(t) #print(vv) if vv > ceps: t = t / vv #print(t) ts = ",".join(iutil.float_tostr(x) for x in t) if ts not in rset: rset.add(ts) expr = expr_fromvec(t) if not expr.isnonpos(): r.iand_norename(expr <= 0) if not skip_simplify: r.simplify_quick() return r def get_extreme_rays(self, A = None): idt = self.id_toexpr() vs = self.get_extreme_rays_vec(A) r = RegionOp.union([]) ceps = PsiOpts.settings["eps"] for v in vs: tr = Region.universe() for i in range(len(v)): if abs(v[i]) > ceps: te, mask = idt(i) tr &= te == v[i] r |= tr return r def get_vec(self, x, sparse = False, cat = False): optobj = SparseMat(self.nvar) optobj.addrow() self.addExpr(optobj, x) optobj.simplify_last_row() c1 = 0.0 if len(self.constmap): c1 = optobj.sumrows(self.constmap, remove = True)[0] if not optobj.mapcol(self.xvarid): return None, None optobj.width = self.nxvar if sparse: if cat: optobj.width += 1 if c1 != 0.0: optobj.x[0].append((self.nvar + 1, c1)) return optobj else: return (optobj, c1) else: if cat: return optobj.row_dense(0) + [c1] else: return (optobj.row_dense(0), c1) def call_prog(self, c): verbose = PsiOpts.settings.get("verbose_lp", False) ceps = PsiOpts.settings["eps"] zero_obj = False if all(abs(x) <= ceps for x in c): if not self.affine_present: return (0.0, [0.0 for i in range(self.nxvar)]) c = [0.0] * len(c) c[0] = 1.0 zero_obj = True # return (None, None) if len(self.Au.x) == 0 and len(self.Ae.x) == 0: return (None, None) if self.solver == "scipy": with warnings.catch_warnings(): warnings.simplefilter("ignore") res = scipy.optimize.linprog(c, self.solver_param["Aus"], self.bu, self.solver_param["Aes"], self.be, bounds = (None, None), method = "interior-point", options={'sparse': True}) if res.status == 0: return (0.0 if zero_obj else float(res.fun), [float(res.x[i]) for i in range(self.nxvar)]) return (None, None) elif self.solver.startswith("pulp."): prob = self.solver_param["prob"] xvar = self.solver_param["xvar"] cexpr = None for i in range(len(c)): if abs(c[i]) > PsiOpts.settings["eps"]: if cexpr is None: cexpr = c[i] * xvar[str(i)] else: cexpr += c[i] * xvar[str(i)] if cexpr is not None: prob.setObjective(cexpr) try: res = prob.solve(iutil.pulp_get_solver(self.solver)) except Exception as err: if verbose: warnings.warn(str(err), RuntimeWarning) res = 0 if pulp.LpStatus[res] == "Optimal": return (0.0 if zero_obj else prob.objective.value(), [xvar[str(i)].value() for i in range(self.nxvar)]) return (None, None) elif self.solver.startswith("pyomo."): coptions = PsiOpts.get_pyomo_options() opt = self.solver_param["opt"] model = self.solver_param["model"] def o_rule(model): cexpr = None for i in range(len(c)): if abs(c[i]) > PsiOpts.settings["eps"]: if cexpr is None: cexpr = c[i] * model.x[i + 1] else: cexpr += c[i] * model.x[i + 1] return cexpr model.del_component("o") model.o = pyo.Objective(rule=o_rule) # print("CALL_PROG START " + str(sum(abs(x) for x in c))) # print(c) with warnings.catch_warnings(): warnings.simplefilter("ignore") try: res = opt.solve(model, options = coptions, tee = PsiOpts.settings["verbose_solver"]) except Exception as err: if verbose: warnings.warn(str(err), RuntimeWarning) res = None # print("CALL_PROG FINISHED") if (res is not None and res.solver.status == pyo.SolverStatus.ok and res.solver.termination_condition == pyo.TerminationCondition.optimal): return (0.0 if zero_obj else model.o(), [model.x[i + 1]() for i in range(self.nxvar)]) return (None, None) elif self.solver.startswith("ortools."): model = self.solver_param["model"] xvar = self.solver_param["xvar"] timelimit = PsiOpts.timer_left_sec() if timelimit is not None: model.set_time_limit(timelimit * 1000) cexpr = None for i in range(len(c)): if abs(c[i]) > PsiOpts.settings["eps"]: if cexpr is None: cexpr = c[i] * xvar[i] else: cexpr += c[i] * xvar[i] if cexpr is not None: model.Minimize(cexpr) try: res = model.Solve() except Exception as err: if verbose: warnings.warn(str(err), RuntimeWarning) res = 0 if res == ortools.linear_solver.pywraplp.Solver.OPTIMAL: return (0.0 if zero_obj else model.Objective().Value(), [xvar[i].solution_value() for i in range(self.nxvar)]) return (None, None) def checkexpr_ge0(self, x, saved = False, optval = None): verbose = PsiOpts.settings.get("verbose_lp", False) if self.pinfeas: return True if self.eps_present: x = x.substituted(Expr.eps(), Expr.eps() * (self.lp_eps_obj / self.lp_eps)) zero_cutoff = self.zero_cutoff if saved: if False: optobj = SparseMat(self.nvar) optobj.addrow() self.addExpr(optobj, x) optobj.simplify_last_row() if len(optobj.x[0]) == 0: return True if not optobj.mapcol(self.xvarid): return False optobj.width = self.nxvar optobj, optobjc1 = self.get_vec(x, sparse = True) if optobj is None: return False #cvec = numpy.array(c) for i in range(len(self.saved_var)): a = self.saved_var[i] #if sum(x * y for x, y in zip(c, a)) < zero_cutoff: #if numpy.dot(a, cvec) < zero_cutoff: if sum(ca * a[j] for j, ca in optobj.x[0]) + optobjc1 < zero_cutoff: for j in range(i, 0, -1): self.saved_var[j], self.saved_var[j - 1] = self.saved_var[j - 1], self.saved_var[j] return False return True verbose_lp_cons = PsiOpts.settings.get("verbose_lp_cons", False) if verbose_lp_cons: print("============ LP constraints ============") print(self.get_region(skip_simplify = True)) print("============ LP objective ============") if False: optobj = SparseMat(self.nvar) optobj.addrow() self.addExpr(optobj, x) optobj.simplify_last_row() optobjc1 = 0.0 if len(self.constmap): optobjc1 = optobj.sumrows(self.constmap, remove = True)[0] if not optobj.mapcol(self.xvarid): return False optobj.width = self.nxvar optobj, optobjc1 = self.get_vec(x, sparse = True) if optobj is None: print("objective contains new terms") else: optreg = self.get_region(A = optobj, skip_simplify = True) if len(optreg.exprs_ge) > 0: print(optreg.exprs_ge[0] + optobjc1) print("========================================") if self.fcn_mode >= 1 and not self.noskip: if x.isnonpos_hc(): fcn_res = True for (a, c) in x.terms: if not self.checkfcn(a.x[0], a.z): fcn_res = False break #print("FCN " + str(x) + " " + str(fcn_res)) if fcn_res: if verbose: print("LP True: fcn") return True if self.fcn_mode >= 2: if verbose: print("LP False: fcn") return False if (self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN) and not self.noskip: if x.isnonpos_ic2(): if self.bnet.check_ic(x): if verbose: print("LP True: bnet") return True #return False res = None ceps = PsiOpts.settings["eps"] c, c1 = self.get_vec(x) if c is None: pass else: #if len([x for x in c if abs(x) > PsiOpts.settings["eps"]]) == 0: if all(abs(x) <= ceps for x in c): if c1 >= -ceps: if verbose: print("LP True: zero") if self.dual_form and self.noskip: self.dual_u = [0.0] * len(self.Au.x) self.dual_e = [0.0] * len(self.Ae.x) if self.dual_pf: self.write_pf(x) return True else: c = None if c is None: c = [-ceps * 2] + [0.0] * (self.nxvar - 1) c1 = -1.0 #return False if len(self.Au.x) == 0 and len(self.Ae.x) == 0: if verbose: print("LP False: no constraints") return False if verbose: print("LP nrv=" + str(self.index.num_rv()) + " nreal=" + str(self.index.num_real()) + " nvar=" + str(self.Au.width) + "/" + str(self.nvar) + " nineq=" + str(len(self.Au.x)) + " neq=" + str(len(self.Ae.x)) + " solver=" + self.solver) if self.solver == "scipy": rec_limit = 50 if self.Au.width > rec_limit: warnings.warn("The scipy solver is not recommended for problems of size > " + str(rec_limit) + " (current " + str(self.Au.width) + "). " + "Please switch to another solver. See " + "https://github.com/cheuktingli/psitip#solver", RuntimeWarning) with warnings.catch_warnings(): warnings.simplefilter("ignore") res = scipy.optimize.linprog(c, self.solver_param["Aus"], self.bu, self.solver_param["Aes"], self.be, bounds = (None, None), method = "interior-point", options={'sparse': True}) if verbose: print(" status=" + str(res.status) + " optval=" + str(res.fun)) if self.affine_present and self.lp_bounded and res.status == 2: return True if res.status == 0: self.optval = float(res.fun) + c1 if optval is not None: optval.append(self.optval) if self.val_enabled: self.val_x = [0.0] * self.nxvar for i in range(self.nxvar): self.val_x[i] = float(res.x[i]) if self.optval >= zero_cutoff: return True if res.status == 0 and self.save_res: self.saved_var.append(array.array("d", [float(a) for a in res.x])) #self.saved_var.append(numpy.array(list(res.x))) if verbose: print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1]))) return False elif self.solver.startswith("pulp."): if self.dual_form: self.dual_form_obj = c self.dual_form_cutoff = c1 self.calc_solver() if self.dual_form_infeas: return False prob = self.solver_param["prob"] xvar = self.solver_param["xvar"] cexpr = None if self.dual_form: for j in range(self.dual_form_ncons): tcoeff = self.dual_form_weights[j] if cexpr is None: cexpr = tcoeff * xvar[str(j)] else: cexpr += tcoeff * xvar[str(j)] else: for i in range(len(c)): if abs(c[i]) > PsiOpts.settings["eps"]: if cexpr is None: cexpr = c[i] * xvar[str(i)] else: cexpr += c[i] * xvar[str(i)] if cexpr is not None: prob.setObjective(cexpr) try: res = prob.solve(iutil.pulp_get_solver(self.solver)) except Exception as err: if verbose: warnings.warn(str(err), RuntimeWarning) res = 0 if verbose: print(" status=" + pulp.LpStatus[res] + " optval=" + str(prob.objective.value())) if self.affine_present and self.lp_bounded and (pulp.LpStatus[res] == "Infeasible" or pulp.LpStatus[res] == "Undefined"): return True #if pulp.LpStatus[res] == "Infeasible": # return True if pulp.LpStatus[res] == "Optimal": if self.dual_form: self.dual_u = [0.0] * len(self.Au.x) self.dual_e = [0.0] * len(self.Ae.x) for j in range(len(self.Au.x)): self.dual_u[j] = -xvar[str(j)].value() for j in range(len(self.Ae.x)): self.dual_e[j] = -xvar[str(j * 2 + len(self.Au.x))].value() + xvar[str(j * 2 + len(self.Au.x) + 1)].value() if self.dual_pf: self.write_pf(x) return True self.optval = prob.objective.value() + c1 if optval is not None: optval.append(self.optval) if self.dual_enabled: self.dual_u = [0.0] * len(self.Au.x) self.dual_e = [0.0] * len(self.Ae.x) for i, (name, c) in enumerate(prob.constraints.items()): #print(str(i) + " " + name + " " + str(c.pi)) if c.pi is None: self.dual_u = None self.dual_e = None break if i < len(self.Au.x): self.dual_u[i] = c.pi else: self.dual_e[i - len(self.Au.x)] = c.pi if self.dual_pf: self.write_pf(x) if self.val_enabled: self.val_x = [0.0] * self.nxvar for i in range(self.nxvar): self.val_x[i] = xvar[str(i)].value() if self.optval >= zero_cutoff: return True if pulp.LpStatus[res] == "Optimal" and self.save_res: self.saved_var.append(array.array("d", [xvar[str(i)].value() for i in range(len(c))])) #self.saved_var.append(numpy.array([xvar[str(i)].value() for i in range(len(c))])) if verbose: print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1]))) return False else: return True elif self.solver.startswith("pyomo."): if self.dual_form: self.dual_form_obj = c self.dual_form_cutoff = c1 self.calc_solver() if self.dual_form_infeas: return False coptions = PsiOpts.get_pyomo_options() opt = self.solver_param["opt"] model = self.solver_param["model"] def o_rule(model): if self.dual_form: return sum(model.x[j + 1] * self.dual_form_weights[j] for j in range(self.dual_form_ncons)) else: cexpr = None for i in range(len(c)): if abs(c[i]) > PsiOpts.settings["eps"]: if cexpr is None: cexpr = c[i] * model.x[i + 1] else: cexpr += c[i] * model.x[i + 1] return cexpr model.del_component("o") model.o = pyo.Objective(rule=o_rule) with warnings.catch_warnings(): warnings.simplefilter("ignore") try: res = opt.solve(model, options = coptions, tee = PsiOpts.settings["verbose_solver"]) except Exception as err: if verbose: warnings.warn(str(err), RuntimeWarning) res = None if verbose and res is not None: print(" status=" + ("OK" if res.solver.status == pyo.SolverStatus.ok else "NO")) if res is not None and self.affine_present and self.lp_bounded and res.solver.termination_condition == pyo.TerminationCondition.infeasible: return True #print("save_res = " + str(self.save_res)) if (res is not None and res.solver.status == pyo.SolverStatus.ok and res.solver.termination_condition == pyo.TerminationCondition.optimal): if self.dual_form: self.dual_u = [0.0] * len(self.Au.x) self.dual_e = [0.0] * len(self.Ae.x) for j in range(len(self.Au.x)): self.dual_u[j] = -model.x[j + 1]() for j in range(len(self.Ae.x)): self.dual_e[j] = -model.x[j * 2 + len(self.Au.x) + 1]() + model.x[j * 2 + len(self.Au.x) + 2]() if self.dual_pf: self.write_pf(x) return True self.optval = model.o() + c1 if optval is not None: optval.append(self.optval) if self.dual_enabled: self.dual_u = [0.0] * len(self.Au.x) self.dual_e = [0.0] * len(self.Ae.x) for c in model.component_objects(pyo.Constraint, active=True): for i, index in enumerate(c): if i < len(self.Au.x): self.dual_u[i] = model.dual[c[index]] else: self.dual_e[i - len(self.Au.x)] = model.dual[c[index]] if self.dual_pf: self.write_pf(x) if self.val_enabled: self.val_x = [0.0] * self.nxvar for i in range(self.nxvar): self.val_x[i] = model.x[i + 1]() if self.optval >= zero_cutoff: #print("RETURN TRUE") return True if self.save_res: self.saved_var.append(array.array("d", [model.x[i + 1]() for i in range(len(c))])) #self.saved_var.append(numpy.array([model.x[i + 1]() for i in range(len(c))])) if verbose: print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1]))) return False elif self.solver.startswith("ortools."): if self.dual_form: self.dual_form_obj = c self.dual_form_cutoff = c1 self.calc_solver() if self.dual_form_infeas: return False model = self.solver_param["model"] xvar = self.solver_param["xvar"] timelimit = PsiOpts.timer_left_sec() if timelimit is not None: model.set_time_limit(timelimit * 1000) cexpr = None if self.dual_form: for j in range(self.dual_form_ncons): tcoeff = self.dual_form_weights[j] if cexpr is None: cexpr = tcoeff * xvar[j] else: cexpr += tcoeff * xvar[j] else: for i in range(len(c)): if abs(c[i]) > PsiOpts.settings["eps"]: if cexpr is None: cexpr = c[i] * xvar[i] else: cexpr += c[i] * xvar[i] if cexpr is not None: model.Minimize(cexpr) try: res = model.Solve() except Exception as err: if verbose: warnings.warn(str(err), RuntimeWarning) res = 0 if verbose: print(" status=" + str(res) + " optval=" + str(model.Objective().Value())) if self.affine_present and self.lp_bounded and (res == ortools.linear_solver.pywraplp.Solver.INFEASIBLE): return True #if pulp.LpStatus[res] == "Infeasible": # return True if res == ortools.linear_solver.pywraplp.Solver.OPTIMAL: if self.dual_form: self.dual_u = [0.0] * len(self.Au.x) self.dual_e = [0.0] * len(self.Ae.x) for j in range(len(self.Au.x)): self.dual_u[j] = -xvar[j].solution_value() for j in range(len(self.Ae.x)): self.dual_e[j] = -xvar[j * 2 + len(self.Au.x)].solution_value() + xvar[j * 2 + len(self.Au.x) + 1].solution_value() if self.dual_pf: self.write_pf(x) return True self.optval = model.Objective().Value() + c1 if optval is not None: optval.append(self.optval) # if self.dual_enabled: # self.dual_u = [0.0] * len(self.Au.x) # self.dual_e = [0.0] * len(self.Ae.x) # for i, (name, c) in enumerate(prob.constraints.items()): # #print(str(i) + " " + name + " " + str(c.pi)) # if c.pi is None: # self.dual_u = None # self.dual_e = None # break # if i < len(self.Au.x): # self.dual_u[i] = c.pi # else: # self.dual_e[i - len(self.Au.x)] = c.pi # if self.dual_pf: # self.write_pf(x) if self.val_enabled: self.val_x = [0.0] * self.nxvar for i in range(self.nxvar): self.val_x[i] = xvar[i].solution_value() if self.optval >= zero_cutoff: return True if res == ortools.linear_solver.pywraplp.Solver.OPTIMAL and self.save_res: self.saved_var.append(array.array("d", [xvar[i].solution_value() for i in range(len(c))])) if verbose: print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1]))) return False else: return True else: if self.solver == "": raise RuntimeError("No solver found. Please install a solver. See: " + "https://github.com/cheuktingli/psitip#solver") else: raise RuntimeError("Solver " + self.solver + " not found. Please install another solver. See: " + "https://github.com/cheuktingli/psitip#solver") def checkexpr_eq0(self, x, saved = False): return self.checkexpr_ge0(x, saved) and self.checkexpr_ge0(-x, saved) def checkexpr(self, x, sg, saved = False): if sg == "==": return self.checkexpr_eq0(x, saved) elif sg == ">=": return self.checkexpr_ge0(x, saved) elif sg == "<=": return self.checkexpr_ge0(-x, saved) else: return False def evalexpr_ge0_saved(self, x): c, c1 = self.get_vec(x) if c is None: return -1e8 r = 0.0 for a in self.saved_var: r += min(sum([x * y for x, y in zip(c, a)]) + c1, 0.0) return r def get_val(self, expr): if self.val_x is None: return None c, c1 = self.get_vec(expr, sparse = True) r = c1 if len(c.x) == 0: return r for j, a in c.x[0]: r += self.val_x[j] * a return r def __call__(self, x): if isinstance(x, Expr): return self.get_val(x) elif isinstance(x, Region): return x.evalcheck(lambda expr: self.get_val(expr)) return None def __getitem__(self, x): return self(x) def clear_dual(self): self.dual_u = None self.dual_e = None def get_dual_expr(self, x): if self.dual_u is None or self.dual_e is None: return None c, c1 = self.get_vec(x, sparse = True) c.simplify() #print(c) r = 0.0 for a, d in itertools.chain(zip(self.Ae.x, self.dual_e), zip(self.Au.x, self.dual_u)): #print(a) a2 = SparseMat.from_row(a, self.Au.width) a2.simplify() t = a2.ratio(c) if t is not None: r += d * t return r def get_dual(self, reg): if isinstance(reg, Expr): return self.get_dual_expr(reg) r = [] for x in reg.exprs_ge: r.append(self.get_dual_expr(x)) for x in reg.exprs_eq: r.append(self.get_dual_expr(x)) if len(r) == 1: return r[0] else: return r def proj_hull(prog, n, init_pt = None, toexpr = None, iscone = False, isfrac = None, max_facet = None, num_simplex = None, init_pts_outer = None, pts_outer = None): """Convex hull method for polyhedron projection. C. Lassez and J.-L. Lassez, Quantifier elimination for conjunctions of linear constraints via a convex hull algorithm, IBM Research Report, T.J. Watson Research Center, RC 16779 (1991) """ if cdd is None: raise RuntimeError("Convex hull method requires pycddlib. Please install it first.") verbose = PsiOpts.settings.get("verbose_discover", False) verbose_outer = PsiOpts.settings.get("verbose_discover_outer", False) verbose_detail = PsiOpts.settings.get("verbose_discover_detail", False) verbose_terms = PsiOpts.settings.get("verbose_discover_terms", False) verbose_terms_inner = PsiOpts.settings.get("verbose_discover_terms_inner", False) verbose_terms_outer = PsiOpts.settings.get("verbose_discover_terms_outer", False) verbose_aux_reduce = PsiOpts.settings.get("verbose_aux_reduce", False) #max_denom = PsiOpts.settings.get("max_denom", 1000) if isfrac is None: isfrac = PsiOpts.settings.get("discover_hull_frac_enabled", False) frac_denom = PsiOpts.settings.get("discover_hull_frac_denom", -1) if max_facet is None: max_facet = PsiOpts.settings.get("discover_max_facet", 10000000000) if num_simplex is None: num_simplex = PsiOpts.settings.get("discover_num_simplex", 1) ceps = PsiOpts.settings["eps_lp"] rnd_started = False mat = None if init_pts_outer is not None and len(init_pts_outer) >= 2: cen = [0.0] * n for p in init_pts_outer: for i in range(n): cen[i] += p[i] for i in range(n): cen[i] /= len(init_pts_outer) init_pts = [] did = False for p in init_pts_outer: v = [cenx - px for cenx, px in zip(cen, p)] if all(abs(vx) <= ceps for vx in v): continue val, a = prog(v) if not all(abs(ax - px) <= ceps for ax, px in zip(a, p)): did = True init_pts.append(a) if pts_outer is not None: pts_outer.append(a) if not did: if verbose_aux_reduce: print("discover fails to shrink") return None mat = cdd.Matrix([[1] + p for p in init_pts], number_type=("fraction" if isfrac else "float")) else: if init_pt is None: if iscone: init_pt = [0] * n else: _, init_pt = prog([1] * n) if init_pt is None: init_pt = [0] * n if pts_outer is not None: pts_outer.append(init_pt) mat = cdd.Matrix([[1] + init_pt], number_type=("fraction" if isfrac else "float")) mat.rep_type = cdd.RepType.GENERATOR ineqs_tight = [] ineqs_tried = [] matP = None matP_avg = None matQ = None matQ_null = None did = True while did: if PsiOpts.is_timer_ended(): break if verbose: print("NUMPOINT = " + str(mat.row_size) + " NUMDIM = " + str(mat.col_size), flush = True) did = False tgt_num_point = mat.col_size + num_simplex isfull = max_facet is None or (max_facet > 0 and mat.col_size * 0.5 * numpy.log(mat.row_size) <= numpy.log(max_facet)) # or tgt_num_point >= mat.row_size poly = None ineqs = [] lset = set() ineqs_row_size = 0 if isfull: poly = cdd.Polyhedron(mat) # print("MAT:") # print(mat) ineqs = poly.get_inequalities() # print("INEQ:") # print(ineqs) lset = ineqs.lin_set ineqs_row_size = ineqs.row_size else: if not rnd_started: if not PsiOpts.has_timer(): warnings.warn("Convex hull method: Max number of facets discover_max_facet = " + str(max_facet) + " reached (current = " + str(round(numpy.exp(mat.col_size * 0.5 * numpy.log(mat.row_size)))) + "). Switching to randomized subset. Program will not terminate unless the block is enclosed by \"with PsiOpts(timelimit = ???):\" or \"with PsiOpts(stop_file = ???):\".", RuntimeWarning) rnd_started = True rnd = PsiOpts.get_random() if matP is None: # tgt_num_point = rnd.randrange(1, mat.col_size + 1) matP = numpy.array([[float(mat[i][j]) for j in range(1, mat.col_size)] for i in range(mat.row_size)], ndmin = 2) # print(matP, flush = True) matP_avg = numpy.mean(matP, 0) # print(matP_avg, flush = True) matQ = numpy.array([matP[i,:] - matP_avg for i in range(mat.row_size)]) matQ_null = scipy.linalg.null_space(matQ) # print(matQ_null, flush = True) # currank = numpy.linalg.matrix_rank(matQ) currank = matQ.shape[1] - matQ_null.shape[1] min_num_point = max(currank, 1) max_num_point = min(matP.shape) # tgt_num_point = rnd.randrange(min_num_point, max_num_point + 1) tgt_num_point = min_num_point if verbose: print("RANDOM SUBSET NUMPOINT = " + str(tgt_num_point), flush = True) cid = list(range(mat.row_size)) rnd.shuffle(cid) if iscone: matA = numpy.array([[0.0] * matQ.shape[1]] + [matQ[cid[i],:] for i in range(tgt_num_point - 1)], ndmin = 2) else: matA = numpy.array([matQ[cid[i],:] for i in range(tgt_num_point)], ndmin = 2) t = numpy.linalg.lstsq(matA, numpy.ones(tgt_num_point), rcond = None)[0] # print(t.shape, flush = True) ineqs.append([matP_avg.dot(t) + 1.0] + list(-t)) for i in range(matQ_null.shape[1]): t = matQ_null[:, i] ineqs.append([matP_avg.dot(t)] + list(-t)) lset.add(len(ineqs) - 1) ineqs_row_size = len(ineqs) if False: cmat = cdd.Matrix([mat[cid[i]] for i in range(tgt_num_point)], number_type=("fraction" if isfrac else "float")) poly = cdd.Polyhedron(cmat) ineqs = poly.get_inequalities() lset = ineqs.lin_set ineqs_row_size = ineqs.row_size did = True if verbose: print("HULL FINISHED", flush = True) if verbose_terms or verbose_terms_inner: print("INNER:", flush = True) for i in range(ineqs_row_size): y = ineqs[i] print(" " + str(toexpr(y[1:n+1])) + (" == " if i in lset else " >= ") + iutil.float_tostr(-y[0]), flush = True) for i in range(ineqs_row_size): if PsiOpts.is_timer_ended(): break for sgn in ([1, -1] if i in lset else [1]): if PsiOpts.is_timer_ended(): break #print("IROWSIZE " + str(ineqs_row_size)) x = [a * sgn for a in ineqs[i]] xnorm = sum(abs(a) for a in x) if xnorm <= ceps: continue x = [a / xnorm for a in x] if not isfull: isbad = False for i2 in range(mat.row_size): y = mat[i2] if sum(yt * xt for (yt, xt) in zip(y, x)) < -ceps: isbad = True break if isbad: if verbose_detail: print("NOT EXTREMAL " + str(toexpr(x[1:n+1])) + " + " + str(x[0]) + " >= 0", flush = True) continue for y in ineqs_tried: if sum(abs(a - b) for a, b in zip(x, y)) <= ceps: break else: if verbose_detail: print("MIN " + str(toexpr(x[1:n+1])), flush = True) isshortcut = False opt = None v = None if toexpr is not None and abs(x[0]) <= ceps: if toexpr(x[1:n+1]).simplified_quick().isnonneg(): opt = 0.0 isshortcut = True if not isshortcut: # print("PROG " + str(x)) opt, v = prog(x[1:n+1]) # print(" VS " + str(opt) + " " + str(-x[0])) vo = v ineqs_tried.append(list(x)) if verbose_detail: print("MIN FINISHED", flush = True) if opt is None or opt >= -x[0] - ceps: ctoadd = opt is not None # ctoadd = True if ctoadd: ineqs_tight.append(list(x)) if verbose or verbose_terms_outer: if verbose and opt is None: print("NONE", flush = True) # print(x) if verbose_terms_outer or verbose_outer or abs(x[0]) <= 100: if ctoadd: if verbose: print("ADD " + str(toexpr(x[1:n+1])) + " >= " + iutil.float_tostr(-x[0]) + (" SHORTCUT" if isshortcut else ""), flush = True) if verbose_terms or verbose_terms_outer: print("OUTER:", flush = True) print(alland([toexpr(y[1:n+1]) + y[0] >= 0 for y in ineqs_tight]), flush = True) print() # for y in ineqs_tight: # print(" " + str(toexpr(y[1:n+1])) + " >= " + iutil.float_tostr(-y[0])) #print("TIGHT " + str(list(x))) continue if isfrac: if frac_denom > 0: v = [fractions.Fraction(a).limit_denominator(frac_denom) for a in v] else: v = [fractions.Fraction(a) for a in v] if iscone: vnorm = sum(abs(a) for a in v) if vnorm > ceps: v = [a / vnorm for a in v] v = [0] + v else: v = [1] + v for i2 in range(mat.row_size): y = mat[i2] if sum(abs(a - b) for a, b in zip(v, y)) <= ceps: break else: if verbose_detail: print("PT " + str(v), flush = True) if pts_outer is not None: pts_outer.append(vo) mat.extend([v]) matP = None did = True return ineqs_tight def discover_hull(self, A, iscone = False, init_pts_outer = None, pts_outer = None): """Convex hull method for polyhedron projection. C. Lassez and J.-L. Lassez, Quantifier elimination for conjunctions of linear constraints via a convex hull algorithm, IBM Research Report, T.J. Watson Research Center, RC 16779 (1991) """ verbose = PsiOpts.settings.get("verbose_discover", False) verbose_detail = PsiOpts.settings.get("verbose_discover_detail", False) verbose_terms = PsiOpts.settings.get("verbose_discover_terms", False) verbose_terms_outer = PsiOpts.settings.get("verbose_discover_terms_outer", False) n = self.nxvar m = len(A.x) toexpr = None if True or verbose or verbose_detail or verbose_terms or verbose_terms_outer: itoexpr = self.row_toexpr() def ctoexpr(x): c = [0.0] * n for i in range(m): for j, a in A.x[i]: c[j] += a * x[i] return itoexpr(c) toexpr = ctoexpr #print(A.x) #print(self.Au.x) def cprog(x): c = [0.0] * n for i in range(m): for j, a in A.x[i]: c[j] += a * x[i] opt, v = self.call_prog(c) if opt is None: return (None, None) r = [0.0] * m for i in range(m): for j, a in A.x[i]: r[i] += a * v[j] return (opt, r) return LinearProg.proj_hull(cprog, m, toexpr = toexpr, iscone = iscone, init_pts_outer = init_pts_outer, pts_outer = pts_outer) def corners_value(self, ispolar = False, isfrac = None): if cdd is None: raise RuntimeError("Requires pycddlib. Please install it first.") if isfrac is None: isfrac = PsiOpts.settings.get("discover_hull_frac_enabled", False) frac_denom = PsiOpts.settings.get("discover_hull_frac_denom", -1) ceps = PsiOpts.settings["eps_lp"] n = self.nxvar # print(n) # print(self.Au.x) # print(self.Ae.x) mat = None if len(self.Au.x): ma = numpy.hstack([numpy.array([self.bu]).T, -self.Au.tonumpyarray()]) if mat is None: mat = cdd.Matrix(ma, number_type=("fraction" if isfrac else "float")) mat.rep_type = cdd.RepType.INEQUALITY else: mat.extend(ma) if len(self.Ae.x): ma = numpy.hstack([numpy.array([self.be]).T, -self.Ae.tonumpyarray()]) if mat is None: mat = cdd.Matrix(ma, linear = True, number_type=("fraction" if isfrac else "float")) mat.rep_type = cdd.RepType.INEQUALITY else: mat.extend(ma, linear = True) if mat is None: return None # print(mat) poly = cdd.Polyhedron(mat) gs = poly.get_generators() fs = poly.get_inequalities() # print(gs) gi = poly.get_incidence() fi = poly.get_input_incidence() #print(fs) if ispolar: gs, fs = fs, gs gi, fi = fi, gi ng = gs.row_size angles = [0.0] * ng if n >= 2: for i in range(ng): avg = None if ispolar: avg = gs[i][1:] else: avg = [0.0] * n for k in gi[i]: avg = [fs[k][j + 1] + avg[j] for j in range(n)] angles[i] = math.atan2(avg[1], avg[0]) if len(gi[i]) == 0: angles[i] = 1e20 gsj = sorted([i for i in range(ng) if len(gi[i])], key = lambda k: angles[k]) gsjinv = [0] * ng for i in range(len(gsj)): gsjinv[gsj[i]] = i gsr = [list(gs[gsj[i]]) for i in range(len(gsj))] #fir = [[gsjinv[a] for a in x] for x in fi if len(x) > 0 and len(x) < ng] fir = [[gsjinv[a] for a in x] for x in fi if len(x) > 0] return (gsr, fir) def table(self, *args, **kwargs): """Plot the information diagram as a Karnaugh map. """ return universe().table(*args, self, **kwargs) def venn(self, *args, **kwargs): """Plot the information diagram as a Venn diagram. Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5). """ return universe().venn(*args, self, **kwargs) class Level(IBaseObj): """The level of a region in the linear entropy hierarchy. """ def __init__(self, quant = 0, n = 0): self.quant = quant self.n = n if self.n == 0: self.quant = 0 @staticmethod def sigma(n): """ Sigma_n. Returns ------- Level """ return Level(1 if n else 0, n) @staticmethod def pi(n): """ Pi_n. Returns ------- Level """ return Level(-1 if n else 0, n) @staticmethod def delta(n): """ Delta_n. Returns ------- Level """ return Level(0, n) def __eq__(self, other): return self.quant == other.quant and self.n == other.n def __ne__(self, other): return not self == other def __and__(self, other): if self.n < other.n: return self if other.n < self.n: return other if self == other: return self return Level(0, self.n) def __or__(self, other): if self.n > other.n: return self if other.n > self.n: return other if self == other: return self if self.quant == 0: return other if other.quant == 0: return self return Level(0, self.n + 1) def __invert__(self): return Level(-self.quant, self.n) def __rshift__(self, other): return (~self) | other def qimplies(self, other, q = 0): if q > 0: return (self | other).exists() elif q < 0: return (self >> other).forall() else: return (self | other).exists() & (self >> other).forall() def __bool__(self): return self.n != 0 def __int__(self): return self.n def int_lv(self): return self.n * 2 + int(self.quant != 0) def __ge__(self, other): return self.int_lv() > other.int_lv() or self == other def __le__(self, other): return other >= self def __gt__(self, other): return self.int_lv() > other.int_lv() def __lt__(self, other): return other > self def exists(self, forall = False): tquant = -1 if forall else 1 if self.n and (self.quant == tquant or self.quant == 0): return Level(tquant, self.n) else: return Level(tquant, self.n + 1) def forall(self): return self.exists(True) def __str__(self): r = "" if self.quant == 0: r += "Delta" elif self.quant > 0: r += "Sigma" else: r += "Pi" r += "_" + str(self.n) return r @latex_postprocess def _latex_(self): r = "" if self.quant == 0: r += "\Delta" elif self.quant > 0: r += "\Sigma" else: r += "\Pi" r += "_" t = str(self.n) if len(t) > 1: r += "{" + t + "}" else: r += t return r def __repr__(self): r = "Level." if self.quant == 0: r += "delta" elif self.quant > 0: r += "sigma" else: r += "pi" r += "(" + str(self.n) + ")" return r class RegionLevel(IBaseObj): """The level of a region in the linear entropy hierarchy, with the same syntax as Region. """ def __init__(self, quant = 0, n = 0): if isinstance(quant, Level): self.level = quant elif isinstance(quant, Region): self.level = quant.level() elif isinstance(quant, RegionLevel): self.level = quant.level else: self.level = Level(quant, n) @staticmethod def sigma(n): """ Sigma_n. Returns ------- Level """ return RegionLevel(1 if n else 0, n) @staticmethod def pi(n): """ Pi_n. Returns ------- Level """ return RegionLevel(-1 if n else 0, n) @staticmethod def delta(n): """ Delta_n. Returns ------- Level """ return RegionLevel(0, n) def __or__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return RegionLevel(self.level | other.level) def __and__(self, other): return self | other def inter(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return RegionLevel(self.level & other.level) def __invert__(self): if not isinstance(self, RegionLevel): self = self.region_level() return RegionLevel(~self.level) def __rshift__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return (~self) | other def qimplies(self, other, q = 0): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return RegionLevel(self.level.qimplies(other.level, q = q)) def __eq__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return (self >> other) & (other >> self) def __ne__(self, other): return ~(self == other) def __bool__(self): return bool(self.level) def __int__(self): return int(self.level) def int_lv(self): return self.level.int_lv() def __ge__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return self.level >= other.level def __le__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return self.level <= other.level def __gt__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return self.level > other.level def __lt__(self, other): if not isinstance(self, RegionLevel): self = self.region_level() if not isinstance(other, RegionLevel): other = other.region_level() return self.level < other.level def exists(self, *args, forall = False): return RegionLevel(self.level.exists(forall = forall)) def forall(self, *args): return self.exists(*args, forall = True) def __str__(self): return str(self.level) @latex_postprocess def _latex_(self): return self.level._latex_() def __repr__(self): r = "RegionLevel." if self.level.quant == 0: r += "delta" elif self.level.quant > 0: r += "sigma" else: r += "pi" r += "(" + str(self.level.n) + ")" return r class RegionType: NIL = 0 NORMAL = 1 UNION = 2 INTER = 3 class Region(IBaseObj): """A region consisting of equality and inequality constraints""" def __init__(self, exprs_ge, exprs_eq, aux, inp, oup, exprs_gei = None, exprs_eqi = None, auxi = None, meta = None): self.exprs_ge = exprs_ge self.exprs_eq = exprs_eq self.aux = aux self.inp = inp self.oup = oup if exprs_gei is not None: self.exprs_gei = exprs_gei else: self.exprs_gei = [] if exprs_eqi is not None: self.exprs_eqi = exprs_eqi else: self.exprs_eqi = [] if auxi is not None: self.auxi = auxi else: self.auxi = Comp.empty() self.meta = meta def get_type(self): return RegionType.NORMAL def isnormalcons(self): return not self.imp_present() @staticmethod def universe(): return Region([], [], Comp.empty(), Comp.empty(), Comp.empty()) @staticmethod def Ic(x, y, z = None): if z is None: z = Comp.empty() x = x - z y = y - z if x.isempty() or y.isempty(): return Region.universe() return Region([], [Expr.Ic(x, y, z)], Comp.empty(), Comp.empty(), Comp.empty()) @staticmethod def empty(): return Region([-Expr.one()], [], Comp.empty(), Comp.empty(), Comp.empty()) @staticmethod def from_bool(b): if b: return Region.universe() else: return Region.empty() @staticmethod def parse(s): """Parse a string, e.g. I(X;Y,Z|W) + 2H(X Z) \le 3 """ return RegionParser.parse_default(s) def setuniverse(self): self.exprs_ge = [] self.exprs_eq = [] self.aux = Comp.empty() self.inp = Comp.empty() self.oup = Comp.empty() self.exprs_gei = [] self.exprs_eqi = [] self.auxi = Comp.empty() def setempty(self): self.exprs_ge = [-Expr.one()] self.exprs_eq = [] self.aux = Comp.empty() self.inp = Comp.empty() self.oup = Comp.empty() self.exprs_gei = [] self.exprs_eqi = [] self.auxi = Comp.empty() def isempty(self): if not (len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0): return False ceps = PsiOpts.settings["eps"] for x in self.exprs_ge: t = x.get_const() if t is not None and t < -ceps: return True for x in self.exprs_eq: t = x.get_const() if t is not None and abs(t) > ceps: return True return False def isuniverse(self, sgn = True, canon = False): if canon and (not self.aux.isempty() or not self.auxi.isempty()): return False if sgn: return len(self.exprs_ge) == 0 and len(self.exprs_eq) == 0 and len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0 else: return self.isempty() def iseq(self): """Is this pure equality. """ if not self.aux.isempty() or not self.auxi.isempty(): return False return len(self.exprs_ge) == 0 and len(self.exprs_eq) > 0 and len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0 def isineq(self): """Is this pure inequality. """ if not self.aux.isempty() or not self.auxi.isempty(): return False return len(self.exprs_ge) > 0 and len(self.exprs_eq) == 0 and len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0 def expr(self): """Returns the sum of the expressions in this region. """ return sum(self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi, Expr.zero()) def copy(self): return Region([x.copy() for x in self.exprs_ge], [x.copy() for x in self.exprs_eq], self.aux.copy(), self.inp.copy(), self.oup.copy(), [x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eqi], self.auxi.copy(), iutil.copy(self.meta)) def copy_noaux(self): return Region([x.copy() for x in self.exprs_ge], [x.copy() for x in self.exprs_eq], Comp.empty(), self.inp.copy(), self.oup.copy(), [x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eqi], Comp.empty(), iutil.copy(self.meta)) def noaux(self): return self.copy_noaux() def copy_(self, other): self.exprs_ge = [x.copy() for x in other.exprs_ge] self.exprs_eq = [x.copy() for x in other.exprs_eq] self.aux = other.aux.copy() self.inp = other.inp.copy() self.oup = other.oup.copy() self.exprs_gei = [x.copy() for x in other.exprs_gei] self.exprs_eqi = [x.copy() for x in other.exprs_eqi] self.auxi = other.auxi.copy() self.meta = iutil.copy(other.meta) def imp_intersection(self): return Region([x.copy() for x in self.exprs_ge] + [x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eq] + [x.copy() for x in self.exprs_eqi], self.aux.copy() + self.auxi.copy(), self.inp.copy(), self.oup.copy()) def imp_intersection_noaux(self): return Region([x.copy() for x in self.exprs_ge] + [x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eq] + [x.copy() for x in self.exprs_eqi], Comp.empty(), Comp.empty(), Comp.empty()) def imp_copy(self): return Region([], [], Comp.empty(), Comp.empty(), Comp.empty(), [x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eqi], self.auxi.copy()) def imp_flipped(self): return Region([x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eqi], self.auxi.copy(), self.inp.copy(), self.oup.copy(), [x.copy() for x in self.exprs_ge], [x.copy() for x in self.exprs_eq], self.aux.copy()) def consonly(self): return Region([x.copy() for x in self.exprs_ge], [x.copy() for x in self.exprs_eq], self.aux.copy(), self.inp.copy(), self.oup.copy()) def imp_flippedonly(self): return Region([x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eqi], self.auxi.copy(), Comp.empty(), Comp.empty()) def imp_flippedonly_noaux(self): return Region([x.copy() for x in self.exprs_gei], [x.copy() for x in self.exprs_eqi], Comp.empty(), Comp.empty(), Comp.empty()) def imp_present(self): return len(self.exprs_gei) > 0 or len(self.exprs_eqi) > 0 or not self.auxi.isempty() def imp_flip(self): self.exprs_ge, self.exprs_gei = self.exprs_gei, self.exprs_ge self.exprs_eq, self.exprs_eqi = self.exprs_eqi, self.exprs_eq self.aux, self.auxi = self.auxi, self.aux return self def imp_only_copy_to(self, other): other.exprs_ge = [] other.exprs_eq = [] other.aux = Comp.empty() other.inp = Comp.empty() other.oup = Comp.empty() other.exprs_gei = [x.copy() for x in self.exprs_gei] other.exprs_eqi = [x.copy() for x in self.exprs_eqi] other.auxi = self.auxi.copy() def __len__(self): return len(self.exprs_ge) + len(self.exprs_eq) def __getitem__(self, key): t = [(a, False) for a in self.exprs_ge] + [(a, True) for a in self.exprs_eq] r = t[key] if not isinstance(r, list): r = [r] c = Region.universe() for a, eq in r: if eq: c.exprs_eq.append(a) else: c.exprs_ge.append(a) return c def add_meta(self, key, value, children = True): if not children: return IBaseObj.add_meta(self, key, value) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: x.add_meta(key, value) return self def get_meta(self, key): t = IBaseObj.get_meta(self, key) if t is not None: return t for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: t = x.get_meta(key) if t is not None: return t return None def remove_meta(self, key): IBaseObj.remove_meta(self, key) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: x.remove_meta(key) return self def sum_entrywise(self, other): return Region([x + y for (x, y) in zip(self.exprs_ge, other.exprs_ge)], [x + y for (x, y) in zip(self.exprs_eq, other.exprs_eq)], self.aux.interleaved(other.aux), self.inp.interleaved(other.inp), self.oup.interleaved(other.oup), [x + y for (x, y) in zip(self.exprs_gei, other.exprs_gei)], [x + y for (x, y) in zip(self.exprs_eqi, other.exprs_eqi)], self.auxi.interleaved(other.auxi)) def ispresent(self, x): """Return whether any variable in x appears here.""" for z in self.exprs_ge: if z.ispresent(x): return True for z in self.exprs_eq: if z.ispresent(x): return True if self.aux.ispresent(x): return True if self.inp.ispresent(x): return True if self.oup.ispresent(x): return True for z in self.exprs_gei: if z.ispresent(x): return True for z in self.exprs_eqi: if z.ispresent(x): return True if self.auxi.ispresent(x): return True return False def __contains__(self, other): x = other if isinstance(x, str): if x == "=" or x == "==": return bool(self.exprs_eq) or bool(self.exprs_eqi) if x == ">=" or x == "<=": return bool(self.exprs_ge) or bool(self.exprs_gei) if x == ">" or x == "<": return Expr.eps() in self if x == "exists": return not self.aux.isempty() if x == "forall": return not self.auxi.isempty() x = rv(x) if isinstance(x, Region): return self.contains_region(other) for z in self.exprs_ge: if x in z: return True for z in self.exprs_eq: if x in z: return True if x in self.aux: return True if x in self.inp: return True if x in self.oup: return True for z in self.exprs_gei: if x in z: return True for z in self.exprs_eqi: if x in z: return True if x in self.auxi: return True return False def affine_present(self): """Return whether there are any affine constraint.""" return self.ispresent((Expr.one() + Expr.eps() + Expr.inf()).allcomp()) def imp_ispresent(self, x): for z in self.exprs_gei: if z.ispresent(x): return True for z in self.exprs_eqi: if z.ispresent(x): return True if self.auxi.ispresent(x): return True return False def contains_region(self, other): if isinstance(other, RegionOp): return False for x, isge in [(y, True) for y in other.exprs_ge] + [(y, False) for y in other.exprs_eq]: found = False if isge: for z in self.exprs_ge + self.exprs_gei: t = z.get_ratio(x) if t is not None and t > 0: found = True break if found: continue for z in self.exprs_eq + self.exprs_eqi: t = z.get_ratio(x) if t is not None and t != 0: found = True break if found: continue return False return True def rename_var(self, name0, name1): for x in self.exprs_ge: x.rename_var(name0, name1) for x in self.exprs_eq: x.rename_var(name0, name1) self.aux.rename_var(name0, name1) self.inp.rename_var(name0, name1) self.oup.rename_var(name0, name1) for x in self.exprs_gei: x.rename_var(name0, name1) for x in self.exprs_eqi: x.rename_var(name0, name1) self.auxi.rename_var(name0, name1) def rename_map(self, namemap): """Rename according to name map. """ for x in self.exprs_ge: x.rename_map(namemap) for x in self.exprs_eq: x.rename_map(namemap) self.aux.rename_map(namemap) self.inp.rename_map(namemap) self.oup.rename_map(namemap) for x in self.exprs_gei: x.rename_map(namemap) for x in self.exprs_eqi: x.rename_map(namemap) self.auxi.rename_map(namemap) return self @fcn_substitute def substitute(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place""" for x in self.exprs_ge: x.substitute(v0, v1) for x in self.exprs_eq: x.substitute(v0, v1) for x in self.exprs_gei: x.substitute(v0, v1) for x in self.exprs_eqi: x.substitute(v0, v1) if not isinstance(v0, Expr): self.aux.substitute(v0, v1) self.inp.substitute(v0, v1) self.oup.substitute(v0, v1) self.auxi.substitute(v0, v1) if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute(self.meta, v0, v1) return self @fcn_substitute def substitute_whole(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place""" for x in self.exprs_ge: x.substitute_whole(v0, v1) for x in self.exprs_eq: x.substitute_whole(v0, v1) for x in self.exprs_gei: x.substitute_whole(v0, v1) for x in self.exprs_eqi: x.substitute_whole(v0, v1) if not isinstance(v0, Expr): self.aux.substitute_whole(v0, v1) self.inp.substitute_whole(v0, v1) self.oup.substitute_whole(v0, v1) self.auxi.substitute_whole(v0, v1) if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute_whole(self.meta, v0, v1) return self @fcn_substitute def substitute_aux(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, in place""" for x in self.exprs_ge: x.substitute(v0, v1) for x in self.exprs_eq: x.substitute(v0, v1) for x in self.exprs_gei: x.substitute(v0, v1) for x in self.exprs_eqi: x.substitute(v0, v1) if not isinstance(v0, Expr): self.aux -= v0 self.inp -= v0 self.oup -= v0 self.auxi -= v0 if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute(self.meta, v0, v1) return self def substituted_aux(self, *args, **kwargs): """Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, return result""" r = self.copy() r.substitute_aux(*args, **kwargs) return r def remove_present(self, v): self.exprs_ge = [x for x in self.exprs_ge if not x.ispresent(v)] self.exprs_eq = [x for x in self.exprs_eq if not x.ispresent(v)] self.exprs_gei = [x for x in self.exprs_gei if not x.ispresent(v)] self.exprs_eqi = [x for x in self.exprs_eqi if not x.ispresent(v)] if isinstance(v, Comp): self.aux -= v self.inp -= v self.oup -= v self.auxi -= v def remove_notpresent(self, v): self.exprs_ge = [x for x in self.exprs_ge if x.ispresent(v)] self.exprs_eq = [x for x in self.exprs_eq if x.ispresent(v)] self.exprs_gei = [x for x in self.exprs_gei if x.ispresent(v)] self.exprs_eqi = [x for x in self.exprs_eqi if x.ispresent(v)] if isinstance(v, Comp): self.aux = self.aux.inter(v) self.inp = self.inp.inter(v) self.oup = self.oup.inter(v) self.auxi = self.auxi.inter(v) def remove_notcontained(self, v): t = self.allcomp() - v if not t.isempty(): self.remove_present(t) def remove_relax(self, v, sn = 1): if sn < 0: self.imp_flip() if not self.remove_relax(v): self.setempty() return False self.imp_flip() return True if any(x.ispresent(v) for x in self.exprs_eqi): self.setuniverse() return False for x in self.exprs_gei: if not x.try_remove(v, -1): self.setuniverse() return False tlist = self.exprs_ge self.exprs_ge = [] for x in tlist: if x.try_remove(v, 1): self.exprs_ge.append(x) self.exprs_eq = [x for x in self.exprs_eq if not x.ispresent(v)] if isinstance(v, Comp): self.aux -= v self.inp -= v self.oup -= v self.auxi -= v return True # def removed_constraints(self, v): # """Return the region after the contraints in v are removed. # """ # cs = self.copy() # if not isinstance(v, Comp): # cs.remove_relax(v) # return cs # bnet = cs.get_bayesnet(roots = v) # bnet += v # cs.remove_relax(v) # return (cs & bnet.get_region()).simplified_quick() def condition(self, b): """Condition on random variable b, in place""" for x in self.exprs_ge: x.condition(b) for x in self.exprs_eq: x.condition(b) for x in self.exprs_gei: x.condition(b) for x in self.exprs_eqi: x.condition(b) return self def conditioned(self, b): """Condition on random variable b, return result""" r = self.copy() r.condition(b) return r def symm_sort(self, terms): """Sort the random variables in terms assuming symmetry among those terms.""" for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: x.symm_sort(terms) def find(self, *args): return self.allcomp().find(*args) def placeholder(*args): r = Expr.zero() for a in args: if isinstance(a, Comp): r += Expr.H(a) * 0 elif isinstance(a, Expr): r += a * 0 return r >= 0 def empty_placeholder(*args): r = Region.empty() r.iand_norename(Region.placeholder(*args)) return r def record_to(self, index): for x in self.exprs_ge: x.record_to(index) for x in self.exprs_eq: x.record_to(index) index.record(self.aux) index.record(self.inp) index.record(self.oup) for x in self.exprs_gei: x.record_to(index) for x in self.exprs_eqi: x.record_to(index) index.record(self.auxi) def name_avoid(self, name0, regs = None): index = IVarIndex() self.record_to(index) if regs is not None: for r in regs: r.record_to(index) return index.name_avoid(name0) # name1 = name0 # while index.get_index_name(name1) >= 0: # name1 += PsiOpts.settings["rename_char"] # return name1 def rename_avoid(self, reg, name0): index = IVarIndex() reg.record_to(index) sindex = IVarIndex() self.record_to(sindex) name1 = name0 while (index.get_index_name(name1) >= 0 or (name1 != name0 and sindex.get_index_name(name1) >= 0)): name1 += PsiOpts.settings["rename_char"] if name1 != name0: self.rename_var(name0, name1) def aux_addprefix(self, pref = "@@"): for i in range(len(self.aux.varlist)): self.rename_var(self.aux.varlist[i].name, pref + self.aux.varlist[i].name) def aux_present(self): return not self.getaux().isempty() or not self.getauxi().isempty() def aux_clear(self): self.aux = Comp.empty() self.auxi = Comp.empty() def getaux(self): return self.aux.copy() def getauxi(self): return self.auxi.copy() def getauxall(self): return self.aux + self.auxi def getauxs(self): r = [] if not self.aux.isempty(): r.append((self.aux.copy(), True)) if not self.auxi.isempty(): r.append((self.auxi.copy(), False)) return r def aux_avoid(self, reg, samesuffix = True): if samesuffix: self.aux_avoid_from(reg.allcomprv_noaux(), samesuffix = True) reg.aux_avoid_from(self.allcomprv(), samesuffix = True) else: for a in reg.getauxi().varlist: reg.rename_avoid(self, a.name) for a in reg.getaux().varlist: reg.rename_avoid(self, a.name) for a in self.getauxi().varlist: self.rename_avoid(reg, a.name) for a in self.getaux().varlist: self.rename_avoid(reg, a.name) def aux_avoid_from(self, reg, samesuffix = True): if not PsiOpts.settings["avoid_enabled"]: return if samesuffix: if isinstance(reg, Region): reg = reg.allcomprv() reg = reg + self.allcomprv_noaux() auxcomp = self.getauxall() regindex = IVarIndex() reg.record_to(regindex) if not regindex.ispresent(auxcomp): return rename_char = PsiOpts.settings["rename_char"] for rep in ["set", "add", "suffix"]: for k in range(1, 512 if rep else 1000): rdict = {} rset = set() bad = False for a in auxcomp: t = iutil.set_suffix_num(a.get_name(), k, rename_char, replace_mode = rep) if t in rset: bad = True break rdict[a.get_name()] = t rset.add(t) if not bad and not any(regindex.get_index_name(a) >= 0 for a in rset): self.rename_map(rdict) return else: for a in self.getaux().varlist: self.rename_avoid(reg, a.name) for a in self.getauxi().varlist: self.rename_avoid(reg, a.name) def iand_norename(self, other): co = other self.exprs_ge += [x.copy() for x in co.exprs_ge] self.exprs_eq += [x.copy() for x in co.exprs_eq] self.exprs_gei += [x.copy() for x in co.exprs_gei] self.exprs_eqi += [x.copy() for x in co.exprs_eqi] self.aux += co.aux self.auxi += co.auxi return self def __iand__(self, other): if isinstance(other, bool): if not other: return Region.empty() return self other = iutil.ensure_region(other) if (isinstance(other, RegionOp) or self.imp_present() or other.imp_present() or (not PsiOpts.settings["prefer_expand"] and (self.aux_present() or other.aux_present()))): return RegionOp.inter([self]) & other if not self.aux_present() and not other.aux_present(): self.exprs_ge += [x.copy() for x in other.exprs_ge] self.exprs_eq += [x.copy() for x in other.exprs_eq] self.exprs_gei += [x.copy() for x in other.exprs_gei] self.exprs_eqi += [x.copy() for x in other.exprs_eqi] return self co = other.copy() self.aux_avoid(co) self.exprs_ge += [x.copy() for x in co.exprs_ge] self.exprs_eq += [x.copy() for x in co.exprs_eq] self.exprs_gei += [x.copy() for x in co.exprs_gei] self.exprs_eqi += [x.copy() for x in co.exprs_eqi] self.aux += co.aux self.auxi += co.auxi return self def __and__(self, other): r = self.copy() r &= other return r def __rand__(self, other): r = self.copy() r &= other return r def __pow__(self, other): if other <= 0: return Region.universe() r = self.copy() for i in range(other - 1): r &= self return r def __or__(self, other): other = iutil.ensure_region(other) if self.isempty(): return other.copy() return RegionOp.union([self]) | other def __ror__(self, other): other = iutil.ensure_region(other) if self.isempty(): return other.copy() return RegionOp.union([self]) | other def __ior__(self, other): other = iutil.ensure_region(other) if self.isempty(): return other.copy() return RegionOp.union([self]) | other def implicate(self, other, skip_simplify = False): other = iutil.ensure_region(other) co = other.copy() if not skip_simplify and PsiOpts.settings["imp_simplify"]: if co.imp_present(): co.simplify() self.aux_avoid(co) self.exprs_ge += [x.copy() for x in co.exprs_gei] self.exprs_eq += [x.copy() for x in co.exprs_eqi] self.exprs_gei += [x.copy() for x in co.exprs_ge] self.exprs_eqi += [x.copy() for x in co.exprs_eq] #self.aux += co.auxi #self.auxi = co.aux + self.auxi self.aux = co.auxi + self.aux self.auxi += co.aux return self def implicated(self, other, skip_simplify = False): other = iutil.ensure_region(other) # if isinstance(other, RegionOp) or self.imp_present() or other.imp_present() or not other.aux.isempty(): if isinstance(other, RegionOp) or not self.auxi.isempty() or other.imp_present() or not other.aux.isempty(): return RegionOp.union([self]).implicated(other, skip_simplify) r = self.copy() r.implicate(other, skip_simplify) return r def __le__(self, other): other = iutil.ensure_region(other) return other.implicated(self) def __ge__(self, other): other = iutil.ensure_region(other) return self.implicated(other) def __rshift__(self, other): other = iutil.ensure_region(other) return other.implicated(self) def __rrshift__(self, other): other = iutil.ensure_region(other) return self.implicated(other) def __lshift__(self, other): other = iutil.ensure_region(other) return self.implicated(other) def __rlshift__(self, other): other = iutil.ensure_region(other) return other.implicated(self) def __eq__(self, other): other = iutil.ensure_region(other) #return self.implies(other) and other.implies(self) return RegionOp.inter([self.implicated(other), other.implicated(self)]) def __ne__(self, other): other = iutil.ensure_region(other) return ~(RegionOp.inter([self.implicated(other), other.implicated(self)])) def relax_term(self, term, gap): self.simplify_quick() for x in self.exprs_eq: c = x.get_coeff(term) if c != 0.0: self.exprs_ge.append(x.copy()) self.exprs_ge.append(-x) x.setzero() for x in self.exprs_ge: c = x.get_coeff(term) if c > 0.0: x.substitute(Expr.fromterm(term), Expr.fromterm(term) + gap) elif c < 0.0: x.substitute(Expr.fromterm(term), Expr.fromterm(term) - gap) self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] def relax(self, w, gap): """Relax real variables in w by gap, in place""" for (a, c) in w.terms: if a.get_type() == TermType.REAL: self.relax_term(a, gap) return self def relaxed(self, w, gap): """Relax real variables in w by gap, return result""" r = self.copy() r.relax(w, gap) return r def balance(self, v = None, w = None, skip_simplify = False): if w is None: w = self.allcomprv() for x in self.exprs_ge + self.exprs_gei: t = x.balanced(v, w, sn = 1) if t is None: t = Expr.zero() x.copy_(t) for x in self.exprs_eq: if x.isbalanced(v): continue t = x.balanced(v, w, sn = 1) if t is not None: self.exprs_ge.append(t) t = x.balanced(v, w, sn = -1) if t is not None: self.exprs_ge.append(-t) x.setzero() for x in self.exprs_eqi: if x.isbalanced(v): continue t = x.balanced(v, w, sn = 1) if t is not None: self.exprs_gei.append(t) t = x.balanced(v, w, sn = -1) if t is not None: self.exprs_gei.append(-t) x.setzero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.exprs_gei = [x for x in self.exprs_gei if not x.iszero()] self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] if not skip_simplify: self.simplify() return self def balanced(self, v = None, skip_simplify = False): r = self.copy() r.balance(v, skip_simplify = True) if not skip_simplify: return r.simplified() return r def one_flipped(self, strict = False): if len(self.exprs_eq) > 0 or len(self.exprs_ge) != 1: return None if strict: c = self.exprs_ge[0].get_coeff(Term.eps()) r = self.exprs_ge[0].substituted(Expr.eps(), Expr.zero()) if c >= 0: return r + Expr.eps() <= 0 else: return r <= 0 else: return self.exprs_ge[0] <= 0 def broken_present(self, w, flipped = True): """Convert region to intersection of individual constraints if they contain w""" if self.imp_present(): return RegionOp.from_region(self).broken_present(w, flipped) r = RegionOp.inter([]) cs = self.copy() for x in cs.exprs_eq: if x.ispresent(w): r.regs.append((x == 0, True)) x.setzero() for x in cs.exprs_ge: if x.ispresent(w): if flipped: r.regs.append((~(x <= 0), True)) else: r.regs.append((x >= 0, True)) x.setzero() cs.exprs_eq = [x for x in cs.exprs_eq if not x.iszero()] cs.exprs_ge = [x for x in cs.exprs_ge if not x.iszero()] cs.aux = Comp.empty() cs.auxi = Comp.empty() r.regs.append((cs, True)) return r.exists(self.aux).forall(self.auxi) def corners_optimum(self, w, sn): """Return union of regions corresponding to maximum/minimum of the real variable w""" for x in self.exprs_eq: if x.get_coeff(w.terms[0][0]) != 0: return self.copy() r = [] for i in range(len(self.exprs_ge)): x = self.exprs_ge[i] if x.get_coeff(w.terms[0][0]) * sn < 0: cs2 = self.copy() cs2.exprs_ge.pop(i) cs2.exprs_eq.append(x.copy()) cs2.aux = Comp.empty() cs2.auxi = Comp.empty() r.append(cs2) if len(r) == 0: return Region.universe() if len(r) == 1: return r[0].exists(self.aux).forall(self.auxi) return RegionOp.union(r).exists(self.aux).forall(self.auxi) def corners_optimum_eq(self, w, sn): """Return union of regions corresponding to maximum/minimum of the real variable w""" for x in self.exprs_eq: if x.get_coeff(w.terms[0][0]) != 0: return self.copy() r = [] cs = self.copy() cs.remove_present(w.terms[0][0].x[0]) for i in range(len(self.exprs_ge)): x = self.exprs_ge[i] if x.get_coeff(w.terms[0][0]) * sn < 0: cs2 = cs.copy() cs2.exprs_eqi.append(x.copy()) r.append(cs2) if len(r) == 0: return Region.universe() if len(r) == 1: return r[0] return RegionOp.inter(r) def corners(self, w): """Return union of regions corresponding to corner points of the real variables in w""" terms = [] if isinstance(w, Expr): for (a, c) in w.terms: if a.get_type() == TermType.REAL: terms.append(a) else: for w2 in w: for (a, c) in w2.terms: if a.get_type() == TermType.REAL: terms.append(a) n = len(terms) cmat = [] cmatall = [] for x in self.exprs_eq: coeff = [x.get_coeff(term) for term in terms] cmat.append(coeff[:]) cmatall.append(coeff[:]) rank = numpy.linalg.matrix_rank(cmat) if rank >= n: return [self.copy()] cs = self.copy() cs.aux = Comp.empty() ges = [] gec = [] for x in cs.exprs_ge: coeff = [x.get_coeff(term) for term in terms] cmatall.append(coeff[:]) if len([x for x in coeff if abs(x) > PsiOpts.settings["eps"]]) > 0: ges.append(x.copy()) gec.append(coeff) x.setzero() cs.exprs_ge = [x for x in cs.exprs_ge if not x.iszero()] rankall = numpy.linalg.matrix_rank(cmatall) r = [] for comb in itertools.combinations(range(len(ges)), rankall - rank): mat = cmat[:] for i in comb: mat.append(gec[i]) if numpy.linalg.matrix_rank(mat) >= rankall: cs2 = cs.copy() for i2 in range(len(ges)): if i2 in comb: cs2.exprs_eq.append(ges[i2].copy()) else: cs2.exprs_ge.append(ges[i2].copy()) r.append(cs2) return RegionOp.union(r).exists(self.aux) def corners_value(self, w, ispolar = False, skip_discover = False, inf_value = 1e6): """Return the vertices and the facet list of the polytope with coordinates in the list w.""" if not skip_discover: t = real_array("#TMPVAR", 0, len(w)) return self.discover([(a, b) for a, b in zip(t, w)]).corners_value(t, ispolar, True, inf_value) cindex = IVarIndex() self.record_to(cindex) for a in w: a.record_to(cindex) prog = self.imp_flipped().init_prog(index = cindex) g, f = prog.corners_value(ispolar = ispolar) A = SparseMat(0) Ab = [] for a in w: c, c1 = prog.get_vec(a, sparse = True) A.extend(c) Ab.append(c1) r = [] for b in g: v = [sum([b[j + 1] * c for j, c in row], 0.0) for row in A.x] if abs(b[0]) < 1e-11: maxv = max(abs(x) for x in v) v = [0.0] + [(x / maxv) * inf_value for x in v] else: v = [1.0 if b[0] > 0 else -1.0] + [x / abs(b[0]) + y for x, y in zip(v, Ab)] r.append(v) return (r, f) def sign_present(self, term): sn_present = [False] * 2 for x in self.exprs_ge: c = x.get_coeff(term) if c > 0.0: sn_present[1] = True elif c < 0.0: sn_present[0] = True for x in self.exprs_eq: c = x.get_coeff(term) if c != 0.0: sn_present[0] = True sn_present[1] = True for x in self.exprs_gei: c = x.get_coeff(term) if c > 0.0: sn_present[0] = True elif c < 0.0: sn_present[1] = True for x in self.exprs_eqi: c = x.get_coeff(term) if c != 0.0: sn_present[0] = True sn_present[1] = True return sn_present def substitute_sign(self, v0, v1s): v0term = v0.terms[0][0] sn_present = [False] * 2 for x in self.exprs_eq: if x.ispresent(v0): if v1s: self.exprs_ge.append(x.copy()) self.exprs_ge.append(-x) x.setzero() sn_present[0] = True sn_present[1] = True for x in self.exprs_eqi: if x.ispresent(v0): if v1s: self.exprs_gei.append(x.copy()) self.exprs_gei.append(-x) x.setzero() sn_present[0] = True sn_present[1] = True for x in self.exprs_ge: if x.ispresent(v0): c = x.get_coeff(v0term) if c > 0.0: if v1s: x.substitute(v0, v1s[1]) sn_present[1] = True else: if v1s: x.substitute(v0, v1s[0]) sn_present[0] = True for x in self.exprs_gei: if x.ispresent(v0): c = x.get_coeff(v0term) if c > 0.0: if v1s: x.substitute(v0, v1s[0]) sn_present[0] = True else: if v1s: x.substitute(v0, v1s[1]) sn_present[1] = True if v1s: self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] return sn_present def substitute_duplicate(self, v0, v1s): for l in [self.exprs_ge, self.exprs_eq, self.exprs_gei, self.exprs_eqi]: olen = len(l) for ix in range(olen): x = l[ix] if x.ispresent(v0): for v1 in v1s: l.append(x.substituted(v0, v1)) x.setzero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.exprs_gei = [x for x in self.exprs_gei if not x.iszero()] self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] # def flattened_minmax(self, term, sgn, bds): # sbds = self.get_lb_ub_eq(term) # reg.eliminate_term(self) def lowest_present(self, v, sn): if not sn: return None if self.ispresent(v): return self return None def flatten_regterm(self, term): self.simplify_quick() sn = term.sn sn_present = [False] * 2 for x in self.exprs_ge: c = x.get_coeff(term) if c * sn > 0.0: sn_present[1] = True elif c * sn < 0.0: sn_present[0] = True for x in self.exprs_eq: c = x.get_coeff(term) if c != 0.0: sn_present[0] = True sn_present[1] = True self.exprs_ge.append(x.copy()) self.exprs_ge.append(-x) x.setzero() for x in self.exprs_gei: c = x.get_coeff(term) if c * sn > 0.0: sn_present[0] = True elif c * sn < 0.0: sn_present[1] = True for x in self.exprs_eqi: c = x.get_coeff(term) if c != 0.0: sn_present[0] = True sn_present[1] = True self.exprs_gei.append(x.copy()) self.exprs_gei.append(-x) x.setzero() self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] cs = self if sn == 0: sn_present[1] = False sn_present[0] = True if sn_present[1]: tmpvar = Expr.real("FLAT_TMP_" + str(term)) for x in cs.exprs_ge: c = x.get_coeff(term) if c * sn > 0.0: x.substitute(Expr.fromterm(term), tmpvar) for x in cs.exprs_gei: c = x.get_coeff(term) if c * sn < 0.0: x.substitute(Expr.fromterm(term), tmpvar) reg2 = term.reg.copy() reg2.substitute(Expr.fromterm(term), tmpvar) cs.aux_avoid(reg2) newindep = Expr.Ic(reg2.getauxi(), cs.allcomprv() - cs.getaux() - reg2.allcomprv(), reg2.allcomprv_noaux()).simplified_quick() if reg2.get_type() == RegionType.NORMAL: reg2 = reg2.corners_optimum_eq(tmpvar, sn) cs = reg2 & cs if not newindep.iszero(): cs = cs.iand_norename((newindep == 0).imp_flipped()) cs.eliminate_quick(tmpvar) if sn_present[0]: newvar = Expr.real(str(term)) reg2 = term.reg.copy() cs.aux_avoid(reg2) newindep = Expr.Ic(reg2.getaux(), cs.allcomprv() - cs.getaux() - reg2.allcomprv(), reg2.allcomprv_noaux()).simplified_quick() if reg2.get_type() == RegionType.NORMAL: reg2 = reg2.corners_optimum(Expr.fromterm(term), sn) cs = cs.implicated(reg2) #cs &= reg2.imp_flipped() if not newindep.iszero(): cs = cs.iand_norename((newindep == 0).imp_flipped()) cs.substitute(Expr.fromterm(term), newvar) return cs def flatten_ivar(self, ivar): cs = self newvar = Comp([ivar.copy_noreg()]) reg2 = ivar.reg.copy() cs.aux_avoid(reg2) cs = cs.implicated(reg2, skip_simplify = True) cs.substitute(Comp([ivar]), newvar) if not ivar.reg_det: newindep = Expr.Ic(reg2.getaux() + newvar, cs.allcomprv() - cs.getaux() - reg2.allcomprv(), reg2.allcomprv_noaux() - newvar).simplified_quick() if not newindep.iszero(): cs = cs.iand_norename((newindep == 0).imp_flipped()) return cs def isregtermpresent(self): for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: if x.isregtermpresent(): return True return False def numexpr(self): return len(self.exprs_ge) + len(self.exprs_eq) + len(self.exprs_gei) + len(self.exprs_eqi) def numterm(self): return sum([len(x.terms) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi]) def isplain(self): if not self.aux.isempty(): return False if self.isregtermpresent(): return False return True def regtermmap(self, cmap, recur): for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi + [Expr.H(self.aux + self.auxi)]: for (a, c) in x.terms: rvs = a.allcomprv_shallow() for b in rvs.varlist: if b.reg is not None: s = b.name if not (s in cmap): cmap[s] = b if recur: b.reg.regtermmap(cmap, recur) if a.get_type() == TermType.REGION: s = a.x[0].varlist[0].name if not (s in cmap): cmap[s] = a if recur: a.reg.regtermmap(cmap, recur) def flatten(self): verbose = PsiOpts.settings.get("verbose_flatten", False) write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) cs = self did = False didall = False regterms = {} cs.regtermmap(regterms, False) regterms_in = {} for (name, term) in regterms.items(): term.reg.regtermmap(regterms_in, True) for (name, term) in regterms.items(): if not(name in regterms_in): if verbose: print("========= flatten ========") print(cs) print("========= term ========") print(term) print("========= region ========") print(term.reg) if isinstance(term, IVar): cs = cs.flatten_ivar(term) else: cs = cs.flatten_regterm(term) did = True didall = True if verbose: print("========= to ========") print(cs) break if write_pf_enabled: if didall: pf = ProofObj.from_region(self, c = "Expanded definitions to") PsiOpts.set_setting(proof_add = pf) if did: return cs.flatten() return cs def flatten_term(self, x, isimp = True): cs = self if isinstance(x, Comp): for y in x.varlist: cs = cs.flatten_ivar(y, isimp) else: for y, _ in x.terms: cs = cs.flatten_regterm(y, isimp) return cs def flattened(self, *args, minmax_elim = False): if not self.isregtermpresent() and len(args) == 0: return self.copy() cs = None if not isinstance(cs, RegionOp): cs = RegionOp.inter([self]) else: cs = self.copy() cs.flatten(minmax_elim = minmax_elim) for x in args: cs.flatten_term(x) t = cs.tosimple() if t is not None: return t return cs def incorporate_tmp(self, x, tmplink): if isinstance(x, Comp): self &= (H(x) == tmplink) elif isinstance(x, Expr): self &= (x == tmplink) return self def incorporate_remove_tmp(self, tmplink): self.remove_present(tmplink) def to_cause_consequence(self): return [(self.imp_flippedonly_noaux(), self.consonly(), self.auxi.copy())] def and_cause_consequence(self, other, avoid = None, added_reg = None): cs = self.copy() other = other.copy() cs.aux_avoid(other) clist = other.to_cause_consequence() for imp, cons, auxi in clist: tcs = cs.tosimple() if tcs is None: return cs cs = tcs timp = imp.tosimple() if timp is None: continue tcs = imp.exists(auxi) tcs.implicate(cs) # tcs = imp.copy() # tcs.implicate(cs) # tcs = tcs.exists(auxi) # print(imp) # print(cons) # print(auxi) # print("TCS") # print(tcs) # print(cons.allcomp().get_markers()) # print() hint_aux_avoid = None if avoid is not None: hint_aux_avoid = [] for a in auxi: hint_aux_avoid.append((a, avoid.copy())) for rr in tcs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid): if iutil.signal_type(rr) == "": # print(rr) tcons = cons.copy() Comp.substitute_list(tcons, rr) cs &= tcons if added_reg is not None: added_reg &= tcons return cs def toregionop(self): if not isinstance(self, RegionOp): return RegionOp.inter([self]) else: return self.copy() def get_req_cons(self): return self.toregionop().get_req_cons() def incorporated(self, *args): cs = self.toregionop() # cargs = [a.copy() for a in args] cargs = [] for a in args: if isinstance(a, Comp): for b in a: cargs.append(b.copy()) else: cargs.append(a.copy()) to_remove_aux = [] tmplink = Expr.real("#TMPLINK") for i in range(len(cargs)): x = cargs[i] if isinstance(x, Region): cs = cs.and_cause_consequence(x).toregionop() else: cs.flatten_term(x, isimp = False) x_noreg = x.copy_noreg() for j in range(i + 1, len(cargs)): if isinstance(x, Comp) and isinstance(cargs[j], Region) and x not in cargs[j].allcomprv_noaux(): continue cargs[j].substitute(x, x_noreg) if not cs.ispresent(x_noreg): #cs.eliminate(x_noreg, forall = True) cs.incorporate_tmp(x_noreg, tmplink) to_remove_aux.append(x_noreg) if len(to_remove_aux): cs.incorporate_remove_tmp(tmplink) # print("AFTER INCORPORATED") # print(cs) return cs.simplified_quick() def instantiated(self, *args): return universe().incorporated(*args, self) def flattened_self(self): r = self.copy() r = r.flatten() return r.simplify_quick() def tosimple(self): if not self.auxi.isempty(): return None return self.copy() def tosimple_noaux(self): if self.aux_present(): return None return self.tosimple() def tosimple_safe(self): return self.tosimple() def tonormal_safe(self): return self.copy() def level(self, mode = ""): """The level of a region in the linear entropy hierarchy. """ r = Level() if not self.aux.isempty(): r = r.exists() if not self.auxi.isempty(): r = r.forall() return r def region_level(self, mode = ""): return RegionLevel(self.level(mode = mode)) def complexity(self): return sum(x.complexity() for x in self.exprs_eq + self.exprs_ge + self.exprs_eqi + self.exprs_gei) + (len(self.aux) + len(self.auxi)) * 100 def sorting_priority(self): return self.complexity() def z3(self, vardict = None): if z3 is None: return None if vardict is None: vardict = iutil.z3_vardict(self.rvs, self.aux, self.auxi, self.reals) # if not self.auxi.isempty(): # raise ValueError("Universally-quantified random variables are not supported for Z3. Use another solver.") if self.imp_present(): r = z3.Or(z3.Not(self.imp_flippedonly_noaux().z3(vardict)), self.consonly().z3(vardict)) if not self.auxi.isempty(): r = z3.ForAll([x.z3(vardict) for x in self.auxi], r) return r r = [(x.z3(vardict) == 0) for x in self.exprs_eq] + [(x.z3(vardict) >= 0) for x in self.exprs_ge] if len(r) == 1: r = r[0] else: r = z3.And(r) if not self.aux.isempty(): return z3.Exists([x.z3(vardict) for x in self.aux], r) else: return r def check_z3(self): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).check_z3() indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).check_z3() vardict = iutil.z3_vardict(self.rvs, self.aux, self.auxi, self.reals) t = self.z3(vardict) solver = vardict["#solver"] solver.add(z3.Not(t)) res = solver.check() return res == z3.unsat def commonpart_extend(self, v, forbid_h = True): ceps = PsiOpts.settings["eps"] tvar = Expr.real("#TVAR") did = False didpos = False toadd = [] todel = set() for x in self.exprs_eq: c = x.commonpart_coeff(v, forbid_h = forbid_h) if c is None or numpy.isnan(c): return None if numpy.isinf(c): return None if abs(c) <= ceps: continue if x.isnonneg() and c > 0: return None did = True didpos = True toadd.append((x, c)) for x in self.exprs_ge: c = x.commonpart_coeff(v, forbid_h = forbid_h) if c is None or numpy.isnan(c): return None if numpy.isinf(c): if c < 0: return None todel.add(x) continue if abs(c) <= ceps: continue if x.isnonpos() and c < 0: return None did = True if c > 0: didpos = True toadd.append((x, c)) for x in self.exprs_eqi: c = x.commonpart_coeff(v, forbid_h = forbid_h) if c is None or numpy.isnan(c): return None if numpy.isinf(c): return None if abs(c) <= ceps: continue did = True didpos = True toadd.append((x, c)) for x in self.exprs_gei: c = x.commonpart_coeff(v, forbid_h = forbid_h) if c is None or numpy.isnan(c): return None if numpy.isinf(c): return None if abs(c) <= ceps: continue did = True didpos = True toadd.append((x, c)) if not didpos: return None self.exprs_eq = [x for x in self.exprs_eq if x not in todel] self.exprs_ge = [x for x in self.exprs_ge if x not in todel] self.exprs_eqi = [x for x in self.exprs_eqi if x not in todel] self.exprs_gei = [x for x in self.exprs_gei if x not in todel] for x, c in toadd: x += tvar * c self.exprs_ge.append(tvar) self.eliminate(tvar) return self def var_neighbors(self, v): r = v.copy() for x in self.exprs_eq + self.exprs_ge + self.exprs_eqi + self.exprs_gei: r += x.var_neighbors(v) return r def aux_strengthen(self, addrv = None, other_neighbors = None): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.aux_strengthen(addrv) self.cons_shallow_copy_(cs) return self allc = self.allcomprv() if addrv is not None: allc += addrv toadd = Region.universe() ag = [] acon = [] for a in self.aux: v = self.var_neighbors(a) if other_neighbors is not None: v = v + other_neighbors.var_neighbors(a) ag.append(v.copy()) acon.append(v.copy()) for it in range(len(self.aux)): for i, a in enumerate(self.aux): for j, b in enumerate(self.aux): if acon[i].ispresent(b): acon[i] += acon[j] acon[j] += acon[i] def recur(ca, prevv): v = Comp.empty() maxdeg = -1 maxdega = None for i, a in enumerate(self.aux): if ca.ispresent(a): if not acon[i].super_of(ca): recur(acon[i].inter(ca), None) recur(ca - acon[i], None) return v += ag[i] deg = len(ag[i]) + len(ag[i] - ca) * 50 if deg > maxdeg: maxdeg = deg maxdega = a if maxdeg < 0: return # print(str(ca) + " " + str(v) + " " + str(allc-v) + " " + str(prevv)) if prevv is None or v != prevv: b = allc - v if not b.isempty(): toadd.exprs_eq.append(Expr.Ic(ca, b, v - ca)) recur(ca - maxdega, v) recur(self.aux.copy(), None) self.iand_norename(toadd) return self def aux_strengthen_old(self, addrv = None, other_neighbors = None): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.aux_strengthen(addrv) self.cons_shallow_copy_(cs) return self allc = self.allcomprv() if addrv is not None: allc += addrv toadd = Region.universe() allc_on = allc if other_neighbors is not None: allc_on += other_neighbors.allcomprv() allc -= self.aux allc_on -= self.aux for a in self.aux: v = self.var_neighbors(a) if other_neighbors is not None: v = v + other_neighbors.var_neighbors(a) v = v.inter(allc_on) b = allc - v if not b.isempty(): toadd.exprs_eq.append(Expr.Ic(a, b, v - a)) allc += a allc_on += a self.iand_norename(toadd) return self def simplify_aux_commonpart(self, reg = None, minlen = 1, maxlen = 1, forbid_h = True): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux_commonpart(reg, minlen, maxlen, forbid_h) self.cons_shallow_copy_(cs) return self if reg is None: reg = Region.universe() did = False taux = self.aux.copy() taux2 = taux.copy() tauxi = self.auxi.copy() self.aux = Comp.empty() self.auxi = Comp.empty() for clen in range(minlen, maxlen + 1): if clen == 2: continue if clen > len(taux): break for v in igen.subset(taux, minsize = clen): if PsiOpts.is_timer_ended(): break if clen == 1: if self.commonpart_extend(v, forbid_h = forbid_h) is not None: did = True else: for v2 in igen.partition(v, clen): if PsiOpts.is_timer_ended(): break if self.commonpart_extend(v2, forbid_h = forbid_h) is not None: did = True self.aux = taux2 self.auxi = tauxi # if did: # self.simplify_quick() return self def remove_realvar_sn(self, v, sn): # v = v.allcomp() v = v.terms[0][0] for x in self.exprs_ge: # if len(x) == 1 and x.terms[0][1] * sn > 0 and x.terms[0][0].get_type() == IVarType.REAL and x.allcomp() == v: if x.get_coeff(v) * sn > 0: x.setzero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] def istight(self, canon = False): return all(x.istight(canon) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi) def tighten(self): for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: x.tighten() def add_meta_present(self, b, key, value): for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: if x.ispresent(b): x.add_meta(key, value) def optimum(self, v, b, sn, name = None, reg_outer = None, assume_feasible = True, allow_reuse = False, quick = None, quick_outer = None, tighten = False): """Return the variable obtained from maximizing (sn=1) or minimizing (sn=-1) the expression v over variables b (Comp, Expr or list) """ if reg_outer is not None: if quick_outer is None: quick_outer = quick a0 = self.optimum(v, b, sn, name = name, reg_outer = None, assume_feasible = assume_feasible, allow_reuse = allow_reuse, quick = quick) a1 = reg_outer.optimum(v, b, sn, name = name, reg_outer = None, assume_feasible = assume_feasible, allow_reuse = allow_reuse, quick = quick_outer) a1.terms[0][0].substitute(a1.terms[0][0].x[0], a0.terms[0][0].x[0]) a0.terms[0][0].reg_outer = a1.terms[0][0].reg if tighten: a0.tighten() return a0 tmpstr = "" if name is not None: tmpstr = name else: if sn > 0: tmpstr = "max" else: tmpstr = "min" tmpstr += str(iutil.hash_short(self)) tmpstr = tmpstr + "(" + str(v) + ")" tmpvar = Expr.real(tmpstr) if allow_reuse and v.size() == 1 and v.terms[0][0].get_type() == TermType.REAL: coeff = v.terms[0][1] if coeff < 0: sn *= -1 cs = self.copy() if b is not None: b = Region.get_allcomp(b) - v.allcomp() if quick is False: cs.eliminate(b) else: cs.eliminate_quick(b) if assume_feasible: if cs.get_type() == RegionType.NORMAL: cs.remove_realvar_sn(v, sn) if not allow_reuse: cs.substitute(Expr.fromterm(Term(v.terms[0][0].copy().x, Comp.empty())), tmpvar) v = tmpvar cs.add_meta_present(v, "pf_note", "at optimum") return Expr.fromterm(Term(v.terms[0][0].copy().x, Comp.empty(), cs, sn)) * coeff cs = self.copy() toadd = Region.universe() if sn > 0: toadd.exprs_ge.append(v - tmpvar) else: toadd.exprs_ge.append(tmpvar - v) toadd = toadd.flattened(minmax_elim = True) cs = cs & toadd # cs.iand_norename(toadd) if quick is None: quick = v.isrealvar() and v.allcomp().super_of(Region.get_allcomp(b)) return cs.optimum(tmpvar, b, sn, name = name, reg_outer = reg_outer, assume_feasible = assume_feasible, allow_reuse = True, quick = quick) def maximum(self, expr, vs, reg_outer = None, **kwargs): """Return the variable obtained from maximizing the expression expr over variables vs (Comp, Expr or list) """ return self.optimum(expr, vs, 1, reg_outer = reg_outer, **kwargs) def minimum(self, expr, vs, reg_outer = None, **kwargs): """Return the variable obtained from minimizing the expression expr over variables vs (Comp, Expr or list) """ return self.optimum(expr, vs, -1, reg_outer = reg_outer, **kwargs) def init_prog(self, index = None, lptype = None, save_res = False, lp_bounded = None, dual_enabled = None, val_enabled = None, simplified_prog = False): if index is None: index = IVarIndex() self.record_to(index) prog = None if lptype is None: lptype = PsiOpts.settings["lptype"] if PsiOpts.settings["lptype_H_if_proof"] and PsiOpts.settings["proof_enabled"]: lptype = LinearProgType.H if lp_bounded is None: if save_res: lp_bounded = True if simplified_prog: lp_bounded = False save_res = False val_enabled = False if lptype == LinearProgType.H: prog = LinearProg(index, lptype, lp_bounded = lp_bounded, save_res = save_res, prereg = self, dual_enabled = dual_enabled, val_enabled = val_enabled) elif lptype == LinearProgType.HC1BN or lptype == LinearProgType.HMIN: bnet = self.get_bayesnet_imp(skip_simplify = True, add_hc = PsiOpts.settings["lp_bnet_hc"]) if PsiOpts.settings["lp_bnet_reverse"]: bnet = bnet.reversed() kindex = bnet.index.copy() kindex.add_varindex(index) prog = LinearProg(kindex, lptype, bnet, lp_bounded = lp_bounded, save_res = save_res, prereg = self, dual_enabled = dual_enabled, val_enabled = val_enabled) for x in self.exprs_gei: prog.addExpr_ge0(x) for x in self.exprs_eqi: prog.addExpr_eq0(x) if simplified_prog: for x in igen.subset(index.comprv, minsize = 1): prog.addExpr_eq0((Expr.H(x) * (1.0 + 1.0 / len(x)) * 0.1).add_meta("sim", True)) # prog.addExpr_eq0((Expr.H(x) / (1.0 + 1.0 / len(x)) * 0.1).add_meta("sim", True)) for x in index.comprv: for y in igen.subset(index.comprv - x, minsize = 1): if len(x) == len(y) and str(x) > str(y): continue prog.addExpr_eq0((Expr.I(x, y) * (1.0 + 1.0 / len(x + y)) * 0.1).add_meta("sim", True)) prog.finish(skip_ent_ineq = simplified_prog) return prog def init_simplified_prog(self, index = None): r = None with PsiOpts(lp_dual_form = True, proof_enabled = False): r = self.imp_flipped().init_prog(index, simplified_prog = True) return r def get_basis(self, more_vars = None): """Get a basis of the entropy region of this region (may not be minimal). """ cs = self.consonly().imp_flipped() index = IVarIndex() cs.record_to(index) if more_vars is not None: more_vars.record_to(index) return cs.init_prog(index).get_var_exprs() def get_prog_region(self, toreal = None, toreal_only = False): cs = self.consonly().imp_flipped() index = IVarIndex() cs.record_to(index) r = cs.init_prog(index, lptype = LinearProgType.H).get_region(toreal, toreal_only) return r def get_extreme_rays(self): cs = self.consonly().imp_flipped() index = IVarIndex() cs.record_to(index) r = cs.init_prog(index, lptype = LinearProgType.H).get_extreme_rays() return r def implies_ineq_cons_hash(self, expr, sg): chash = hash(expr) if sg == ">=": for x in self.exprs_ge: if chash == hash(x): return True for x in self.exprs_eq: if chash == hash(x): return True return False def implies_ineq_cons_quick(self, expr, sg): """Return whether self implies expr >= 0 or expr == 0, without linear programming""" if sg == "==" and expr.isnonneg(): sg = ">=" expr = -expr if sg == ">=": for x in self.exprs_ge: d = (expr - x).simplified_quick() if d.isnonneg(): return True for x in self.exprs_eq: d = (expr - x).simplified_quick() if d.isnonneg(): return True d = (expr + x).simplified_quick() if d.isnonneg(): return True return False if sg == "==": for x in self.exprs_eq: d = (expr - x).simplified_quick() if d.iszero(): return True return False return False def implies_ineq_quick(self, expr, sg, bnet = None): """Return whether self implies expr >= 0 or expr == 0, without linear programming""" expr = expr.simplified_quick(bnet = bnet) if expr.iszero(): return True if sg == "==" and expr.isnonneg(): sg = ">=" expr = -expr if sg == ">=": if expr.isnonneg(): return True for x in self.exprs_gei: d = (expr - x).simplified_quick(bnet = bnet) if d.isnonneg(): return True for x in self.exprs_eqi: d = (expr - x).simplified_quick(bnet = bnet) if d.isnonneg(): return True d = (expr + x).simplified_quick(bnet = bnet) if d.isnonneg(): return True return False if sg == "==": for x in self.exprs_eqi: d = (expr - x).simplified_quick(bnet = bnet) if d.iszero(): return True return False return False def implies_ineq_prog(self, index, progs, expr, sg, save_res = False, saved = False, allow_shortcut = True): if saved == "both": return (self.implies_ineq_prog(index, progs, expr, sg, save_res, saved = True, allow_shortcut = allow_shortcut) and self.implies_ineq_prog(index, progs, expr, sg, save_res, saved = False, allow_shortcut = allow_shortcut)) #print("save_res = " + str(save_res)) if not saved and allow_shortcut and self.implies_ineq_quick(expr, sg): return True if len(progs) == 0: progs.append(self.init_prog(index, save_res = save_res)) if progs[0].checkexpr(expr, sg, saved = saved): return True return False def implies_impflipped_saved(self, other, index, progs): verbose_subset = PsiOpts.settings.get("verbose_subset", False) if verbose_subset: print(self) print(other) for x in other.exprs_ge: if not self.implies_ineq_prog(index, progs, x, ">=", save_res = True, saved = "both"): if verbose_subset: print(str(x) + " >= 0 FAIL") return False for x in other.exprs_eq: if not self.implies_ineq_prog(index, progs, x, "==", save_res = True, saved = "both"): if verbose_subset: print(str(x) + " == 0 FAIL") return False if verbose_subset: print("SUCCESS") return True def implies_saved(self, other, index, progs): self.imp_flip() r = self.implies_impflipped_saved(other, index, progs) self.imp_flip() return r def check_quick(self, bnet = None, skip_simplify = False): """Return whether implication is true""" verbose_subset = PsiOpts.settings.get("verbose_subset", False) cs = self if not skip_simplify: cs = self.simplified_quick(zero_group = 2) cs.split() if verbose_subset: print(cs) for x in cs.exprs_ge: if not cs.implies_ineq_quick(x, ">=", bnet = bnet): if verbose_subset: print(str(x) + " >= 0 FAIL") return False for x in cs.exprs_eq: if not cs.implies_ineq_quick(x, "==", bnet = bnet): if verbose_subset: print(str(x) + " == 0 FAIL") return False if verbose_subset: print("SUCCESS") return True def check_plain(self, skip_simplify = False, quick = False, bnet = None): """Return whether implication is true""" if quick: return self.check_quick(bnet = bnet, skip_simplify = skip_simplify) verbose_subset = PsiOpts.settings.get("verbose_subset", False) allow_shortcut = not (PsiOpts.settings["proof_noskip"] and PsiOpts.settings["proof_enabled"]) cs = self if not skip_simplify: if allow_shortcut: cs = self.simplified_quick(zero_group = 2) if not PsiOpts.settings.get("lp_zero_group", False) or (PsiOpts.settings["lp_no_zero_group_if_proof"] and PsiOpts.settings["proof_enabled"]): cs.split() # print(PsiOpts.settings["proof_noskip"]) # print(PsiOpts.settings["proof_enabled"]) # print(cs) index = IVarIndex() cs.record_to(index) progs = [] if verbose_subset: print(cs) for x in cs.exprs_ge: if not cs.implies_ineq_prog(index, progs, x, ">=", allow_shortcut = allow_shortcut): if verbose_subset: print(str(x) + " >= 0 FAIL") return False for x in cs.exprs_eq: if not cs.implies_ineq_prog(index, progs, x, "==", allow_shortcut = allow_shortcut): if verbose_subset: print(str(x) + " == 0 FAIL") return False if verbose_subset: print("SUCCESS") return True def check_dual(self, optimal_terms = True, force_H = False, mul_coeff = True, skip_simplify = True, existing_only = False): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only) if self.aux_present(): rr = self.check_getaux() if rr is None: return None subdict = Comp.substitute_list_to_dict(rr, multi = True) cs = self.copy() cs = cs.substituted_dict_union(subdict).noaux() return cs.check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only) verbose_subset = PsiOpts.settings.get("verbose_subset", False) ceps = PsiOpts.settings["eps_lp"] # cs = self.imp_flippedonly_noaux() cs = self r = Region.universe() index = IVarIndex() cs.record_to(index) progs = [] ctx = None if force_H: ctx = PsiOpts(lptype = "H") else: ctx = contextlib.nullcontext() with PsiOpts(lp_dual_form = True), ctx: progs.append(self.init_prog(index, dual_enabled = True)) for x in self.exprs_ge + self.exprs_eq + [-y for y in self.exprs_eq]: if not cs.implies_ineq_prog(index, progs, x, ">=", allow_shortcut = False): if verbose_subset: print(str(x) + " >= 0 FAIL") return None if optimal_terms: t = progs[0].get_dual_region_terms(x) if t is not None: r.iand_norename(t) else: r.iand_norename(x >= 0) elif existing_only: for y in self.exprs_gei: t = progs[0].get_dual_expr(y) if t is not None and abs(t) > ceps: if mul_coeff: r.exprs_ge.append(y * t) else: r.exprs_ge.append(y) for y in self.exprs_eqi: t = progs[0].get_dual_expr(y) if t is not None and abs(t) > ceps: if mul_coeff: r.exprs_eq.append(y * t) else: r.exprs_eq.append(y) else: csum = Expr.zero() cdual = progs[0].get_dual_region(mul_coeff = mul_coeff, omit_trivial = not skip_simplify, tosum=csum) # print(cdual) # print(csum) if cdual is not None: r.iand_norename(cdual) csum = x - csum csum.simplify() if not csum.isnonneg(): r.iand_norename(csum >= 0) else: r.iand_norename(x >= 0) # v, prog = cs.minimum(x).solve_prog("dual") # if v is None or v < -ceps or prog is None: # return None # cdual = prog.get_dual_region() # if cdual is not None: # r &= cdual if not skip_simplify: r = r.simplified_quick() return r def min_assumption(self, optimal_terms = True, force_H = False, mul_coeff = False, skip_simplify = False, existing_only = False): """Get a smaller set of assumptions needed to prove an implication. WARNING: If not used with parameter "force_H=True", may omit conditional independence and functional dependencies. """ return self.check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only) def ic_list(self, v): r = [] for x in self.exprs_ge: if x.isnonpos(): continue for (a, c) in x.terms: if not a.isic2(): continue if a.x[0].ispresent(v): r.append((a.x[1].copy(), c)) if a.x[1].ispresent(v): r.append((a.x[0].copy(), c)) return r def ic_list_similarity(self, vl, wl): r = 0 for (va, vc) in vl: for (wa, wc) in wl: if vc * wc < 0: continue if va == wa: r += 2 elif va.super_of(wa) or wa.super_of(va): r += 1 return r def get_hc(self): r = Expr.zero() for x in self.exprs_ge: if x.isnonpos(): for a, c in x.terms: if a.ishc(): cx = a.x[0] - a.z if not cx.isempty(): r.terms.append((Term.Hc(cx, a.z), 1.0)) for x in self.exprs_eq: if x.isnonpos() or x.isnonneg(): for a, c in x.terms: if a.ishc(): cx = a.x[0] - a.z if not cx.isempty(): r.terms.append((Term.Hc(cx, a.z), 1.0)) return r def get_var_avoid(self, a): r = None for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: t = x.get_var_avoid(a) if r is None: r = t elif t is not None: r = r.inter(t) return r def get_aux_avoid_list(self): r = [] caux = self.getaux() for a in caux: t = self.get_var_avoid(a) if t is not None: r.append((a, t)) return r def check_getaux_inplace(self, must_include = None, single_include = None, hint_pair = None, hint_aux = None, hint_aux_avoid = None, max_iter = None, leaveone = None): """Return whether implication is true, with auxiliary search result""" if hint_aux_avoid is None: hint_aux_avoid = [] hint_aux_avoid = hint_aux_avoid + self.get_aux_avoid_list() for rr in self.check_getaux_inplace_gen(must_include = must_include, single_include = single_include, hint_pair = hint_pair, hint_aux = hint_aux, hint_aux_avoid = hint_aux_avoid, max_iter = max_iter, leaveone = leaveone): if iutil.signal_type(rr) == "": return rr return None def check_getaux_inplace_gen(self, must_include = None, single_include = None, hint_pair = None, hint_aux = None, hint_aux_avoid = None, max_iter = None, leaveone = None): """Generator that yields all auxiliary search result""" if leaveone is None: leaveone = PsiOpts.settings["auxsearch_leaveone"] write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False) leaveone_add_ineq = PsiOpts.settings["auxsearch_leaveone_add_ineq"] if self.aux.isempty(): if len(self.exprs_eq) == 0 and len(self.exprs_ge) == 0: yield [] return if not leaveone: oproof = None proof_stepped_in = False if write_pf_enabled: oproof = PsiOpts.get_proof().copy() pcase = PsiOpts.get_proof().get_case() if not pcase.isuniverse(): pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly(), c = ["Case ", pcase.copy()]) PsiOpts.set_setting(proof_step_in = pf) pf = PsiOpts.get_proof() proof_stepped_in = True desc_more = self.get_meta("pf_note_case") if desc_more is not None: pf.insert_step(ProofObj.from_region(pf.claim, c = desc_more), 0) pf.claim = None pf.desc += [":"] cres = self.check_plain(skip_simplify = True) if proof_stepped_in: PsiOpts.set_setting(proof_step_out = True) if cres: yield [] else: if oproof is not None: PsiOpts.set_proof(oproof) return verbose = PsiOpts.settings.get("verbose_auxsearch", False) verbose_step = PsiOpts.settings.get("verbose_auxsearch_step", False) verbose_result = PsiOpts.settings.get("verbose_auxsearch_result", False) verbose_cache = PsiOpts.settings.get("verbose_auxsearch_cache", False) verbose_step_cached = PsiOpts.settings.get("verbose_auxsearch_step_cached", False) as_generator = True if must_include is None: must_include = Comp.empty() noncircular = PsiOpts.settings["imp_noncircular"] noncircular_allaux = PsiOpts.settings["imp_noncircular_allaux"] #noncircular_skipfail = True forall_multiuse = PsiOpts.settings["forall_multiuse"] forall_multiuse_numsave = PsiOpts.settings["forall_multiuse_numsave"] auxsearch_local = PsiOpts.settings["auxsearch_local"] save_res = auxsearch_local if max_iter is None: max_iter = PsiOpts.settings["auxsearch_max_iter"] lpcost = 100 maxcost = max_iter * lpcost curcost = [0] cs = self.copy() index = IVarIndex() #cs.record_to(index) progs = [] #ccomp = must_include + cs.auxi + (index.comprv - cs.auxi - must_include) ccomp = must_include + cs.auxi + (cs.allcomprv() - cs.aux - cs.auxi - must_include) index.record(ccomp) index.record(cs.allcompreal()) auxcomp = cs.aux.copy() auxcond = Comp.empty() for a in auxcomp.varlist: b = Comp([a]) if self.imp_ispresent(b): auxcond += b auxcomp = auxcond + auxcomp n = auxcomp.size() n_cond = auxcond.size() m = ccomp.size() #m_flip = cs.auxi.size() clist = collections.deque() clist_hashset = set() flipflag = 0 auxiclist = [cs.ic_list(a) for a in auxcomp] cs_flipped = cs.imp_flipped() ciclist = [cs_flipped.ic_list(a) for a in ccomp] cvisflag = 0 for i in range(n): maxj = -1 maxval = 0 for j in range(m): t = cs.ic_list_similarity(auxiclist[i], ciclist[j]) if t > maxval: maxval = t maxj = j if maxj >= 0: flipflag |= 1 << (m * i + maxj) cvisflag |= 1 << maxj for i in range(n): if (flipflag >> (m * i)) & ((1 << m) - 1) == 0: flipflag |= (((1 << (must_include + cs.auxi).size()) - 1) & ~cvisflag) << (m * i) mustflag = 0 for i in range(n): #flipflag += ((1 << (must_include + cs.auxi).size()) - 1) << (m * i) mustflag += ((1 << must_include.size()) - 1) << (m * i) singleflag = 0 if single_include is not None: for j in range(m): if single_include.ispresent(ccomp[j]): singleflag += 1 << j #m_flip = 0 # mp2 = (1 << m) comppair = [-1 for j in range(m)] auxpair = [-1 for i in range(n)] compside = [0 for j in range(m)] auxside = [0 for i in range(n)] if hint_pair is not None: for (a, b) in hint_pair: ai = -1 bi = -1 for i in range(n): if auxcomp.varlist[i] == a.varlist[0]: ai = i auxside[i] = 1 elif auxcomp.varlist[i] == b.varlist[0]: bi = i auxside[i] = -1 if ai >= 0 and bi >= 0: auxpair[ai] = bi auxpair[bi] = ai ai = -1 bi = -1 for i in range(m): if ccomp.varlist[i] == a.varlist[0]: ai = i compside[i] = 1 elif ccomp.varlist[i] == b.varlist[0]: bi = i compside[i] = -1 if ai >= 0 and bi >= 0: comppair[ai] = bi comppair[bi] = ai setflag = 0 if hint_aux is not None: for taux, tc in hint_aux: pair_allowed = taux.get_marker_key("incpair") is not None tcmask = 0 tcmask_pair = 0 for j in range(m): if tc.ispresent(ccomp[j]): tcmask |= 1 << j if pair_allowed and comppair[j] < 0: tcmask_pair |= 1 << j elif pair_allowed and comppair[j] >= 0 and tc.ispresent(ccomp[comppair[j]]): tcmask_pair |= 1 << j for i in range(n): if taux.ispresent(auxcomp[i]): setflag |= tcmask << (m * i) elif pair_allowed and auxpair[i] >= 0 and taux.ispresent(auxcomp[auxpair[i]]): setflag |= tcmask_pair << (m * i) auxlist = [Comp.empty() for i in range(n)] auxflag = [0 for i in range(n)] eqs = [] for x in cs.exprs_ge: eqs.append((x.copy(), ">=")) for x in cs.exprs_eq: eqs.append((x.copy(), "==")) eqvs_range = n * 2 + 1 eqvs = [[] for i in range(eqvs_range)] eqvsid = [[] for i in range(eqvs_range)] eqvpresflag = [[] for i in range(eqvs_range)] eqvleaveok = [[] for i in range(eqvs_range)] eqvs_emptyid = n * 2 for (x, sg) in eqs: maxi = -1 mini = 100000000 presflag = 0 for i in range(n): if x.ispresent(Comp([auxcomp.varlist[i]])): maxi = max(maxi, i) mini = min(mini, i) presflag |= 1 << i ii = eqvs_emptyid if maxi >= 0: if mini == maxi: ii = maxi * 2 else: ii = maxi * 2 + 1 eqvsid[ii].append(len(eqvs[ii])) eqvs[ii].append((x, sg)) eqvpresflag[ii].append(presflag) eqvleaveok[ii].append(sg == ">=" and not x.isnonpos() and not x.ispresent(auxcond)) eqvsns = [[] for i in range(eqvs_range)] for i in range(eqvs_range): for j in range(len(eqvs[i])): x, sg = eqvs[i][j] if sg == "==": eqvsns[i].append(0) else: csn = None for i2 in range(n): if eqvpresflag[i][j] & (1 << i2) != 0: tsn = x.get_sign(auxcomp.varlist[i2]) if tsn == 0 or ((csn is not None) and csn != tsn): csn = 0 break else: csn = tsn eqvsns[i].append(csn if csn is not None else 0) eqvsncache = [[] for i in range(n * 2 + 1)] eqvflagcache = [[] for i in range(n * 2 + 1)] for i in range(eqvs_range): eqvsncache[i] = [[[], []] for j in range(len(eqvs[i]))] eqvflagcache[i] = [{} for j in range(len(eqvs[i]))] auxsetflag = [(setflag >> (m * i)) & ((1 << m) - 1) for i in range(n)] auxavoidflag = [0 for i in range(n)] if hint_aux_avoid is not None: for taux, tc in hint_aux_avoid: pair_allowed = taux.get_marker_key("incpair") is not None tcmask = 0 tcmask_pair = 0 for j in range(m): if tc.ispresent(ccomp[j]): tcmask |= 1 << j if pair_allowed and comppair[j] < 0: tcmask_pair |= 1 << j elif pair_allowed and comppair[j] >= 0 and tc.ispresent(ccomp[comppair[j]]): tcmask_pair |= 1 << j for i in range(n): if taux.ispresent(auxcomp[i]): auxavoidflag[i] |= tcmask elif pair_allowed and auxpair[i] >= 0 and taux.ispresent(auxcomp[auxpair[i]]): auxavoidflag[i] |= tcmask_pair if False: for i in range(n): t = self.get_var_avoid(auxcomp[i]) if t is not None: auxavoidflag[i] |= ccomp.get_mask(t) for i in range(n): auxavoidflag[i] &= ~auxsetflag[i] avoidflag = sum(auxavoidflag[i] << (m * i) for i in range(n)) flipflag &= ~avoidflag disjoint_ids = [-1 for i in range(n)] symm_ids = [-1 for i in range(n)] nonsubset_ids = [-1 for i in range(n)] nonempty_is = [False for i in range(n)] symm_nonempty_ns = [0] * n for i in range(n): if auxcomp.varlist[i].markers is None: continue cdict = {v: w for v, w in auxcomp.varlist[i].markers} disjoint_ids[i] = cdict.get("disjoint", -1) symm_ids[i] = cdict.get("symm", -1) nonsubset_ids[i] = cdict.get("nonsubset", -1) nonempty_is[i] = "nonempty" in cdict symm_nonempty_ns[i] = cdict.get("symm_nonempty", 0) for i in range(n): if symm_nonempty_ns[i] > 0: nsymm = 0 for i2 in range(i + 1, n): if symm_ids[i] == symm_ids[i2]: nsymm += 1 if nsymm < symm_nonempty_ns[i]: nonempty_is[i] = True #print("NONSUBSET " + "; ".join(str(auxcomp[i]) for i in range(n) if nonsubset_ids[i] >= 0)) fcns = cs_flipped.get_hc() fcns_mask = [] for a, c in fcns.terms: fcns_mask.append((index.get_mask(a.x[0]), index.get_mask(a.z))) if verbose: print("========= aux search ========") print(cs.imp_flippedonly()) print("========= subset of =========") #print(co) for i in range(n * 2 + 1): hs = "" if i == n * 2: hs = "NONE" elif i % 2 == 0: hs = str(auxcomp.varlist[i // 2]) else: hs = "<=" + str(auxcomp.varlist[i // 2]) for j in range(len(eqvs[i])): eqvsnstr = "" if i < n * 2: eqvsnstr = " " + ("amb" if eqvsns[i][j] == 0 else "inc" if eqvsns[i][j] == 1 else "dec") if leaveone and eqvleaveok[i][j]: eqvsnstr += " " + "leaveok" print(iutil.strpad(hs, 8, ": " + str(eqvs[i][j][0]) + " " + str(eqvs[i][j][1]) + " 0" + eqvsnstr)) print("========= variables =========") print(ccomp) print("========= auxiliary =========") #print(auxcomp) print(str(auxcond) + " ; " + str(auxcomp - auxcond)) if len(fcns_mask) > 0: print("========= functions =========") for x, z in fcns_mask: print(str(ccomp.from_mask(x)) + " <- " + str(ccomp.from_mask(z))) if hint_pair is not None: print("========= pairing =========") for i in range(n): if auxpair[i] > i: print(str(auxcomp.varlist[i]) + " <-> " + str(auxcomp.varlist[auxpair[i]])) for i in range(m): if comppair[i] > i: print(str(ccomp.varlist[i]) + " <-> " + str(ccomp.varlist[comppair[i]])) print("========= initial =========") for i in range(n): cflag = (flipflag >> (m * i)) & ((1 << m) - 1) csetflag = auxsetflag[i] cavoidflag = auxavoidflag[i] ccor = Comp.empty() cset = Comp.empty() cavoid = Comp.empty() for j in range(m): if cflag & (1 << j) != 0: ccor += ccomp[j] if csetflag & (1 << j) != 0: cset += ccomp[j] if cavoidflag & (1 << j) != 0: cavoid += ccomp[j] print(str(auxcomp.varlist[i]) + " : " + str(ccor) + (" Fix: " + str(cset) if not cset.isempty() else "") + (" Avoid: " + str(cavoid) if not cavoid.isempty() else "")) if self.aux.isempty() and write_pf_enabled and leaveone: oproof = PsiOpts.get_proof().copy() pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly(), c = "Claim:") PsiOpts.set_setting(proof_step_in = pf) pf = PsiOpts.get_proof() # print(self) curleave = None for i in range(n * 2 + 1): for j in range(len(eqvs[i])): cs = self.noaux() cs.exprs_ge = [] cs.exprs_eq = [] if eqvs[i][j][1] == "==": cs.exprs_eq.append(eqvs[i][j][0]) else: cs.exprs_ge.append(eqvs[i][j][0]) tres = cs.check_plain(skip_simplify = True) if not tres: if not eqvleaveok[i][j] or curleave is not None: PsiOpts.set_setting(proof_step_out = True) PsiOpts.set_proof(oproof) return else: curleave = eqvs[i][j][0] PsiOpts.set_setting(proof_step_out = True) pcase = PsiOpts.get_proof().get_case() if curleave is not None or not pcase.isuniverse(): tcase = pcase.copy() if curleave is not None: tcase &= (curleave >= 0) PsiOpts.get_proof().set_case(pcase & (curleave <= 0)) pf.desc = ["Case ", tcase] desc_more = self.get_meta("pf_note_case") if desc_more is not None: # pf.desc += [", ", "\n"] + desc_more pf.insert_step(ProofObj.from_region(pf.claim, c = desc_more), 0) pf.claim = None pf.desc += [":"] # pf2 = ProofObj.from_region(None, c = ["Remaining case: ", curleave <= 0]) # PsiOpts.set_setting(proof_add = pf2) # print("WRITE PF LEAVEONE " + str(curleave)) if curleave is None: yield [] else: yield ("leaveone", [], -curleave) return if len(eqs) == 0: #print(bin(setflag)) #print(bin(avoidflag)) #print("m=" + str(m) + " n=" + str(n)) #print("ccomp=" + str(ccomp) + " auxcomp=" + str(auxcomp)) #for a in auxcomp.varlist: # print(str(a) + " " + str(a.markers)) mleft = [j for j in range(m * n) if setflag & (1 << j) == 0 and avoidflag & (1 << j) == 0] def check_nocond_recur(i, size, allflag): #print(str(i) + " " + str(size) + " " + bin(allflag)) if i == n: if mustflag != 0 and mustflag & allflag == 0: return rr = [(auxcomp[i2].copy(), auxlist[i2].copy()) for i2 in range(n)] yield rr return mlefti = [j - m * i for j in mleft if j >= m * i and j < m * (i + 1)] symm_break = -1 if symm_ids[i] >= 0: for i2 in range(i): if symm_ids[i] == symm_ids[i2]: symm_break = max(symm_break, auxflag[i2]) if disjoint_ids[i] >= 0: for i2 in range(i): if disjoint_ids[i] == disjoint_ids[i2]: mlefti = [j for j in mlefti if auxflag[i2] & (1 << j) == 0] sizelb = max(size - sum(j >= m * (i + 1) for j in mleft), 0) sizeub = min(min(len(mlefti), m), size) if sizelb > sizeub: return for tsize in range(sizelb, sizeub + 1): for comb in itertools.combinations(list(reversed(mlefti)), tsize): curcost[0] += 1 if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return auxflag[i] = sum(1 << j for j in comb) | auxsetflag[i] if auxflag[i] < symm_break: break if nonempty_is[i] and auxflag[i] == 0: continue if any(auxflag[i] | z == auxflag[i] and auxflag[i] & x != 0 for x, z in fcns_mask): continue if nonsubset_ids[i] >= 0: tbad = False for i2 in range(i): if nonsubset_ids[i] == nonsubset_ids[i2]: if auxflag[i] | auxflag[i2] == auxflag[i] or auxflag[i] | auxflag[i2] == auxflag[i2]: tbad = True break if tbad: continue auxlist[i] = ccomp.from_mask(auxflag[i]) #print(str(i) + " " + str(m) + " " + str(ccomp) + " " + bin(auxflag[i]) + " " + str(auxlist[i])) for rr in check_nocond_recur(i + 1, size - len(comb), allflag | (auxflag[i] << (m * i))): yield rr #print("START") for tsize in range(len(mleft) + 1): for rr in check_nocond_recur(0, tsize, 0): #print("; ".join(str(v) + ":" + str(w) for v, w in rr)) yield rr #print("END") return auxcache = [{} for i in range(n)] if verbose: print("========= progress: =========") leaveone_static = None # *********** Check conditions that does not depend on aux *********** if n_cond == 0: for (x, sg) in eqvs[eqvs_emptyid]: if not cs.implies_ineq_prog(index, progs, x, sg, save_res = save_res): if leaveone and sg == ">=": if verbose_step: print(" F LO " + str(x) + " " + sg + " 0") leaveone = False leaveone_static = -x else: if verbose_step: print(" F " + str(x) + " " + sg + " 0") return None allflagcache = [{} for i in range(n + 1)] cursizepass = 0 cs = Region.universe() cs_added = Region.universe() #self.imp_only_copy_to(cs_added) condflagadded = {} condflagadded_true = collections.deque() flagcache = [set() for i in range(n + 1)] maxprogress = [-1, 0, -1] flipflaglen = 0 numfail = [0] numclear = [0] auxavoidflag_orig = list(auxavoidflag) def clear_cache(mini, maxi): if verbose_cache: print("========= cache clear: " + str(mini) + " - " + str(maxi) + " =========") progs[:] = [] auxcache[mini:maxi] = [{} for i in range(mini, maxi)] for a in flagcache: a.clear() for i in range(mini, maxi): auxavoidflag[i] = auxavoidflag_orig[i] for i in range(mini * 2, eqvs_range): for j in range(len(eqvs[i])): if eqvpresflag[i][j] & ((1 << maxi) - (1 << mini)) != 0: eqvsncache[i][j] = [[], []] eqvflagcache[i][j] = {} # eqvsncache[i] = [[[], []] for j in range(len(eqvs[i]))] # eqvflagcache[i] = [{} for j in range(len(eqvs[i]))] def build_region(i, allflag, allownew, cflag = 0, add_ineq = None): numclear[0] += 1 cs_added_changed = False prev_csstr = cs.tostring(tosort = True) if add_ineq is not None: prev_csaddedstr = cs_added.tostring(tosort = True) cs_added.exprs_gei.append(add_ineq) cs_added.simplify_quick(zero_group = 2) #cs_added.split() if prev_csaddedstr != cs_added.tostring(tosort = True): cs_added_changed = True clist.appendleft(cflag) while len(clist) > forall_multiuse_numsave: clist.pop() if verbose_step: print("========= leave one added =========") print(cs_added.imp_flipped()) print("==================") cs_added.imp_only_copy_to(cs) elif i >= n_cond and (forall_multiuse or leaveone): csnew = Region.universe() if not (allflag in condflagadded): self.imp_only_copy_to(csnew) for i3 in range(i, n_cond): csnew.remove_present(Comp([auxcomp.varlist[i3]])) for i3 in range(i): csnew.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3]) if allownew: condflagadded[allflag] = True prev_csaddedstr = cs_added.tostring(tosort = True) cs_added.iand_norename(csnew) cs_added.simplify_quick(zero_group = 2) #cs_added.split() if prev_csaddedstr != cs_added.tostring(tosort = True): cs_added_changed = True condflagadded_true.appendleft(allflag) while len(condflagadded_true) > forall_multiuse_numsave: condflagadded_true.pop() if verbose_step: print("========= forall added =========") print(cs_added.imp_flipped()) print("==================") cs_added.imp_only_copy_to(cs) if not allownew: cs.iand_norename(csnew) else: self.imp_only_copy_to(cs) for i3 in range(i, n_cond): cs.remove_present(Comp([auxcomp.varlist[i3]])) for i3 in range(i): cs.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3]) if cs_added_changed or prev_csstr != cs.tostring(tosort = True): if verbose_cache: print("========= cleared =========") print(prev_csstr) print("========= to =========") print(cs.tostring(tosort = True)) maxi = n if forall_multiuse and not cs_added_changed: maxi = n_cond if noncircular: if noncircular_allaux: clear_cache(n_cond, maxi) else: clear_cache(min(1, n_cond), maxi) else: clear_cache(0, maxi) else: if verbose_cache: print("========= not cleared =========") print(prev_csstr) #build_region(0, 0, True) build_region(0, 0, leaveone) def is_marker_sat(i): if nonempty_is[i] and auxflag[i] == 0: return False if any(auxflag[i] | z == auxflag[i] and auxflag[i] & x != 0 for x, z in fcns_mask): return False if nonsubset_ids[i] >= 0: for i2 in range(i): if nonsubset_ids[i] == nonsubset_ids[i2]: if auxflag[i] | auxflag[i2] == auxflag[i] or auxflag[i] | auxflag[i2] == auxflag[i2]: return False if symm_ids[i] >= 0: for i2 in range(i): if symm_ids[i] == symm_ids[i2]: if auxflag[i] < auxflag[i2]: return False if disjoint_ids[i] >= 0: tbad = False for i2 in range(i): if disjoint_ids[i] == disjoint_ids[i2]: if auxflag[i] & auxflag[i2] != 0: return False return True # *********** Sandwich procedure *********** dosandwich = PsiOpts.settings["auxsearch_sandwich"] dosandwich_inc = PsiOpts.settings["auxsearch_sandwich_inc"] if dosandwich: if verbose: print("========== sandwich =========") for i in range(n * 2 + 1): if i == eqvs_emptyid: continue for j in range(len(eqvs[i])): if leaveone and eqvleaveok[i][j]: continue presflag = eqvpresflag[i][j] for sg in ([1, -1] if eqvs[i][j][1] == "==" else [1]): x = eqvs[i][j][0] * sg sns = [0] * n bad = False for i2 in range(n): if presflag & (1 << i2): sns[i2] = x.get_sign(auxcomp.varlist[i2]) if sns[i2] == 0: bad = True break if not dosandwich_inc and sns[i2] > 0: bad = True break if bad: continue for ix in range(-1, n): if sns[ix] == 0: continue xt = x.copy() for i2 in range(n): if i2 == ix: continue cmask = 0 if sns[i2] < 0: cmask = (setflag >> (m * i2)) & ((1 << m) - 1) else: cmask = ~(avoidflag >> (m * i2)) & ((1 << m) - 1) # print(avoidflag) # print(cmask) xt.substitute(Comp([auxcomp.varlist[i2]]), ccomp.from_mask(cmask)) if ix == -1: tres = cs.implies_ineq_prog(index, progs, xt, ">=", save_res = save_res, saved = "both") if verbose_step: print(str(x) + " : " + str(xt) + " >= 0 " + str(tres)) if not tres: return None continue cmask0 = (setflag >> (m * ix)) & ((1 << m) - 1) cmask1 = ~(avoidflag >> (m * ix)) & ((1 << m) - 1) cmaskc = cmask1 & ~cmask0 for k in range(m): if cmaskc & (1 << k): cmask = 0 if sns[ix] < 0: cmask = cmask0 | (1 << k) else: cmask = cmask1 & ~(1 << k) xt2 = xt.copy() xt2.substitute(Comp([auxcomp.varlist[ix]]), ccomp.from_mask(cmask)) tres = cs.implies_ineq_prog(index, progs, xt2, ">=", save_res = save_res, saved = "both") if verbose_step: print(str(x) + " : " + str(xt2) + " >= 0 " + str(tres)) if not tres: if sns[ix] < 0: avoidflag |= 1 << (m * ix + k) auxavoidflag[ix] |= 1 << k else: setflag |= 1 << (m * ix + k) auxsetflag[ix] |= 1 << k flipflag &= ~avoidflag auxavoidflag_orig = list(auxavoidflag) if verbose: print("======= after sandwich ======") for i in range(n): cflag = (flipflag >> (m * i)) & ((1 << m) - 1) csetflag = auxsetflag[i] cavoidflag = auxavoidflag[i] ccor = Comp.empty() cset = Comp.empty() cavoid = Comp.empty() for j in range(m): if cflag & (1 << j) != 0: ccor += ccomp[j] if csetflag & (1 << j) != 0: cset += ccomp[j] if cavoidflag & (1 << j) != 0: cavoid += ccomp[j] print(str(auxcomp.varlist[i]) + " : " + str(ccor) + (" Fix: " + str(cset) if not cset.isempty() else "") + (" Avoid: " + str(cavoid) if not cavoid.isempty() else "")) print("=============================") def check_local(i0, allflag, leave_id = None): cflag = ((flipflag | setflag) >> (m * i0)) << (m * i0) mleft = [j for j in range(m * i0, m * n) if setflag & (1 << j) == 0] #mleft = mleft[::-1] while True: cleave_id = leave_id for i in range(i0, n): auxflag[i] = (cflag >> (m * i)) & ((1 << m) - 1) auxlist[i] = ccomp.from_mask(auxflag[i]) cres = True bad = False for i in range(i0, n): if not is_marker_sat(i): bad = True break if bad: cres = False if cres: for i2 in range(i0 * 2, n * 2): i2r = -1 isone = False if i2 != eqvs_emptyid: i2r = i2 // 2 isone = (i2 % 2 == 0) auxflagi = auxflag[i2r] if isone and (auxflagi in auxcache[i2r]): cres = auxcache[i2r][auxflagi] else: for ieqid in range(len(eqvs[i2])): ieq = eqvsid[i2][ieqid] if cleave_id == (i2, ieq): continue (x, sg) = eqvs[i2][ieq] x2 = x.copy() if isone: x2.substitute(Comp([auxcomp.varlist[i2r]]), auxlist[i2r]) else: for i3 in range(i2r + 1): x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3]) curcost[0] += lpcost if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False tres = cs.implies_ineq_prog(index, progs, x2, sg, save_res = save_res) #print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i)]), # 26, " LO#" + str(numfail[0]), 8, " " + str(x) + " " + sg + " 0 " + str(tres))) if not tres: if isone and eqvsns[i2][ieq] < 0 and (not leaveone or not eqvleaveok[i2][ieq]) and iutil.bitcount(auxflagi & ~auxsetflag[i]) == 1: auxavoidflag[i2r] |= auxflagi & ~auxsetflag[i] if verbose_step: print(iutil.strpad(" L AVOID " + str(auxcomp.varlist[i2r]), 12, " : " + str(ccomp.from_mask(auxavoidflag[i2r])))) if leaveone and cleave_id is None and eqvleaveok[i2][ieq]: cleave_id = (i2, ieq) if verbose_step: print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(n)]), 26, " LSET#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0")) else: if verbose_step: numfail[0] += 1 print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(n)]), 26, " LO#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0")) eqvsid[i2].pop(ieqid) eqvsid[i2].insert(0, ieq) flagcache[i2r + 1].add(cflag & ((1 << ((i2r + 1) * m)) - 1)) #print("FCA " + str(i2r + 1) + " " + bin(cflag & ((1 << ((i2r + 1) * m)) - 1))) cres = False break if isone and not leaveone: auxcache[i2r][auxflagi] = cres if not cres: break if cres: if as_generator: if leaveone and cleave_id is not None: x2 = -eqvs[cleave_id[0]][cleave_id[1]][0] for i3 in range(n): x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3]) #print("BUILD " + str(x2)) x2.simplify_quick() if not x2.isnonneg(): if leaveone_add_ineq: build_region(n_cond, allflag, False, cflag = cflag, add_ineq = x2) yield ("leaveone", x2.copy()) else: yield allflag | cflag else: return True flagcache[n].add(cflag) def check_local_recur(i, kflag, sizelb, sizeub, kleave_id): csizelb = max(sizelb - sum(j >= m * (i + 1) for j in mleft), 0) csizeub = sizeub #print(str(i) + " " + str(csizelb) + " " + str(csizeub)) for tsize in range(csizelb, csizeub + 1): mlefti = [j - m * i for j in mleft if j >= m * i and j < m * (i + 1)] mlefti = [j for j in mlefti if auxavoidflag[i] & (1 << j) == 0] cflagi = ((cflag >> (m * i)) & ((1 << m) - 1)) mleftimustflag = cflagi & auxavoidflag[i] ttsize = tsize - iutil.bitcount(mleftimustflag) if ttsize > len(mlefti): break if ttsize < 0: continue for comb in itertools.combinations(mlefti, ttsize): curcost[0] += 1 if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False auxflag[i] = sum(1 << j for j in comb) ^ cflagi auxflag[i] &= ~auxavoidflag[i] if not auxcache[i].get(auxflag[i], True): continue if not is_marker_sat(i): continue auxlist[i] = ccomp.from_mask(auxflag[i]) ckflag = kflag | (auxflag[i] << (m * i)) ckleave_id = kleave_id if i == n - 1 and mustflag != 0 and mustflag & (allflag | ckflag) == 0: continue if ckflag in flagcache[i + 1]: continue cres = True for i2 in range(i * 2, (i + 1) * 2): i2r = -1 isone = False if i2 != eqvs_emptyid: i2r = i2 // 2 isone = (i2 % 2 == 0) auxflagi = auxflag[i2r] for ieqid in range(len(eqvs[i2])): ieq = eqvsid[i2][ieqid] (x, sg) = eqvs[i2][ieq] x2 = x.copy() if isone: x2.substitute(Comp([auxcomp.varlist[i2r]]), auxlist[i2r]) else: for i3 in range(i2r + 1): x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3]) curcost[0] += lpcost if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False tres = cs.implies_ineq_prog(index, progs, x2, sg, save_res = save_res, saved = True) if not tres: if isone and eqvsns[i2][ieq] < 0 and (not leaveone or not eqvleaveok[i2][ieq]) and iutil.bitcount(auxflagi & ~auxsetflag[i]) == 1: auxavoidflag[i2r] |= auxflagi & ~auxsetflag[i] if verbose_step: print(iutil.strpad(" T AVOID " + str(auxcomp.varlist[i2r]), 12, " : " + str(ccomp.from_mask(auxavoidflag[i2r])))) if leaveone and ckleave_id is None and eqvleaveok[i2][ieq]: #if verbose_step and verbose_step_cached: # print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i2r + 1)]), # 26, " TSET=" + str(tsize) + ",#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0")) ckleave_id = (i2, ieq) else: if verbose_step and verbose_step_cached: numfail[0] += 1 print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i2r + 1)]), 26, " TO=" + str(tsize) + ",#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0")) eqvsid[i2].pop(ieqid) eqvsid[i2].insert(0, ieq) flagcache[i2r + 1].add(ckflag) #print("FCB " + str(i2r + 1) + " " + bin(ckflag)) cres = False break if isone and not cres and not leaveone: auxcache[i2r][auxflagi] = cres if not cres: break if not cres: continue if i == n - 1: return True if check_local_recur(i + 1, ckflag, sizelb - len(comb), sizeub - len(comb), ckleave_id): return True if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False return False sizeseq = [0, 1, 2, 4] sizeseq = [s for s in sizeseq if s < len(mleft)] + [len(mleft)] found = False for si in range(1, len(sizeseq)): if si == 0: tcflag = cflag cflag = 0 if check_local_recur(i0, 0, 1, 1, None): found = True cflag = sum(auxflag[i] << (m * i) for i in range(i0, n)) break cflag = tcflag else: if check_local_recur(i0, 0, sizeseq[si - 1] + 1, sizeseq[si], None): found = True cflag = sum(auxflag[i] << (m * i) for i in range(i0, n)) break if not found: return False def check_recur(i, size, stepsize, allflag): if i == n and mustflag != 0 and mustflag & allflag == 0: return False if allflag in allflagcache[i]: return False cprogress = i * (n * 2 + 2) if cprogress > maxprogress[0]: maxprogress[0] = cprogress maxprogress[1] = allflag | (flipflag & ~((1 << (m * i)) - 1)) maxprogress[2] = i if not noncircular and i == n_cond: if verbose_cache: print("========= cache clear: circ, suff # " + str(numclear[0]) + " =========") build_region(n_cond, allflag, False) i2lb = 0 i2ub = 0 if noncircular: if i > 0: i2lb = (i - 1) * 2 i2ub = i * 2 else: if i >= n_cond: i2lb = 0 if i > n_cond: i2lb = (i - 1) * 2 i2ub = i * 2 for i2 in range(i2lb, i2ub): cres = True i2r = -1 isone = False if i2 != eqvs_emptyid: i2r = i2 // 2 isone = (i2 % 2 == 0) auxflagi = auxflag[i2r] if isone and (auxflagi in auxcache[i2r]): cres = auxcache[i2r][auxflagi] else: for ieqid in range(len(eqvs[i2])): ieq = eqvsid[i2][ieqid] (x, sg) = eqvs[i2][ieq] x2 = x.copy() if isone: x2.substitute(Comp([auxcomp.varlist[i2r]]), auxlist[i2r]) else: for i3 in range(i2r + 1): x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3]) auxflagpres = 0 auxflagprescn = 0 for i3 in range(i2r + 1): if eqvpresflag[i2][ieq] & (1 << i3) != 0: auxflagpres |= auxflag[i3] << (m * auxflagprescn) auxflagprescn += 1 tres = None computed = False eqvsn = eqvsns[i2][ieq] if eqvsn == 0: tres = eqvflagcache[i2][ieq].get(auxflagpres, None) else: eqvsn = (eqvsn + 1) // 2 for f in eqvsncache[i2][ieq][eqvsn]: if auxflagpres == f | auxflagpres: tres = (eqvsn == 1) break if tres is None: for f in eqvsncache[i2][ieq][1 - eqvsn]: if f == f | auxflagpres: tres = (eqvsn == 0) break if tres is None: curcost[0] += lpcost if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False tres = cs.implies_ineq_prog(index, progs, x2, sg, save_res = save_res) computed = True eqvsn = eqvsns[i2][ieq] if eqvsn == 0: eqvflagcache[i2][ieq][auxflagpres] = tres else: eqvsncache[i2][ieq][1 if tres else 0].append(auxflagpres) if not tres: if verbose_step and (verbose_step_cached or computed): numfail[0] += 1 print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i)]), 26, " S=" + str(cursizepass) + ",T=" + str(stepsize) + ",L=" + str(flipflaglen) + ",#" + str(numfail[0]), 18, " " + str(x) + " " + sg + " 0" + ("" if computed else " (Ca)"))) eqvsid[i2].pop(ieqid) eqvsid[i2].insert(0, ieq) cres = False break if isone: auxcache[i2r][auxflagi] = cres if not cres: allflagcache[i][allflag] = True return False cprogress = i * (n * 2 + 2) + i2 + 1 if cprogress > maxprogress[0]: maxprogress[0] = cprogress maxprogress[1] = allflag | (flipflag & ~((1 << (m * i)) - 1)) maxprogress[2] = i if i == n_cond and n_cond > 0: if noncircular: if verbose_cache: print("========= cache clear: nonc, checkempty # " + str(numclear[0]) + " =========") build_region(i, allflag, True) for ieqid in range(len(eqvs[eqvs_emptyid])): ieq = eqvsid[eqvs_emptyid][ieqid] (x, sg) = eqvs[eqvs_emptyid][ieq] curcost[0] += lpcost if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False if not cs.implies_ineq_prog(index, progs, x, sg, save_res = save_res): if verbose_step: numfail[0] += 1 print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i)]), 26, " S=" + str(cursizepass) + ",T=" + str(stepsize) + ",L=" + str(flipflaglen) + ",#" + str(numfail[0]), 18, " " + str(x) + " " + sg + " 0")) allflagcache[i][allflag] = True eqvsid[eqvs_emptyid].pop(ieqid) eqvsid[eqvs_emptyid].insert(0, ieq) return False if i == n: if as_generator: yield allflag return else: return True if i == n_cond and auxsearch_local: if as_generator: for rr in check_local(i, allflag): yield rr return else: return check_local(i, allflag) cflipflag = (flipflag >> (m * i)) & ((1 << m) - 1) if i >= flipflaglen - 1: i2 = auxpair[i] if i2 >= 0 and i2 < i: cflipflag = auxflag[i2] for j in range(m): j2 = comppair[j] if j2 >= 0: if (auxflag[i2] & (1 << j) != 0) and (auxflag[i2] & (1 << j2) == 0): cflipflag &= ~(1 << j) csetflag = (setflag >> (m * i)) & ((1 << m) - 1) mleft = [j for j in range(m) if csetflag & (1 << j) == 0] #sizelb = max(0, size - m * (n - i - 1)) sizelb = 0 sizeub = min(min(size, len(mleft)), stepsize) pnumclear = -1 for tsize in range(sizelb, sizeub + 1): for comb in itertools.combinations(mleft, tsize): curcost[0] += 1 if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False #print(tsize, comb) tflag = 0 for j in comb: tflag += (1 << j) auxlist[i] = Comp.empty() auxflag[i] = 0 for j in range(m): if (csetflag & (1 << j) != 0) or (cflipflag & (1 << j) != 0) != (tflag & (1 << j) != 0): auxflag[i] += (1 << j) auxlist[i].varlist.append(ccomp.varlist[j]) if (auxflag[i] & singleflag) != 0 and iutil.bitcount(auxflag[i]) > 1: continue if i < n_cond: pass else: if (auxflag[i] in auxcache[i]) and not auxcache[i]: continue if noncircular and i < n_cond and numclear[0] != pnumclear: if verbose_cache: print("========= cache clear: nonc, inc # " + str(numclear[0]) + " =========") if noncircular_allaux: build_region(0, 0, False) else: build_region(i, allflag, False) pnumclear = numclear[0] recur = check_recur(i+1, size - tsize, stepsize, allflag + (auxflag[i] << (m * i))) if as_generator: for rr in recur: yield rr else: if recur: return True if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): return False return False res_hashset = set() maxsize = m * n - iutil.bitcount(setflag) size = 0 stepsize = 0 while True: cursizepass = size prevprogress = maxprogress[0] recur = check_recur(0, size, stepsize, 0) if not as_generator: recur = [True] if recur else [] for rr in recur: if verbose or verbose_result: print("========= success cost " + str(curcost[0]) + "/" + str(maxcost) + " =========") #print("========= final region =========") #print(cs.imp_flipped()) print("========== aux =========") for i in range(n): print(iutil.strpad(str(auxcomp.varlist[i]), 6, ": " + str(auxlist[i]))) namelist = [auxcomp.varlist[i].name for i in range(n)] res = [] for i in range(n): i2 = namelist.index(self.aux.varlist[i].name) cval = auxlist[i2].copy() if forall_multiuse and i2 < n_cond and len(condflagadded_true) > 0: cval = [ccomp.from_mask(x >> (m * i2)).copy() for x in condflagadded_true] if len(cval) == 1: cval = cval[0] if i2 >= n_cond and len(clist) > 0: cval = [cval] + [ccomp.from_mask(x >> (m * i2)).copy() for x in clist] if len(cval) == 1: cval = cval[0] res.append((Comp([self.aux.varlist[i].copy()]), cval)) if as_generator: res_hash = hash(iutil.list_tostr_std(res)) if not (res_hash in res_hashset): if iutil.signal_type(rr) == "leaveone": yield ("leaveone", res, rr[1]) elif leaveone_static is not None: yield ("leaveone", res, leaveone_static.copy()) else: yield res res_hashset.add(res_hash) else: return res if size >= maxsize: break if n_cond == 0 and auxsearch_local: break flipflag = maxprogress[1] flipflaglen = maxprogress[2] if prevprogress != maxprogress[0]: size = 1 else: if size < 2: size += 1 else: size *= 2 if size >= maxsize: size = maxsize stepsize = m else: clen = max(flipflaglen, 1) stepsize = (size + clen - 1) // clen if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost): yield ("max_iter_reached", ) return None def add_sfrl_imp(self, x, y, gap = None, noaux = True, name = None): ccomp = self.allcomprv() - self.aux if name is None: name = self.name_avoid(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True)) newvar = Comp.rv(name) self.exprs_gei.append(-Expr.I(x, newvar)) self.exprs_gei.append(-Expr.Hc(y, x + newvar)) others = ccomp - x - y if not others.isempty(): self.exprs_gei.append(-Expr.Ic(newvar, others, x + y)) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) self.exprs_gei.append(gap.copy() - Expr.Ic(x, newvar, y)) if not noaux: self.auxi += newvar return newvar @fcn_list_to_list def add_sfrl(self, x, y, gap = None, noaux = True, name = None): self.imp_flip() r = self.add_sfrl_imp(x, y, gap, noaux, name) self.imp_flip() return r def add_esfrl_imp(self, x, y, gap = None, noaux = True): if x.super_of(y): return Comp.empty(), y.copy() if x.isempty(): return y.copy(), Comp.empty() ccomp = self.allcomprv() - self.aux newvar = Comp.rv(self.name_avoid(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True))) newvark = Comp.rv(self.name_avoid(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True) + "_K")) self.exprs_gei.append(-Expr.I(x, newvar)) self.exprs_gei.append(-Expr.Hc(newvark, x + newvar)) self.exprs_gei.append(-Expr.Hc(y, newvar + newvark)) others = ccomp - x - y if not others.isempty(): self.exprs_gei.append(-Expr.Ic(newvar + newvark, others, x + y)) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) self.exprs_gei.append(gap.copy() + Expr.I(x, y) - Expr.H(newvark)) if not noaux: self.auxi += newvar + newvark return newvar, newvark def add_esfrl(self, x, y, gap = None, noaux = True): self.imp_flip() r = self.add_esfrl_imp(x, y, gap, noaux) self.imp_flip() return r def check_getaux_sfrl(self, sfrl_level = None, sfrl_minsize = 0, sfrl_maxsize = None, sfrl_gap = None, hint_pair = None, hint_aux = None): """Return whether implication is true, with auxiliary search results.""" verbose_sfrl = PsiOpts.settings.get("verbose_sfrl", False) if sfrl_level is None: sfrl_level = PsiOpts.settings["sfrl_level"] if sfrl_maxsize is None: sfrl_maxsize = PsiOpts.settings["sfrl_maxsize"] if sfrl_gap is None: sfrl_gap = PsiOpts.settings["sfrl_gap"] gap = None gappresent = False if sfrl_gap == "zero": gap = Expr.zero() elif sfrl_gap != "": gap = Expr.real(sfrl_gap) gappresent = True enable_multiple = (sfrl_level >= PsiOpts.SFRL_LEVEL_MULTIPLE) n = sfrl_maxsize cs = self ccomp = cs.auxi + (cs.allcomprv() - cs.aux - cs.auxi) m = ccomp.size() sfrlcomp = [[0, 0] for i in range(n)] tres = None def check_getaux_sfrl_recur(i): if i >= sfrl_minsize: cs = self.copy() csfrl = Comp.empty() for i2 in range(i): sfrlx = sum([ccomp[j] for j in range(m) if sfrlcomp[i2][0] & (1 << j) != 0]) sfrly = sum([ccomp[j] for j in range(m) if sfrlcomp[i2][1] & (1 << j) != 0]) csfrl += cs.add_sfrl_imp(sfrlx, sfrly, gap, noaux = False) if verbose_sfrl: print("========== SFRL ========= =========") print(csfrl) if enable_multiple: tres = cs.check_getaux_inplace(must_include = csfrl, hint_pair = hint_pair, hint_aux = hint_aux) else: tres = cs.check_getaux_inplace(must_include = csfrl, single_include = csfrl, hint_pair = hint_pair, hint_aux = hint_aux) if tres is not None: return tres if i == n: return None for size in range(2, m + 1): for xsize in range(1, size): for xtuple in itertools.combinations(range(m), xsize): xmask = sum([1 << x for x in xtuple]) yset = [i for i in range(m) if not (i in xtuple)] for ytuple in itertools.combinations(yset, size - xsize): ymask = sum([1 << y for y in ytuple]) sfrlcomp[i][0] = xmask sfrlcomp[i][1] = ymask tres = check_getaux_sfrl_recur(i + 1) if tres is not None: return tres if False: for xmask in range(1, (1 << m) - 1): ymask = (1 << m) - 1 - xmask while ymask != 0: sfrlcomp[i][0] = xmask sfrlcomp[i][1] = ymask tres = check_getaux_sfrl_recur(i + 1) if tres is not None: return tres ymask = (ymask - 1) & ~xmask return check_getaux_sfrl_recur(0) def substituted_dict_union_plain(self, d): d2 = dict() for key, value in d.items(): if self.ispresent(key): if isinstance(value, list): d2[key] = value else: d2[key] = [value] cons0 = self.noaux() cons = [] for subs in itertools.product(*[[(key, v) for v in value] for key, value in d2.items()]): t = cons0.substituted(list(subs)) t.add_meta("pf_note_case", ["Substitute ", CompArray(list(subs)).add_meta("omit_bracket", True).add_meta("subs", True), ":"]) cons.append(t) # cons = [self.consonly().noaux()] # taux = self.aux.copy() # for key, value in d.items(): # taux -= key # if not any(c.ispresent(key) for c in cons): # continue # if not isinstance(value, list): # value = [value] # cons = [c.substituted(key, x) for c in cons for x in value] if len(cons) == 1: cons = cons[0] else: cons = RegionOp.union(cons) return cons def substituted_dict_union(self, d): taux = self.aux.copy() for key, value in d.items(): taux -= key r = self.imp_flippedonly_noaux() >> self.consonly().noaux().substituted_dict_union_plain(d).exists(taux) if not self.auxi.isempty(): r = r.forall(self.auxi.copy()) return r def check_getaux(self, hint_pair = None, hint_aux = None): """Return whether implication is true, with auxiliary search results.""" truth = PsiOpts.settings["truth"] if truth is not None: r = None with PsiOpts(truth = None): r = (truth >> self).check_getaux(hint_pair, hint_aux) return r indreg = self.get_indreg_checked() if indreg is not None: r = None with PsiOpts(indreg_enabled = False): r = (indreg >> self).check_getaux(hint_pair, hint_aux) return r if PsiOpts.settings["maxent_lex_enabled"] and self.maxent_present(): with PsiOpts(maxent_lex_enabled = False): for cs in self.maxent_lex_self_gen(): r = cs.check_getaux(hint_pair, hint_aux) if r is not None: return r return None if self.isplain(): r = self.check() if r: return [] return None write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False) oproof = None if write_pf_enabled: oproof = PsiOpts.get_proof().copy() pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly(), c = "Claim:") PsiOpts.set_setting(proof_step_in = pf) for rr in self.check_getaux_gen(hint_pair, hint_aux): if iutil.signal_type(rr) == "": if write_pf_enabled: if not self.is_getaux_op() and len(rr) > 0: subdict = Comp.substitute_list_to_dict(rr, multi = True) if not Comp.substitute_dict_ismulti(subdict): pf = ProofObj.from_region(None, c = ["Substitute ", CompArray(subdict).add_meta("omit_bracket", True).add_meta("subs", True), ":"]) PsiOpts.set_setting(proof_add = pf) cs = self.copy() # Comp.substitute_list(cs, rr, isaux = True) cs = cs.substituted_dict_union(subdict) if cs.getaux().isempty(): with PsiOpts(proof_enabled = True): if isinstance(cs, RegionOp): cs.check() else: cs.check_plain() PsiOpts.set_setting(proof_step_out = True) return rr if write_pf_enabled: PsiOpts.set_setting(proof_step_out = True) PsiOpts.set_proof(oproof) return None def check_getaux_dict(self, multi = True, **kwargs): r = self.check_getaux(**kwargs) if r is None: return None return Comp.substitute_list_to_dict(r, multi = multi) def check_getaux_array(self, multi = True, **kwargs): r = self.check_getaux(**kwargs) if r is None: return None return CompArray(Comp.substitute_list_to_dict(r, multi = multi)).add_meta("omit_bracket", True).add_meta("subs", True) def solve(self, method = "c", full = False, proof = None, multi = True, display_reg = None, **kwargs): """Return whether implication is true, with auxiliary search results.""" truth = PsiOpts.settings["truth"] truthcomp = Comp.empty() if truth is not None: truthcomp = truth.allcomprv() if display_reg is None: display_reg = PsiOpts.settings["solve_display_reg"] if full: method = "c,-e,-c,e" if proof is None: proof = True if display_reg is None: display_reg = True methods = method.split(",") if len(methods) > 1: t0 = None for cmethod in methods: t = self.solve(method = cmethod, proof = proof, multi = multi, display_reg = display_reg, **kwargs) if t0 is None: t0 = t if t.truth is not None: return t return t0 omethod = method negate = False if method.startswith("-"): method = method[1:] negate = True cs = None if negate: cs = ~(self.forall_completed(truthcomp)) else: cs = self.copy() if method == "c": proof_kwargs = {key: val for key, val in kwargs.items() if key.startswith("proof_")} other_kwargs = {key: val for key, val in kwargs.items() if not key.startswith("proof_")} cproof = None cont = PsiOpts() if proof: cont = PsiOpts(proof_new = True, **proof_kwargs) with cont: r = cs.check_getaux(**other_kwargs) if proof: cproof = PsiOpts.get_proof() if r is None: return CheckResult([], reg = self.copy(), truth = None, method = omethod, display_reg = display_reg) return CheckResult(Comp.substitute_list_to_dict(r, multi = multi), reg = self.copy(), truth = not negate, method = omethod, getaux = r, proof = cproof, display_reg = display_reg) if method == "e": r = None try: r = cs.forall_completed(truthcomp).example() except Exception as err: r = None if r is None: return CheckResult([], reg = self.copy(), truth = None, method = omethod, display_reg = display_reg) return CheckResult([], reg = self.copy(), truth = not negate, method = omethod, model = r, display_reg = display_reg) return CheckResult([], reg = self.copy(), truth = None, method = omethod, display_reg = display_reg) def is_getaux_op(self): truth = PsiOpts.settings["truth"] if truth is not None: return True if self.isregtermpresent(): return True return False def presolve_process(self, relax = True): if relax: if PsiOpts.settings.get("presolve_aux_hull", False): self.simplify_aux_hull() elif PsiOpts.settings.get("presolve_aux_hull_quick", False): self.simplify_aux_hull(quick = True) if PsiOpts.settings.get("presolve_simplify", False): self.simplify() return self else: if PsiOpts.settings.get("presolve_aux_eq", False): self.simplify_aux_eq() elif PsiOpts.settings.get("presolve_aux_eq_quick", False): self.simplify_aux_eq(quick = True) return self def get_indreg(self, reg_add = None, skip_abscont = False): r = Region.universe() reals = list(self.allcompreal_exprlist()) + list(iutil.allcompreal_exprlist(reg_add)) for R in reals: Rt = R.get_maxent_comp() if Rt is not None: r.iand_norename(R >= Expr.H(Rt)) allcomprv = self.allcomprv() + r.allcomprv() if reg_add is not None: allcomprv = allcomprv + iutil.allcomprv(reg_add) return r & allcomprv.get_indreg(skip_abscont = skip_abscont) def get_indreg_checked(self, reg_add = None): if PsiOpts.settings["indreg_enabled"]: indreg = self.get_indreg(reg_add) if indreg is not None and not indreg.isuniverse(): return indreg return None def check_getaux_gen(self, hint_pair = None, hint_aux = None): """Generator that yields all auxiliary search results.""" truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): for rr in (truth >> self).check_getaux_gen(hint_pair, hint_aux): yield rr return indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): for rr in (indreg >> self).check_getaux_gen(hint_pair, hint_aux): yield rr return if PsiOpts.settings["maxent_lex_enabled"] and self.maxent_present(): with PsiOpts(maxent_lex_enabled = False): for cs in self.maxent_lex_self_gen(): for rr in cs.check_getaux_gen(hint_pair, hint_aux): yield rr return if self.isregtermpresent(): cs = RegionOp.inter([self]) for rr in cs.check_getaux_gen(hint_pair, hint_aux): yield rr return write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) cs = self.copy() cs.presolve_process() cs.simplify_quick(zero_group = 2) cs.split() PsiOpts.settings["proof_enabled"] = False res = None sfrl_level = PsiOpts.settings["sfrl_level"] hint_aux_avoid = self.get_aux_avoid_list() for rr in cs.check_getaux_inplace_gen(hint_pair = hint_pair, hint_aux = hint_aux, hint_aux_avoid = hint_aux_avoid): PsiOpts.settings["proof_enabled"] = write_pf_enabled yield rr PsiOpts.settings["proof_enabled"] = False if sfrl_level > 0: res = cs.check_getaux_sfrl(sfrl_minsize = 1, hint_pair = hint_pair, hint_aux = hint_aux) if res is not None: PsiOpts.settings["proof_enabled"] = write_pf_enabled yield res PsiOpts.settings["proof_enabled"] = False PsiOpts.settings["proof_enabled"] = write_pf_enabled def check(self): """Return whether implication is true""" if iutil.get_solver() == "z3": return self.check_z3() truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).check() indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).check() if PsiOpts.settings["maxent_lex_enabled"] and self.maxent_present(): with PsiOpts(maxent_lex_enabled = False): for cs in self.maxent_lex_self_gen(): if cs.check(): return True return False if self.isplain(): return self.check_plain() return self.check_getaux() is not None def assumption(self, mode = None): """Retrieve the strengthened assumptions for proving this region. The implication in this region must be true if this assumption does not hold. """ return RegionOp.inter([self]).assumption(mode = mode) def truth(self): """The region given by the assumptions. """ r = Region.universe() truth = PsiOpts.settings["truth"] if truth is not None: r &= truth indreg = self.get_indreg_checked() if indreg is not None: r &= indreg return r def evalcheck(self, f): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).evalcheck(f) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).evalcheck(f) ceps = PsiOpts.settings["eps_check"] for x in self.exprs_gei: if not float(f(x)) >= -ceps: return True for x in self.exprs_eqi: if not abs(float(f(x))) <= ceps: return True for x in self.exprs_ge: if not float(f(x)) >= -ceps: return False for x in self.exprs_eq: if not abs(float(f(x))) <= ceps: return False return True def eval_max_violate(self, f): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).eval_max_violate(f) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).eval_max_violate(f) ceps = PsiOpts.settings["eps_check"] for x in self.exprs_gei: t = float(f(x)) if not numpy.isnan(t) and not t >= -ceps: return 0.0 for x in self.exprs_eqi: t = float(f(x)) if not numpy.isnan(t) and not abs(t) <= ceps: return 0.0 r = 0.0 for x in self.exprs_ge: t = float(f(x)) if numpy.isnan(t): return numpy.inf r = max(r, -t) for x in self.exprs_eq: t = float(f(x)) if numpy.isnan(t): return numpy.inf r = max(r, abs(t)) return r def eval_sum_violate(self, f, pow = 1, leak = 0.1): # truth = PsiOpts.settings["truth"] # if truth is not None: # with PsiOpts(truth = None): # return (truth >> self).eval_sum_violate(f, pow = pow) r = 0.0 for x in self.exprs_ge: t = ConcReal.unbox(f(x)) if numpy.isnan(float(t)): continue if t < 0: r = r + (-t) ** pow else: r = r - (t ** pow) * leak for x in self.exprs_eq: t = ConcReal.unbox(f(x)) if numpy.isnan(float(t)): continue if t < 0: r = r + (-t) ** pow else: r = r + t ** pow return r def expr_sum_violate(self, *args, **kwargs): return Expr.fcn(lambda P: self.eval_sum_violate(P, *args, **kwargs)) def example(self, card = None, denom = None): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth & self).example(card = card) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg & self).example(card = card) cs = self.exists(self.allcomprv() - self.getaux() - self.getauxi()) if card is None: card = PsiOpts.settings["opt_example_card"] if card is None or isinstance(card, int): card = [card] if denom is None: denom = PsiOpts.settings["max_denom_try"] for ccard in card: P = ConcModel() cont = PsiOpts(opt_num_points_mul = PsiOpts.settings["example_opt_num_points_mul"]) if ccard is not None: cont = PsiOpts(opt_num_points_mul = PsiOpts.settings["example_opt_num_points_mul"], opt_aux_card = ccard) with cont: if not P[cs]: continue if denom > 0: P2 = P.opt_model().copy() P2.fraction_snap(denom = denom, eps = numpy.inf) if P2[cs.noaux()]: P2.set_force_float(5) return P2 return P.opt_model() return None def implies(self, other, quick = False, bnet = None, **kwargs): """Whether self implies other""" if kwargs: k0, k1 = PsiOpts.setting_strengthen_split(kwargs) cs = self if k0: with PsiOpts(**{"simplify_" + key: val for key, val in k0.items()}): cs = cs.simplified() if k1: with PsiOpts(**{"simplify_" + key: val for key, val in k1.items()}): other = other.simplified() return cs.implies(other, quick = quick) if quick: return (self <= other).check_quick(bnet = bnet) else: return (self <= other).check() def implies_getaux(self, other, hint_pair = None, hint_aux = None): """Whether self implies other, with auxiliary search result""" res = (self <= other).check_getaux(hint_pair, hint_aux) if res is None: return None return res #auxlist = other.aux.varlist + self.auxi.varlist #return [(Comp([auxlist[i]]), res[i][1]) for i in range(len(res))] def implies_getaux_gen(self, other, hint_pair = None, hint_aux = None): """Whether self implies other, yield all auxiliary search result""" for rr in (self <= other).check_getaux_gen(hint_pair, hint_aux): yield rr def equiv(self, other, quick = False, **kwargs): """Whether self is equivalent to other""" return self.implies(other, quick = quick, **kwargs) and other.implies(self, quick = quick, **kwargs) # def allcomp(self): # index = IVarIndex() # self.record_to(index) # return index.comprv + index.compreal # def allcomprv(self): # index = IVarIndex() # self.record_to(index) # return index.comprv # def allcompreal(self): # index = IVarIndex() # self.record_to(index) # return index.compreal def indep_components(self): return self.get_bayesnet().indep_components() def abscont_preserved(self): return self.get_ic(include_ic = False, include_hc = True) <= 0 def maxent_lex_region(self, a, s, varadd = None): """Region given by the lex operation. Mokshay Madiman, Adam W Marcus, and Prasad Tetali, "Entropy and set cardinality inequalities for partition-determined functions", Random Structures & Algorithms 40, 4 (2012), pp. 399--424. """ vars = self.allcomprv() + iutil.allcomprv(a) + iutil.allcomprv(s) + iutil.allcomprv(varadd) indreg = self.get_indreg_checked(reg_add = vars) if indreg is not None: with PsiOpts(indreg_enabled = False): return (self & indreg).maxent_lex_region(a, s, varadd = varadd) r = Region.universe() for b in vars: for mask in range(1 << len(s)): s0 = sum((x for i, x in enumerate(s) if mask & (1 << i)), Comp.empty()) s1 = sum((x for i, x in enumerate(s) if not mask & (1 << i)), Comp.empty()) if (self >> (~Expr.Hc(b, s0) & ~Expr.Hc(a, b + s1))).check_plain(): r.iand_norename(~Expr.Hc(s0, b)) return r def minimal_dependency(self, x, ys): if not (self >> ~Expr.Hc(x, sum(ys, Comp.empty()))).check_plain(): return None r = (1 << len(ys)) - 1 for i in range(len(ys)): a = sum((b for j, b in enumerate(ys) if r & (1 << j) and j != i), Comp.empty()) if (self >> ~Expr.Hc(x, a)).check_plain(): r -= 1 << i return r def maxent_lex_region_gen(self, vs, varadd = None, nochange = None): """Region given by the lex operation. Mokshay Madiman, Adam W Marcus, and Prasad Tetali, "Entropy and set cardinality inequalities for partition-determined functions", Random Structures & Algorithms 40, 4 (2012), pp. 399--424. """ if nochange is None: nochange = Comp.empty() vars = self.allcomprv() + iutil.allcomprv(vs) + iutil.allcomprv(varadd) indreg = self.get_indreg_checked(reg_add = vars) if indreg is not None: with PsiOpts(indreg_enabled = False): for r in (self & indreg).maxent_lex_region_gen(vs, varadd = varadd, nochange = nochange): yield r return yield self icomps = self.indep_components() if len(icomps) <= 1: icomps = [vars] mask_force = 0 cvs = [] for v, force in vs: if nochange.ispresent(v): if force: return continue dep = self.minimal_dependency(v, icomps) if dep is None: if force: return continue deplist = [y for j, y in enumerate(icomps) if dep & (1 << j)] # if not force and len(deplist) <= 1: # continue if any(nochange.ispresent(x) for x in deplist): if force: return continue treg = self.maxent_lex_region(v, deplist, varadd = varadd) if not force and treg.isuniverse(): continue if force: mask_force += 1 << len(cvs) cvs.append((v, dep, treg, force)) for mask in range(1, 1 << len(cvs)): if mask & mask_force != mask_force: continue tvs = [] sdep = 0 r = Region.universe() for i, (v, dep, treg, force) in enumerate(cvs): if not mask & (1 << i): continue if sdep & dep: sdep = -1 break sdep |= dep depsum = sum((y for j, y in enumerate(icomps) if dep & (1 << j)), Comp.empty()) tvs.append(depsum) r.iand_norename((H0(v) == H(v)).add_meta("pf_note", ["lex ", depsum])) r.iand_norename(treg) if sdep == -1: continue depleft = [y for j, y in enumerate(icomps) if not sdep & (1 << j)] cs = Region.universe() if depleft: cs = self.copy() cs.remove_relax(vars - sum(depleft, Comp.empty())) r = (r & cs & self.abscont_preserved() & self.get_indreg(skip_abscont = True) & indep(*depleft, *tvs)) r = r.simplified_quick() yield r def maxent_ub_list(self): r = [] for x in self.exprs_ge: for a, c in x.terms: if c >= 0: continue a2 = a.get_maxent_comp() if a2 is not None and a2 not in r: r.append(a2) for x in self.exprs_eq: for a, c in x.terms: a2 = a.get_maxent_comp() if a2 is not None and a2 not in r: r.append(a2) return r def abscont_nochange_comp(self): r = Comp.empty() for x in self.exprs_ge + self.exprs_eq: for a, c in x.terms: if a.get_maxent_comp() is None: r += a.allcomprv() return r def maxent_present(self): return any(x.get_maxent_comp() is not None for x in self.allcompreal_exprlist()) def maxent_lex_self_gen(self, varadd = None): vs = [(x, True) for x in self.maxent_ub_list()] for tr in self.imp_flippedonly_noaux().maxent_lex_region_gen(vs, varadd=varadd, nochange=self.abscont_nochange_comp()): r = (tr >> self.consonly()).forall(self.auxi) # print(r) yield r def allcomprealvar(self): r = Comp.empty() for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: r += x.allcomprealvar() return r # index = IVarIndex() # self.record_to(index) # return index.compreal - Comp([IVar.eps(), IVar.one()]) def allcompreal_exprlist(self): r = ExprArray([]) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: r.iadd_noduplicate(x.allcompreal_exprlist()) return r def allcomprealvar_exprlist(self): r = ExprArray([]) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: r.iadd_noduplicate(x.allcomprealvar_exprlist()) return r def allcomprv_noaux(self): return self.allcomprv() - self.getauxall() def aux_remove(self): self.aux = Comp.empty() self.auxi = Comp.empty() return self def completed_semigraphoid_ic(self, vs = None, max_iter = None): verbose = PsiOpts.settings.get("verbose_semigraphoid", False) if vs is not None: tvs = [] for t in vs: if isinstance(t, Comp): tvs.append((t, t)) else: tvs.append(t) if len(tvs) == 0: return Expr.zero() vs = tvs def mask_impl(a, b): a0, a1, az = a b0, b1, bz = b a0x = a0 & (bz & ~az) a0 &= ~a0x az |= a0x a1x = a1 & (bz & ~az) a1 &= ~a1x az |= a1x if az != bz: return False return a0 | b0 == a0 and a1 | b1 == a1 icexpr = self.get_ic() index = IVarIndex() icexpr.record_to(index) icl = set() if verbose: print("========== SEMIGRAPHOID =========") print(icexpr) print("=====================================") for a, c in icexpr.terms: if len(a.x) == 1: mz = index.get_mask(a.z) m0 = index.get_mask(a.x[0]) & ~mz icl.add((m0, m0, mz)) elif len(a.x) == 2: mz = index.get_mask(a.z) m0 = index.get_mask(a.x[0]) & ~mz m1 = index.get_mask(a.x[1]) & ~mz if m0 > m1: m0, m1 = m1, m0 icl.add((m0, m1, mz)) #for a0, a1, az in icl: # print(str(Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az)))) citer = 0 iclw = icl.copy() did = True while did: did = False icl2 = icl.copy() for a0k, a1k, azk in icl: if max_iter is not None and citer > max_iter: break for b0k, b1k, bzk in icl: if max_iter is not None and citer > max_iter: break if (a0k, a1k, azk) == (b0k, b1k, bzk): continue if azk & ~(b0k | b1k | bzk) != 0: continue if bzk & ~(a0k | a1k | azk) != 0: continue for aj in range(2): for bj in range(2): citer += 1 if max_iter is not None and citer > max_iter: break a0, a1, az = a0k, a1k, azk b0, b1, bz = b0k, b1k, bzk if aj != 0: a0, a1 = a1, a0 a0o, a1o, azo = a0, a1, az if bj != 0: b0, b1 = b1, b0 b0o, b1o, bzo = b0, b1, bz if a0 & b0 == 0: continue b0x = b0 & (az & ~bz) b0 &= ~b0x bz |= b0x b1z = b1 | bz a0x = a0 & (b1z & ~az) a0 &= ~a0x az |= a0x a1x = a1 & (b1z & ~az) a1 &= ~a1x az |= a1x b1 &= az a0 &= b0 & ~bz a1 &= ~bz b1 &= ~bz # if verbose: # print(str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo))) # + " - " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo))) # + " : " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz))) # + " / " + str(index.from_mask(az)) + "=" + str(index.from_mask(b1 | bz)) # ) if az != b1 | bz: continue if a0 == 0 or a1 | b1 == 0: continue t = (a0, a1 | b1, bz) if mask_impl((a0o, a1o, azo), t) or mask_impl((b0o, b1o, bzo), t): continue if verbose: print(str(citer) + ": " + str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo))) + " & " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo))) + " -> " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz))) ) if t[0] > t[1]: t = (t[1], t[0], t[2]) if t in iclw: continue #print(str(Expr.Ic(index.from_mask(t[0]), index.from_mask(t[1]), index.from_mask(t[2])))) icl2.add(t) iclw.add(t) did = True icl.clear() for a0, a1, az in icl2: for b0, b1, bz in icl2: if (a0, a1, az) == (b0, b1, bz): continue if (mask_impl((b0, b1, bz), (a0, a1, az)) or mask_impl((b0, b1, bz), (a1, a0, az))): break else: icl.add((a0, a1, az)) if vs is not None: vmasks = [index.get_mask(b) for a, b in vs] vvars = CompArray([a for a, b in vs]) icl2 = set() for a0, a1, az in icl: b0 = 0 b1 = 0 bz = 0 for i in range(len(vs)): if a0 | vmasks[i] == a0: b0 |= 1 << i if a1 | vmasks[i] == a1: b1 |= 1 << i if b0 == 0 or b1 == 0: continue def recur(i, zmask, azvis): if i == len(vs): if azvis & az == az: c0 = b0 & ~zmask c1 = b1 & ~zmask if c0 != 0 and c1 != 0: if c0 > c1: c0, c1 = c1, c0 icl2.add((c0, c1, zmask)) return vm = vmasks[i] recur(i + 1, zmask, azvis) if vm & az != 0 and (a0 | a1 | az) & vm == vm: recur(i + 1, zmask + (1 << i), azvis | vm) recur(0, 0, 0) icl.clear() for a0, a1, az in icl2: for b0, b1, bz in icl2: if (a0, a1, az) == (b0, b1, bz): continue if (mask_impl((b0, b1, bz), (a0, a1, az)) or mask_impl((b0, b1, bz), (a1, a0, az))): break else: icl.add((a0, a1, az)) r = Expr.zero() for a0, a1, az in icl: r += Expr.Ic(vvars.from_mask(a0), vvars.from_mask(a1), vvars.from_mask(az)) return r r = Expr.zero() for a0, a1, az in icl: r += Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az)) return r def completed_semigraphoid_ic_new(self, vs = None, max_iter = None, include_hc = True): verbose = PsiOpts.settings.get("verbose_semigraphoid", False) if vs is not None: tvs = [] for t in vs: if isinstance(t, Comp): tvs.append((t, t)) else: tvs.append(t) if len(tvs) == 0: return Expr.zero() vs = tvs def mask_impl(a, b): a0, a1, az = a b0, b1, bz = b a0x = a0 & (bz & ~az) a0 &= ~a0x az |= a0x a1x = a1 & (bz & ~az) a1 &= ~a1x az |= a1x if az != bz: return False return a0 | b0 == a0 and a1 | b1 == a1 icexpr = self.get_ic(include_hc = include_hc) index = IVarIndex() icexpr.record_to(index) icl = set() nvar = len(index.comprv) # fcn_masks = [1 << i for i in range(nvar)] fcns = [] def fcn_expand(x): did = True while did: did = False for mz, m0 in fcns: if x & mz == mz and x | m0 != x: x |= m0 did = True return x def clean(a): a0, a1, az = a # for i in range(nvar): # if a0 & (1 << i): # a0 |= fcn_masks[i] # if a1 & (1 << i): # a1 |= fcn_masks[i] # if az & (1 << i): # az |= fcn_masks[i] az = fcn_expand(az) a0 = fcn_expand(a0 | az) & ~az a1 = fcn_expand(a1 | az) & ~az if a0 > a1: a0, a1 = a1, a0 return (a0, a1, az) if verbose: print("========== SEMIGRAPHOID =========") print(icexpr) print("=====================================") for a, c in icexpr.terms: if len(a.x) == 1: mz = index.get_mask(a.z) m0 = index.get_mask(a.x[0]) fcns.append((mz, mz | m0)) fcns = [(mz, fcn_expand(m0)) for mz, m0 in fcns] for a, c in icexpr.terms: if len(a.x) == 1: pass # mz = index.get_mask(a.z) # m0 = index.get_mask(a.x[0]) & ~mz # m1 = ((1 << nvar) - 1) & ~mz # # icl.add((m0, m0, mz)) # icl.add((m0, m1, mz)) elif len(a.x) == 2: mz = index.get_mask(a.z) m0 = index.get_mask(a.x[0]) & ~mz m1 = index.get_mask(a.x[1]) & ~mz # if m0 > m1: # m0, m1 = m1, m0 icl.add(clean((m0, m1, mz))) # for a0, a1, az in icl: # print(str(Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az)))) citer = 0 iclw = icl.copy() did = True while did: did = False icl2 = icl.copy() # Check contraction axiom for a0k, a1k, azk in icl: if max_iter is not None and citer > max_iter: break if PsiOpts.is_timer_ended(): break for b0k, b1k, bzk in icl: if max_iter is not None and citer > max_iter: break if (a0k, a1k, azk) == (b0k, b1k, bzk): continue if azk & ~(b0k | b1k | bzk) != 0: continue if bzk & ~(a0k | a1k | azk) != 0: continue for aj in range(2): for bj in range(2): citer += 1 if max_iter is not None and citer > max_iter: break a0, a1, az = a0k, a1k, azk b0, b1, bz = b0k, b1k, bzk if aj != 0: a0, a1 = a1, a0 a0o, a1o, azo = a0, a1, az if bj != 0: b0, b1 = b1, b0 b0o, b1o, bzo = b0, b1, bz if a0 & b0 == 0: continue b0x = b0 & (az & ~bz) b0 &= ~b0x bz |= b0x b1z = b1 | bz a0x = a0 & (b1z & ~az) a0 &= ~a0x az |= a0x a1x = a1 & (b1z & ~az) a1 &= ~a1x az |= a1x b1 &= az a0 &= b0 & ~bz a1 &= ~bz b1 &= ~bz # if verbose: # REMOVE # print(str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo))) # + " - " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo))) # + " : " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz))) # + " / " + str(index.from_mask(az)) + "=" + str(index.from_mask(b1 | bz)) # ) if az != b1 | bz: continue if a0 == 0 or a1 | b1 == 0: continue t = clean((a0, a1 | b1, bz)) if mask_impl((a0o, a1o, azo), t) or mask_impl((b0o, b1o, bzo), t): continue if verbose: print(str(citer) + ": " + str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo))) + " & " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo))) + " -> " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz))) ) # if t[0] > t[1]: # t = (t[1], t[0], t[2]) if t in iclw: continue #print(str(Expr.Ic(index.from_mask(t[0]), index.from_mask(t[1]), index.from_mask(t[2])))) icl2.add(t) iclw.add(t) did = True icl.clear() for a0, a1, az in icl2: for b0, b1, bz in icl2: if (a0, a1, az) == (b0, b1, bz): continue if (mask_impl((b0, b1, bz), (a0, a1, az)) or mask_impl((b0, b1, bz), (a1, a0, az))): break else: icl.add((a0, a1, az)) if vs is not None: vmasks = [index.get_mask(b) for a, b in vs] vvars = CompArray([a for a, b in vs]) icl2 = set() for a0, a1, az in icl: b0 = 0 b1 = 0 bz = 0 for i in range(len(vs)): if a0 | vmasks[i] == a0: b0 |= 1 << i if a1 | vmasks[i] == a1: b1 |= 1 << i if b0 == 0 or b1 == 0: continue def recur(i, zmask, azvis): if i == len(vs): if azvis & az == az: c0 = b0 & ~zmask c1 = b1 & ~zmask if c0 != 0 and c1 != 0: if c0 > c1: c0, c1 = c1, c0 icl2.add((c0, c1, zmask)) return vm = vmasks[i] recur(i + 1, zmask, azvis) if vm & az != 0 and (a0 | a1 | az) & vm == vm: recur(i + 1, zmask + (1 << i), azvis | vm) recur(0, 0, 0) icl.clear() for a0, a1, az in icl2: for b0, b1, bz in icl2: if (a0, a1, az) == (b0, b1, bz): continue if (mask_impl((b0, b1, bz), (a0, a1, az)) or mask_impl((b0, b1, bz), (a1, a0, az))): break else: icl.add((a0, a1, az)) r = Expr.zero() for a0, a1, az in icl: r += Expr.Ic(vvars.from_mask(a0), vvars.from_mask(a1), vvars.from_mask(az)) for mz, m0 in fcns: r += Expr.Hc(vvars.from_mask(m0 & ~mz), vvars.from_mask(mz)) return r r = Expr.zero() for a0, a1, az in icl: r += Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az)) for mz, m0 in fcns: r += Expr.Hc(index.from_mask(m0 & ~mz), index.from_mask(mz)) return r def completed_semigraphoid(self, vs = None, max_iter = None): """ Use semi-graphoid axioms to deduce more conditional independence. Judea Pearl and Azaria Paz, "Graphoids: a graph-based logic for reasoning about relevance relations", Advances in Artificial Intelligence (1987), pp. 357--363. """ return self.completed_semigraphoid_ic(vs = vs, max_iter = max_iter) <= 0 def eliminated_ic(self, w): icexpr = self.get_ic() index = IVarIndex() icexpr.record_to(index) icl = set() r = Expr.zero() for a, c in icexpr.terms: if a.z.ispresent(w): continue b = a.copy() for i in range(len(b.x)): b.x[i] -= w if not b.iszero(): r += Expr.fromterm(b) return r <= 0 def completed_sfrl(self, gap = None, max_iter = None): index = IVarIndex() self.record_to(index) n = len(index.comprv) cs = self.copy() tmpvar = Comp.empty() for i in range(n): for j in range(n): if i == j: continue tmpvar += cs.add_sfrl(index.comprv[i], index.comprv[j], gap, noaux = False) cs2 = cs.completed_semigraphoid(max_iter = max_iter) cs3 = cs2.eliminated_ic(tmpvar) return cs3 def convexified(self, v = None, bnet = None, q = None, inp = False, forall = False): """Convexify with respect to random variables v by a time sharing RV q, return result""" r = self.copy() if r.isregtermpresent(): r = r.flattened(minmax_elim = True) r = r.tosimple_safe() if r is None: return None v = Region.get_allcomp(v) qname = "Q" if q is not None: qname = str(q) qname = r.name_avoid(qname) q = Comp.rv(qname) allcomp = r.allcomprv() r.condition(q) cmi = None if inp: cmi = Expr.Ic(q, allcomp - r.getauxall() - v - r.inp - r.oup, r.inp) == 0 elif bnet is None: cmi = Expr.I(q, allcomp - r.getauxall() - v) == 0 else: bnet2 = None if isinstance(bnet, BayesNet): bnet2 = bnet.copy() else: bnet2 = BayesNet(bnet) for v2 in v: bnet2 += (q, v2) cmi = bnet2.get_region() if forall: r |= ~cmi return r.forall(q) else: r &= cmi return r.exists(q) def tounion(self): return RegionOp.pack_type(self, RegionType.UNION) def convexified_diag(self, v = None, cross_only = True, skip_simplify = False): """Convexify with respect to the real variables in v along the diagonal, return result""" if v is None: v = self.allcomprealvar_exprlist() index = IVarIndex() self.record_to(index) for x in v: x.record_to(index) namemap = [dict(), dict()] v_new = [[], []] toelim = Expr.zero() for x in v: cname = x.get_name() for it in range(2): nname = index.name_avoid(cname) namemap[it][cname] = nname v_new[it].append(Expr.real(nname)) Expr.real(nname).record_to(index) toelim += Expr.real(nname) r = RegionOp.empty() ru = self.tounion() if cross_only: r = ru.copy() for i, j in itertools.combinations(range(len(v)), 2): v0j = v[i] + v[j] - v_new[0][i] v1j = v[i] + v[j] - v_new[1][i] for k0, a0 in enumerate(ru): a0t = a0.copy() a0t.substitute(v[i], v_new[0][i]) a0t.substitute(v[j], v0j) for k1, a1 in enumerate(ru): if cross_only and k0 == k1: continue a1t = a1.copy() a1t.substitute(v[i], v_new[1][i]) a1t.substitute(v[j], v1j) t = a0t & a1t t &= (v[i] >= v_new[0][i]) & (v[i] <= v_new[1][i]) #print(str(i) + " , " + str(j)) r |= t.exists(v_new[0][i]+v_new[1][i]) if not skip_simplify: return r.simplified() else: return r def isconvex(self, v = None, bnet = None, inp = False): """Check whether region is convex with respect to random variables v and real variables. False return value does NOT necessarily mean region is not convex """ t = self.convexified(v, bnet, inp = inp) if t is None: return False return t.implies(self) #return ((self + self) / 2).implies(self) def clean_eps(self): if not self.ispresent(IVar.eps()): return for a in self.exprs_ge + self.exprs_gei: eps_coeff = a.get_coeff(Term.eps()) a.substitute(Expr.eps(), Expr.zero()) if eps_coeff < 0: a -= Expr.eps() def simplify_quick(self, reg = None, zero_group = 0): """Simplify a region in place, without linear programming Optional argument reg with constraints assumed to be true zero_group = 2: group all nonnegative terms as a single inequality """ if not PsiOpts.settings.get("simplify_enabled", False): return self write_pf_enabled = (PsiOpts.settings.get("proof_enabled", False) and PsiOpts.settings.get("proof_step_simplify", False)) if write_pf_enabled: prevself = self.copy() #self.remove_missing_aux() if reg is None: reg = Region.universe() for x in self.exprs_ge: x.simplify_quick(reg = reg) for x in self.exprs_eq: x.simplify_quick(reg = reg) self.clean_eps() index = IVarIndex() self.record_to(index) gemask = [index.get_mask(x.allcomprv_shallow()) for x in self.exprs_ge] eqmask = [index.get_mask(x.allcomprv_shallow()) for x in self.exprs_eq] did = True if True: did = False for i in range(len(self.exprs_ge)): if not self.exprs_ge[i].iszero(): for j in range(i): if not self.exprs_ge[j].iszero() and gemask[i] == gemask[j]: ratio = self.exprs_ge[i].get_ratio(self.exprs_ge[j], skip_simplify = True) if ratio is None: continue if ratio > PsiOpts.settings["eps"]: self.exprs_ge[i] = Expr.zero() gemask[i] = 0 did = True break elif ratio < -PsiOpts.settings["eps"]: self.exprs_eq.append(self.exprs_ge[i]) eqmask.append(gemask[i]) self.exprs_ge[i] = Expr.zero() gemask[i] = 0 self.exprs_ge[j] = Expr.zero() gemask[j] = 0 did = True break for i in range(len(self.exprs_ge)): if not self.exprs_ge[i].iszero(): for j in range(len(self.exprs_eq)): if not self.exprs_eq[j].iszero() and gemask[i] == eqmask[j]: ratio = self.exprs_ge[i].get_ratio(self.exprs_eq[j], skip_simplify = True) if ratio is None: continue self.exprs_ge[i] = Expr.zero() gemask[i] = 0 did = True break for i in range(len(self.exprs_eq)): if not self.exprs_eq[i].iszero(): for j in range(i): if not self.exprs_eq[j].iszero() and eqmask[i] == eqmask[j]: ratio = self.exprs_eq[i].get_ratio(self.exprs_eq[j], skip_simplify = True) if ratio is None: continue self.exprs_eq[i] = Expr.zero() eqmask[i] = 0 did = True break self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] for i in range(len(self.exprs_ge)): if self.exprs_ge[i].isnonneg(): self.exprs_ge[i] = Expr.zero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] if True: allzero = Expr.zero() for i in range(len(self.exprs_ge)): if self.exprs_ge[i].isnonpos(): for (a, c) in self.exprs_ge[i].terms: allzero -= Expr.fromterm(a) self.exprs_ge[i] = Expr.zero() for i in range(len(self.exprs_eq)): if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg(): for (a, c) in self.exprs_eq[i].terms: allzero -= Expr.fromterm(a) self.exprs_eq[i] = Expr.zero() if not allzero.iszero(): allzero.simplify_quick(reg = reg) allzero.sortIc() #self.exprs_ge.append(allzero) self.exprs_ge.insert(0, allzero) if zero_group == 2: pass else: for i in range(len(self.exprs_ge)): if self.exprs_ge[i].isnonpos(): for (a, c) in self.exprs_ge[i].terms: if zero_group == 1: self.exprs_ge.append(-Expr.fromterm(a)) else: self.exprs_eq.append(Expr.fromterm(a)) self.exprs_ge[i] = Expr.zero() for i in range(len(self.exprs_eq)): if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg(): for (a, c) in self.exprs_eq[i].terms: if zero_group == 1: self.exprs_ge.append(-Expr.fromterm(a)) else: self.exprs_eq.append(Expr.fromterm(a)) self.exprs_eq[i] = Expr.zero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] for x in self.exprs_ge: x.simplify_mul(1) for x in self.exprs_eq: x.simplify_mul(2) if self.imp_present(): t = self.imp_flippedonly() t.simplify_quick(reg, zero_group) self.exprs_gei = t.exprs_ge self.exprs_eqi = t.exprs_eq for x in self.exprs_ge: if self.implies_ineq_quick(x, ">="): x.setzero() for x in self.exprs_eq: if self.implies_ineq_quick(x, "=="): x.setzero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] if write_pf_enabled: if self.tostring() != prevself.tostring(): # pf = ProofObj.from_region(prevself, c = "Simplify") # pf += ProofObj.from_region(self, c = "Simplified as") pf = ProofObj.from_region(("equiv", prevself, self), c = "Simplify:") PsiOpts.set_setting(proof_add = pf) return self def iand_simplify_quick(self, other, skip_simplify = True): did = False for x in other.exprs_ge: if not self.implies_ineq_cons_hash(x, ">="): self.exprs_ge.append(x) did = True for x in other.exprs_eq: if not self.implies_ineq_cons_hash(x, "=="): self.exprs_eq.append(x) did = True if not skip_simplify and did: self.simplify_quick(zero_group = 1) return self def split_ic2(self): ge_insert = [] for i in range(len(self.exprs_ge)): if self.exprs_ge[i].isnonpos(): t = Expr.zero() for (a, c) in self.exprs_ge[i].terms: if a.isic2(): ge_insert.append(-Expr.fromterm(a)) else: t += Expr.fromterm(a) * c self.exprs_ge[i] = t for i in range(len(self.exprs_eq)): if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg(): t = Expr.zero() for (a, c) in self.exprs_eq[i].terms: if a.isic2(): ge_insert.append(-Expr.fromterm(a)) else: t += Expr.fromterm(a) * c self.exprs_eq[i] = t self.exprs_ge = ge_insert + self.exprs_ge self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] def split(self): ge_insert = [] for i in range(len(self.exprs_ge)): if self.exprs_ge[i].isnonpos(): for (a, c) in self.exprs_ge[i].terms: ge_insert.append(-Expr.fromterm(a)) self.exprs_ge[i] = Expr.zero() for i in range(len(self.exprs_eq)): if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg(): for (a, c) in self.exprs_eq[i].terms: ge_insert.append(-Expr.fromterm(a)) self.exprs_eq[i] = Expr.zero() self.exprs_ge = ge_insert + self.exprs_ge self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] if self.imp_present(): t = self.imp_flippedonly() t.split() self.exprs_gei = t.exprs_ge self.exprs_eqi = t.exprs_eq @staticmethod def intersection_interleave(rlist): r = Region.universe() r.exprs_ge = iutil.list_interleave(x.exprs_ge for x in rlist) r.exprs_eq = iutil.list_interleave(x.exprs_eq for x in rlist) r.exprs_gei = iutil.list_interleave(x.exprs_gei for x in rlist) r.exprs_eqi = iutil.list_interleave(x.exprs_eqi for x in rlist) r.aux = sum(iutil.list_interleave(x.aux for x in rlist), Comp.empty()) r.auxi = sum(iutil.list_interleave(x.auxi for x in rlist), Comp.empty()) return r def symmetrized(self, symm_set, union = False, convexify = False, skip_simplify = False): if symm_set is None: return self r = Region.universe() if union: r = RegionOp.empty() cs = self.copy() n = len(symm_set) m = min(len(a) for a in symm_set) tmpvar = [[] for i in range(n)] for i in range(n): for j in range(m): if isinstance(symm_set[i][j], Comp): tmpvar[i].append(Comp.rv("#TMPVAR" + str(i) + "_" + str(j))) else: tmpvar[i].append(Expr.real("#TMPVAR" + str(i) + "_" + str(j))) cs.substitute(symm_set[i][j], tmpvar[i][j]) tcss = [] for p in itertools.permutations(range(n)): tcs = cs.copy() for i in range(n): for j in range(m): tcs.substitute(tmpvar[i][j], symm_set[p[i]][j]) tcss.append(tcs) if union: for tcs in tcss: r |= tcs else: r = Region.intersection_interleave(tcss) if convexify: r = r.convexified_diag(skip_simplify = True) if skip_simplify: return r else: return r.simplified() def simplify_bayesnet(self, reg = None, reduce_ic = False): if isinstance(reg, RegionOp): reg = reg.tosimple_noaux() if reg is None: reg = Region.universe() icexpr = Expr.zero() for x in self.exprs_ge: if x.isnonpos(): icexpr += x for x in self.exprs_eq: if x.isnonpos(): icexpr += x elif x.isnonneg(): icexpr -= x if reg.isuniverse() and icexpr.iszero(): return bnet = (reg & (icexpr >= 0)).get_bayesnet(skip_simplify = True) for x in self.exprs_ge: # Prevent circular simplification if not x.isnonpos(): x.simplify_quick(bnet = bnet) for x in self.exprs_eq: # Prevent circular simplification if not (x.isnonpos() or x.isnonneg()): x.simplify_quick(bnet = bnet) if reduce_ic: for x in self.exprs_ge: if x.isnonpos(): if bnet.check_ic(-x): x.setzero() for x in self.exprs_eq: if x.isnonpos(): if bnet.check_ic(-x): x.setzero() elif x.isnonneg(): if bnet.check_ic(x): x.setzero() self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.iand_norename(bnet.get_region()) return self def zero_exprs(self, avoid = None): for x in self.exprs_ge: if x is avoid: continue if x.isnonpos(): for y in x: yield y for x in self.exprs_eq: if x is avoid: continue if x.isnonpos() or x.isnonneg(): for y in x: yield y else: yield x def empty_rvs(self): r = Comp.empty() for x in self.zero_exprs(): if len(x) == 1 and x.terms[0][1] != 0 and x.terms[0][0].ish(): r += x.terms[0][0].x[0] return r def simplify_pair(self, reg = None): did = False if isinstance(reg, RegionOp): reg = reg.tosimple_noaux() if reg is None: reg = Region.universe() for x, xs in [(a, ">=") for a in self.exprs_ge] + [(a, "==") for a in self.exprs_eq]: xcomp = x.complexity() for y in igen.pm(itertools.chain(self.zero_exprs(avoid = x), reg.zero_exprs())): x2 = (x + y).simplified_quick() x2comp = x2.complexity() if x2comp < xcomp: did = True x.copy_(x2) xcomp = x2comp for z in reg.empty_rvs(): did = True self.substitute(z, Comp.empty()) for z in self.empty_rvs(): did = True self.substitute(z, Comp.empty()) self.exprs_eq.append(Expr.H(z)) if did: self.simplify_quick() return self def simplify_redundant(self, reg = None, proc = None, full = True, quick = False, bnet = None): write_pf_enabled = (PsiOpts.settings.get("proof_enabled", False) and PsiOpts.settings.get("proof_step_simplify", False)) aux_relax = PsiOpts.settings.get("simplify_aux_relax", False) if write_pf_enabled: prevself = self.copy() red_reg = Region.universe() prev_write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) PsiOpts.settings["proof_enabled"] = False if reg is None: reg = Region.universe() #if self.isregtermpresent(): # return self allcompreal = self.allcompreal() + reg.allcompreal() aux = self.aux def preprocess(r): if proc is not None: r = proc(r) if aux_relax: r.aux += aux r.aux_strengthen() r.aux = Comp.empty() return r for i in range(len(self.exprs_ge) - 1, -1, -1): if PsiOpts.is_timer_ended(): break t = self.exprs_ge[i] self.exprs_ge[i] = Expr.zero() cs = self.imp_intersection_noaux() & reg if not full: cs.remove_notcontained(t.allcomp() + allcompreal) cs = preprocess(cs) if not (cs <= (t >= 0)).check_plain(quick = quick, bnet = bnet): self.exprs_ge[i] = t elif write_pf_enabled: red_reg.exprs_ge.append(t) self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] for i in range(len(self.exprs_eq) - 1, -1, -1): if PsiOpts.is_timer_ended(): break t = self.exprs_eq[i] self.exprs_eq[i] = Expr.zero() cs = self.imp_intersection_noaux() & reg if not full: cs.remove_notcontained(t.allcomp() + allcompreal) cs = preprocess(cs) if not (cs <= (t == 0)).check_plain(quick = quick, bnet = bnet): self.exprs_eq[i] = t elif write_pf_enabled: red_reg.exprs_eq.append(t) self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] for i in range(len(self.exprs_gei) - 1, -1, -1): if PsiOpts.is_timer_ended(): break t = self.exprs_gei[i] self.exprs_gei[i] = Expr.zero() cs = self.imp_flippedonly_noaux() & reg if not full: cs.remove_notcontained(t.allcomp() + allcompreal) cs = preprocess(cs) if not (cs <= (t >= 0)).check_plain(quick = quick, bnet = bnet): self.exprs_gei[i] = t elif write_pf_enabled: red_reg.exprs_gei.append(t) self.exprs_gei = [x for x in self.exprs_gei if not x.iszero()] for i in range(len(self.exprs_eqi) - 1, -1, -1): if PsiOpts.is_timer_ended(): break t = self.exprs_eqi[i] self.exprs_eqi[i] = Expr.zero() cs = self.imp_flippedonly_noaux() & reg if not full: cs.remove_notcontained(t.allcomp() + allcompreal) cs = preprocess(cs) if not (cs <= (t == 0)).check_plain(quick = quick, bnet = bnet): self.exprs_eqi[i] = t elif write_pf_enabled: red_reg.exprs_eqi.append(t) self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] if False: if self.imp_present(): t = self.imp_flippedonly() t.simplify_redundant(reg) self.exprs_gei = t.exprs_ge self.exprs_eqi = t.exprs_eq if write_pf_enabled: if not red_reg.isuniverse(): pf = ProofObj.from_region(red_reg, c = "Remove redundant constraints") pf = ProofObj.from_region(self, c = "Result") PsiOpts.set_setting(proof_add = pf) PsiOpts.settings["proof_enabled"] = prev_write_pf_enabled return self def simplify_symm(self, symm_set, quick = False): """Simplify a region, assuming symmetry among variables in symm_set. """ if isinstance(symm_set, Comp): self.symm_sort(symm_set) self.simplify_quick() if not quick: self.simplify_redundant(proc = lambda t: t.symmetrized(symm_set, skip_simplify = True)) def simplified_symm(self, symm_set, quick = False): """Simplify a region, assuming symmetry among variables in symm_set. """ r = self.copy() r.simplify_symm(symm_set, quick) return r def var_mi_only(self, v): return all(a.var_mi_only(v) for a in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi) def aux_eq_towards(self, v, tgt, reg = None, quick = False, alt_cases = None): if self.imp_present(): return False if not self.var_mi_only(v): return False if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() sreg = self.copy_noaux() & reg sreg_bnet = sreg.get_bayesnet() ege = [a for a in self.exprs_ge if a.ispresent(v)] eeq = [a for a in self.exprs_eq if a.ispresent(v)] eget = [a.substituted(v, tgt) for a in ege] eeqt = [a.substituted(v, tgt) for a in eeq] for e0, e1 in zip(eeq, eeqt): with PsiOpts(proof_enabled = False): if not sreg.implies(e1 == 0, quick = quick, bnet = sreg_bnet): return False # oe0 = None # oe1 = None oe0s = [] oe1s = [] for i, (e0, e1) in enumerate(zip(ege, eget)): with PsiOpts(proof_enabled = False): if sreg.implies(e1 >= 0, quick = quick, bnet = sreg_bnet): continue if alt_cases is None and oe0s: if sreg.implies(e0 <= oe0s[0], quick = quick, bnet = sreg_bnet): oe0s.pop() oe1s.pop() elif sreg.implies(e0 >= oe0s[0], quick = quick, bnet = sreg_bnet): continue else: return False oe0s.append(e0) oe1s.append(e1) # if oe0 is None or sreg.implies(e0 <= oe0, quick = quick, bnet = sreg_bnet): # # print(sreg) # # print(str(e0) + " <= " + str(oe0)) # # print() # oe0 = e0 # oe1 = e1 # continue # if oe0 is not None and sreg.implies(e0 >= oe0, quick = quick, bnet = sreg_bnet): # # print(sreg) # # print(str(e0) + " >= " + str(oe0)) # # print() # continue # return False if not oe0s: return False lowany = False with PsiOpts(proof_enabled = False): lowany = any(sreg.implies(oe1 <= 0, quick = quick, bnet = sreg_bnet) for oe1 in oe1s) if alt_cases is None: if not lowany: return False self.exprs_ge.remove(oe0s[0]) self.exprs_eq.append(oe0s[0]) return True else: for oe0, oe1 in zip(oe0s, oe1s): t = self.copy() t.exprs_ge = list(self.exprs_ge) t.exprs_ge.remove(oe0) t.exprs_eq.append(oe0) t = t.copy() alt_cases.append(t) if not lowany: t = self.copy() t.substitute_aux(v, tgt) alt_cases.append(t) return True def simplify_aux_eq(self, reg = None, quick = False, full = False): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux_eq(reg, quick) self.cons_shallow_copy_(cs) return self if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() # self.split() did = True for it in range(100): if PsiOpts.is_timer_ended(): break did = False for a in (self.aux if it % 2 else reversed(self.aux)): if PsiOpts.is_timer_ended(): break ane = self.var_neighbors(a) totry = [] if full: totry = igen.subset(ane) else: totry = [Comp.empty()] for tgt in totry: if PsiOpts.is_timer_ended(): break if self.aux_eq_towards(a, tgt, reg = reg, quick = quick): did = True break if did: break if not did: break return self def aux_eq_cases_inner(self, reg = None, quick = False, full = True, target = None): if self.aux.isempty() or self.imp_present(): return [self.copy()] if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() alt_cases = [] for a in self.aux: if PsiOpts.is_timer_ended(): break ane = self.var_neighbors(a) totry = [] if target is not None: totry = [target.inter(ane)] elif full: totry = igen.subset(ane) else: totry = [Comp.empty()] for tgt in totry: if PsiOpts.is_timer_ended(): break if self.aux_eq_towards(a, tgt, reg = reg, quick = quick, alt_cases = alt_cases): return sum((c.aux_eq_cases_inner(reg = reg, quick = quick, full = full, target = target) for c in alt_cases), []) return [self.copy()] def aux_eq_cases(self, reg = None, quick = False, full = True, target = None, skip_simplify = False): t = self.aux_eq_cases_inner(reg = reg, quick = quick, full = full, target = target) if len(t) == 1: return t[0] r = anyor(c.noaux() for c in t).exists(self.aux) if not skip_simplify: r = r.simplified() return r def aux_reduced(self, new_aux = None, maxsize = 3, skip_simplify = False, aux_pairs = None, aux_force = None, score_fcn = None): """ Reduce the number of auxiliaries. May enlarge the region. Parameters ---------- new_aux : Comp or int The new auxiliaries. len(new_aux) will be the number of auxiliaries in the new region. maxsize : int The max number of old auxiliaries each new auxiliary corresponds to. Returns ------- Region The new region. """ verbose = PsiOpts.settings.get("verbose_aux_reduced", False) if self.imp_present(): return self.copy() # cs = self.consonly() # return cs.aux_reduced(new_aux = new_aux, maxsize = maxsize) if new_aux is None: new_aux = [] if isinstance(new_aux, int): if new_aux == 0: new_aux = [] elif new_aux == 1: new_aux = [Comp.rv(self.name_avoid("A"))] else: t_new_aux = rv_seq("A", 1, new_aux + 1) new_aux = [] index = IVarIndex() self.record_to(index) for a in t_new_aux: cname = index.name_avoid(a.get_name()) c = Comp.rv(cname) index.record(c) new_aux.append(c) n = len(new_aux) new_aux_sum = sum(new_aux, Comp.empty()) discover_list = list(self.allcomprv() - self.aux) + list(self.allcomprealvar_exprlist()) selfnoaux = self.noaux() index_self = IVarIndex() selfnoaux.record_to(index_self) progs = [] pts_outer = None r = None with PsiOpts(discover_max_facet = None): iters = itertools.combinations(list(igen.subset(self.aux, minsize = 1, maxsize = maxsize)), n) if score_fcn is not None: tlist = [] for t in iters: score = score_fcn(t) if score is None: continue tlist.append((score, t)) tlist.sort() iters = [] for i in range(len(tlist)): score, t = tlist[i] if isinstance(score, tuple) and score[1] < 0: iters.append(t) tlist.pop(i) break iters += [t for score, t in tlist] for cauxs in iters: if PsiOpts.is_timer_ended(): break if aux_force is not None: if n > 0 and not any(c.ispresent(aux_force) for c in cauxs): continue if aux_pairs is not None: scores = [0, 0] for c in cauxs: for i, p in enumerate(aux_pairs): for it in range(2): if c.ispresent(p[it]): scores[it] += 100 + i if scores[0] < scores[1]: if verbose: print(", ".join(str(c) for c in cauxs) + " pair skipped") continue if r is not None: bad = True for cor in itertools.permutations(range(n)): r2 = r.copy() r2.remove_notpresent(new_aux_sum) for i in range(n): r2.substitute_aux(new_aux[i], cauxs[cor[i]]) if selfnoaux.implies_saved(r2.noaux(), index_self, progs): cauxs = tuple([cauxs[cor[i]] for i in range(n)]) bad = False break if bad: if verbose: print(", ".join(str(c) for c in cauxs) + " subset fail") continue if PsiOpts.is_timer_ended(): break if verbose: print(", ".join(str(c) for c in cauxs) + " discover...") new_pts_outer = [] # t = self.discover(discover_list + list(zip(new_aux, cauxs)), init_pts_outer = pts_outer, pts_outer = new_pts_outer) t = self.discover(discover_list + list(zip(new_aux, cauxs))) # r.append(t) if PsiOpts.is_timer_ended(): break if verbose: print(t) print() if not skip_simplify: t = t.exists(new_aux_sum.inter(t.allcomprv())).simplified().noaux() if verbose: print("Simplified:") print(t) print() pts_outer = new_pts_outer r = t # if len(r) == 1: # r = r[0] # else: # r = anyor(r) if r is None: return self.copy() r = r.exists(new_aux_sum.inter(r.allcomprv())) return r def aux_hull_towards(self, v, tgt, reg = None, quick = False): # This is needed for e.g. Gray-Wyner network if self.imp_present(): return False if not self.var_mi_only(v): return False if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() ege = [a for a in self.exprs_ge if a.ispresent(v)] eeq = [a for a in self.exprs_eq if a.ispresent(v)] eget = [a.substituted(v, tgt) for a in ege] eeqt = [a.substituted(v, tgt) for a in eeq] for e0, e1 in zip(eeq, eeqt): with PsiOpts(proof_enabled = False): if not reg.implies(e0 == e1, quick = quick): return False sns = [[False] * len(ege), [False] * len(ege)] for i, (e0, e1) in enumerate(zip(ege, eget)): with PsiOpts(proof_enabled = False): sns[0][i] = reg.implies(e0 <= e1, quick = quick) sns[1][i] = reg.implies(e0 >= e1, quick = quick) for i, (e0, e1, s0, s1) in enumerate(zip(ege, eget, sns[0], sns[1])): if s0 and s1: continue if not (s0 or s1): continue did = False lvls = [0] * len(ege) inc = Expr.zero() if s0: inc = e1 - e0 lvls[i] = 1 did = True else: inc = e0 - e1 lvls[i] = -1 inc.simplify_quick() # print("TRY " + str(e0) + " " + str(e1) + " " + str(inc)) bad = False for iz, (ez0, ez1, sz0, sz1) in enumerate(zip(ege, eget, sns[0], sns[1])): if iz == i: continue if sz0 and sz1: continue ezdiff = (ez1 - ez0).simplified_quick() if sz0: with PsiOpts(proof_enabled = False): if reg.implies(ezdiff >= inc, quick = quick): lvls[iz] = 1 did = True else: lvls[iz] = 0 else: with PsiOpts(proof_enabled = False): if reg.implies(ezdiff >= -inc, quick = quick): lvls[iz] = -1 else: bad = True break # print(" TO " + str(ez0) + " " + str(ez1) + " " + str(lvls[iz])) if bad or not did: continue t = Expr.real("#TMPVAR") for iz, (ez0, ez1, sz0, sz1) in enumerate(zip(ege, eget, sns[0], sns[1])): ez0 += lvls[iz] * t self.iand_norename(t >= 0) self.iand_norename(t <= inc) # print(self) with PsiOpts(simplify_aux_hull = False, simplify_aux_hull_lower_complexity = False, proof_enabled = False): self.eliminate(t) return True return False def cons_shallow_copy_(self, other): self.exprs_ge = other.exprs_ge self.exprs_eq = other.exprs_eq self.aux = other.aux return self def simplify_aux_hull(self, reg = None, quick = False, lower_complexity = False): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux_hull(reg, quick, lower_complexity) self.cons_shallow_copy_(cs) return self if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() orig_self = None if lower_complexity: orig_self = self.copy() did = True for it in range(100): if PsiOpts.is_timer_ended(): break did = False for a in (self.aux if it % 2 else reversed(self.aux)): if PsiOpts.is_timer_ended(): break ane = self.var_neighbors(a) for tgt in igen.subset(ane): if self.aux_hull_towards(a, tgt, reg = reg, quick = quick): did = True break if did: break if not did: break if lower_complexity: self.simplify_quick() if self.complexity() >= orig_self.complexity(): self.copy_(orig_self) return self def simplify_aux_combine(self, reg = None, r_did = None): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux_combine(reg, r_did) self.cons_shallow_copy_(cs) return self if reg is None: reg = Region.universe() did = True while did: if self.aux.isempty(): break if PsiOpts.is_timer_ended(): break selfplain = self.copy() selfplain.aux = Comp.empty() selfplain &= reg did = False for i in range(len(self.aux)): if PsiOpts.is_timer_ended(): break for j in range(i + 1, len(self.aux)): if PsiOpts.is_timer_ended(): break cs = self.copy() cs.aux = Comp.empty() cs.substitute(self.aux[i], self.aux[i] + self.aux[j]) cs.substitute(self.aux[j], self.aux[i] + self.aux[j]) with PsiOpts(proof_enabled = False): if selfplain.implies(cs): with PsiOpts(meta_subs_criteria = "eqtype"): self.substitute(self.aux[j], self.aux[i]) if r_did is not None: r_did[0] = True did = True break if did: break if did: continue for i in range(len(self.aux)): if PsiOpts.is_timer_ended(): break for j in range(len(self.aux)): if i == j: continue if PsiOpts.is_timer_ended(): break cs = self.copy() cs.aux = Comp.empty() cs.substitute(self.aux[j], Comp.empty()) cs.substitute(self.aux[i], self.aux[i] + self.aux[j]) with PsiOpts(proof_enabled = False): if selfplain.implies(cs): with PsiOpts(meta_subs_criteria = "eqtype"): self.substitute(self.aux[j], Comp.empty()) if r_did is not None: r_did[0] = True did = True break if did: break return self def simplify_aux(self, reg = None, cases = None, r_did = None): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux(reg, cases, r_did) self.cons_shallow_copy_(cs) return self leaveone = None if cases is not None: leaveone = True if reg is None: reg = Region.universe() did = True while did: if self.aux.isempty(): break did = False index = IVarIndex() self.record_to(index) reg.record_to(index) taux = self.aux.copy() taux2 = taux.copy() tauxi = self.auxi.copy() self.aux = Comp.empty() self.auxi = Comp.empty() # print(self) for a in taux: if PsiOpts.is_timer_ended(): break a2 = Comp.rv(index.name_avoid(a.get_name())) cs = (self.consonly().substituted(a, a2) & reg) >> self.exists(a) #hint_aux_avoid = self.get_aux_avoid_list() hint_aux_avoid = [(a, a2)] # print(cs) with PsiOpts(proof_enabled = False): for rr in cs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid, leaveone = leaveone): stype = iutil.signal_type(rr) with PsiOpts(meta_subs_criteria = "eqtype"): if stype == "": # print(rr) ar = a.copy() Comp.substitute_list(ar, rr) if ar == a: continue # print(self) self.substitute(a, ar) # print(self) taux2 -= a if r_did is not None: r_did[0] = True did = True break elif stype == "leaveone" and cases is not None: ar = a.copy() Comp.substitute_list(ar, rr[1]) if ar == a: continue tr = self.copy() # tr.iand_norename(rr[2] <= 0) tr.substitute(a, ar) tr.aux = taux2 - a tr.auxi = tauxi.copy() tr.simplify_quick() tr.simplify_aux(reg, cases) cases.append(tr) self.iand_norename(rr[2] >= 0) if r_did is not None: r_did[0] = True did = True self.aux = taux2 self.auxi = tauxi if did: self.simplify_quick() return self def simplify_imp(self, reg = None): if not self.imp_present(): return def simplify_expr_exhaust(self): for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: x.simplify_exhaust() return self def expanded_cases_reduce(self, reg = None, skip_simplify = False): r = self.copy() cases = [] r.simplify_aux(reg, cases) if len(cases) == 0: return r else: r2 = RegionOp.union([r] + cases) if not skip_simplify: return r2.simplified() else: return r2 def expanded_cases(self, reg = None, leaveone = True, skip_simplify = False): if self.aux.isempty(): return None if self.imp_present(): return None cases = [] if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() did = False index = IVarIndex() self.record_to(index) reg.record_to(index) # taux = self.aux.copy() # taux2 = taux.copy() # tauxi = self.auxi.copy() # self.aux = Comp.empty() # self.auxi = Comp.empty() rmap = index.calc_rename_map(self.aux) cs2 = self.copy() cs2.rename_map(rmap) s2 = self.consonly() s2.aux = Comp.empty() cs = (s2 & reg) >> cs2 hint_aux_avoid = cs.get_aux_avoid_list() with PsiOpts(forall_multiuse_numsave = 0): for rr in cs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid, leaveone = leaveone): stype = iutil.signal_type(rr) if stype == "": continue print(rr) cs3 = cs2.copy() Comp.substitute_list(cs3, rr, isaux = True) cs3.eliminate(self.aux.inter(cs3.allcomp())) cases.append(cs3) elif stype == "leaveone": print(str(rr[1]) + " " + str(rr[2])) cs3 = cs2.copy() Comp.substitute_list(cs3, rr[1], isaux = True) cs3.eliminate(self.aux.inter(cs3.allcomp())) cases.append(cs3) cs2.iand_norename(rr[2] >= 0) if PsiOpts.is_timer_ended(): break cs3 = cs2.copy() cs3.rename_map({b: a for a, b in rmap.items()}) cases.append(cs3) if len(cases) == 1: return cases[0] else: r2 = RegionOp.union() for a in cases: r2 |= a if not skip_simplify: return r2.simplified() else: return r2 def simplify_aux_empty(self, reg = None, skip_simplify = False): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux_empty(reg, skip_simplify) self.cons_shallow_copy_(cs) return self if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() cs = self.copy() cs.aux = Comp.empty() cs = cs & reg index_self = IVarIndex() cs.record_to(index_self) progs = [] for taux in igen.subset(self.aux, minsize = 2): cs2 = self.copy() cs2.aux = Comp.empty() cs2.substitute(taux, Comp.empty()) # if cs.implies(cs2): if cs.implies_saved(cs2, index_self, progs): with PsiOpts(meta_subs_criteria = "eqtype"): self.substitute_aux(taux, Comp.empty()) if not skip_simplify: self.simplify_quick() self.simplify_aux_empty(reg, skip_simplify) return self return self def simplify_aux_recombine(self, reg = None, skip_simplify = False, r_did = None): if self.aux.isempty(): return self if self.imp_present(): cs = self.consonly() cs.simplify_aux_recombine(reg, skip_simplify, r_did) self.cons_shallow_copy_(cs) return self if isinstance(reg, RegionOp): reg = reg.tosimple() if reg is None: reg = Region.universe() did = False index = IVarIndex() self.record_to(index) reg.record_to(index) rmap = index.calc_rename_map(self.aux) cs2 = self.copy() cs2.rename_map(rmap) s2 = self.consonly() s2.aux = Comp.empty() cs = (s2 & reg) >> cs2 hint_aux_avoid = cs.get_aux_avoid_list() cmin = (len(self.aux), -len(self.aux)) minlist = None with PsiOpts(auxsearch_leaveone_add_ineq = False): for rr in cs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid): stype = iutil.signal_type(rr) if stype == "": # print(rr) cnaux = 0 csize = 0 clist = [] cvar = Comp.empty() for a in cs2.aux: b = a.copy() Comp.substitute_list(b, rr) csize += len(b) cvar += b clist.append(b) for a in self.aux: if cvar.ispresent(a): cnaux += 1 t = (cnaux, -csize) if t < cmin: cmin = t minlist = clist if PsiOpts.is_timer_ended(): break if minlist is None: return self prevaux = self.aux.copy() self.rename_map(rmap) selfaux = self.aux.copy() with PsiOpts(meta_subs_criteria = "eqtype"): for a, b in zip(selfaux, minlist): self.substitute_aux(a, b) self.eliminate(prevaux.inter(self.allcomp())) if r_did is not None: r_did[0] = True if not skip_simplify: self.simplify_quick() return self def sort(self): for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi: x.sort() self.exprs_ge.sort(key = lambda x: x.sorting_tuple_eqn()) self.exprs_eq.sort(key = lambda x: x.sorting_tuple_eqn()) self.exprs_gei.sort(key = lambda x: x.sorting_tuple_eqn()) self.exprs_eqi.sort(key = lambda x: x.sorting_tuple_eqn()) def simplify(self, reg = None, zero_group = 0, **kwargs): """Simplify a region in place Optional argument reg with constraints assumed to be true zero_group = 2: group all nonnegative terms as a single inequality """ if kwargs: r = None with PsiOpts(**{"simplify_" + key: val for key, val in kwargs.items()}): r = self.simplify(reg, zero_group) return r if not PsiOpts.settings.get("simplify_enabled", False): return self if reg is None: reg = Region.universe() nit = PsiOpts.settings.get("simplify_num_iter", 1) if PsiOpts.settings.get("simplify_regterm", False): self.simplify_regterm(reg) with PsiOpts(simplify_regterm = False): for it in range(nit): self.simplify_quick(reg, zero_group) if not PsiOpts.settings.get("simplify_quick", False): if PsiOpts.settings.get("simplify_remove_missing_aux", False): self.remove_missing_aux() if PsiOpts.settings.get("simplify_aux_commonpart", False): self.simplify_aux_commonpart(reg, maxlen = PsiOpts.settings.get("simplify_aux_xor_len", False)) if PsiOpts.settings.get("simplify_aux_empty", False): self.simplify_aux_empty() # if PsiOpts.settings.get("simplify_aux_combine", False): # self.simplify_aux_combine(reg) if PsiOpts.settings.get("simplify_aux_recombine", False): self.simplify_aux_recombine() elif PsiOpts.settings.get("simplify_aux", False): self.simplify_aux() r_did = [False] if PsiOpts.settings.get("simplify_aux_combine", False): self.simplify_aux_combine(reg, r_did) if r_did[0]: if PsiOpts.settings.get("simplify_aux_recombine", False): self.simplify_aux_recombine() elif PsiOpts.settings.get("simplify_aux", False): self.simplify_aux() # if PsiOpts.settings.get("simplify_bayesnet", False): # self.simplify_bayesnet(reg) if PsiOpts.settings.get("simplify_redundant", False): self.simplify_redundant(reg, full = PsiOpts.settings.get("simplify_redundant_full", False)) if PsiOpts.settings.get("simplify_aux_hull", False): self.simplify_aux_hull(reg) elif PsiOpts.settings.get("simplify_aux_hull_lower_complexity", False): self.simplify_aux_hull(reg, quick = True, lower_complexity = True) if PsiOpts.settings.get("simplify_aux_eq", False): self.simplify_aux_eq(reg) if PsiOpts.settings.get("simplify_bayesnet", False): self.simplify_bayesnet(reg) if PsiOpts.settings.get("simplify_expr_exhaust", False): self.simplify_expr_exhaust() if PsiOpts.settings.get("simplify_pair", False): self.simplify_pair(reg) if PsiOpts.settings.get("simplify_remove_missing_aux", False): self.remove_missing_aux() if PsiOpts.settings.get("simplify_aux_strengthen", False): self.aux_strengthen() if PsiOpts.settings.get("simplify_sort", False): self.sort() return self def simplify_union(self, reg = None): self.simplify(reg) return self def simplified_quick(self, reg = None, zero_group = 0): """Returns the simplified region Optional argument reg with constraints assumed to be true zero_group = 2: group all nonnegative terms as a single inequality """ if reg is None: reg = Region.universe() r = self.copy() r.simplify_quick(reg, zero_group) return r def simplified(self, reg = None, zero_group = 0, **kwargs): """Returns the simplified region Optional argument reg with constraints assumed to be true zero_group = 2: group all nonnegative terms as a single inequality """ if reg is None: reg = Region.universe() r = self.copy() r.simplify(reg, zero_group, **kwargs) return r def remove_trivial(self): self.exprs_gei = [x for x in self.exprs_gei if not x.isnonneg()] self.exprs_eqi = [x for x in self.exprs_eqi if not (x.isnonneg() and x.isnonpos())] self.exprs_ge = [x for x in self.exprs_ge if not x.isnonneg()] self.exprs_eq = [x for x in self.exprs_eq if not (x.isnonneg() and x.isnonpos())] def removed_trivial(self): r = self.copy() r.remove_trivial() return r def get_sum_seq(self, prefer_short = True, simplify_exhaust = False, target = None, bnet = None): exprs = list(self.exprs_ge + self.exprs_eq) c = Expr.zero() r0 = [] r1 = [] exprs_vis = [False] * len(exprs) eseq = [] for it in range(len(exprs)): minc = None minx = None mincomp = 1e20 mini = 0 for i, x in enumerate(exprs): if exprs_vis[i]: continue t = None with PsiOpts(proof_enabled = False): t = (c + x).simplified_quick(bnet = bnet) if simplify_exhaust: t.simplify_break_hc() # t = t.simplified_exhaust() tcomp = t.complexity() * 2 - x.complexity() if tcomp < mincomp: mincomp = tcomp minc = t minx = x mini = i if not prefer_short: break r0.append(minx) r1.append(minc) c = minc exprs_vis[mini] = True eseq.append(mini) # print(i) # print(exprs) # print(minx) # print(minc) if target is not None: for i, (x, t) in enumerate(reversed(list(zip(r0, r1)))): t.simplify_target(target, bnet = bnet) if i * 2 > len(r0): break target = (target - x).simplified_quick(bnet = bnet) if simplify_exhaust: target.simplify_break_hc() return (r0, r1) def get_ic(self, include_ic = True, include_hc = False, skip_simplify = False): cs = self if not skip_simplify: cs = self.simplified_quick(zero_group = 2) # icexpr = Expr.zero() # for x in cs.exprs_ge: # if x.isnonpos(): # icexpr += x # for x in cs.exprs_eq: # if x.isnonpos(): # icexpr += x # elif x.isnonneg(): # icexpr -= x # return icexpr exprs = [] r = Expr.zero() for x in cs.exprs_ge: if x.isnonpos(): exprs.append(x) for x in cs.exprs_eq: if x.isnonpos() or x.isnonneg(): exprs.append(x) for x in exprs: for a, c in x.terms: if (include_ic and a.isic2()) or (include_hc and a.ishc()): r += Expr.fromterm(a) return r def remove_ic(self): exprs = [] r = Expr.zero() for x in self.exprs_ge: if x.isnonpos(): exprs.append(x) for x in self.exprs_eq: if x.isnonpos() or x.isnonneg(): exprs.append(x) for x in exprs: tterms = [] for a, c in x.terms: if a.isic2(): r += Expr.fromterm(a) else: tterms.append((a, c)) x.terms = tterms x.mhash = None self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()] self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] return r def get_bayesnet(self, vs = None, roots = None, semigraphoid_iter = None, get_list = False, skip_simplify = False): """Return a Bayesian network containing the conditional independence conditions in this region """ if semigraphoid_iter is None: semigraphoid_iter = PsiOpts.settings["bayesnet_semigraphoid_iter"] icexpr = None if vs is not None: icexpr = self.completed_semigraphoid_ic(vs = vs, max_iter = semigraphoid_iter) else: icexpr = self.get_ic(skip_simplify = skip_simplify) if semigraphoid_iter > 0: icexpr += self.completed_semigraphoid_ic(max_iter = semigraphoid_iter) if get_list: return BayesNet.from_ic_list(icexpr, roots = roots) else: return BayesNet.from_ic(icexpr, roots = roots).tsorted() def graph(self, vs = None, **kwargs): """Return the Bayesian network among the random variables as a graphviz digraph that can be displayed in the console. """ return self.get_bayesnet(vs = vs).graph(**kwargs) def get_bayesnet_imp(self, skip_simplify = False, add_hc = True): """Return a Bayesian network containing the conditional independence conditions in this region """ cs = self if not skip_simplify: cs = self.simplified_quick(zero_group = 2) icexpr = Expr.zero() for x in cs.exprs_gei: if x.isnonpos(): icexpr += x for x in cs.exprs_eqi: if x.isnonpos(): icexpr += x elif x.isnonneg(): icexpr -= x return BayesNet.from_ic(icexpr, add_var = self.allcomprv(), add_hc = add_hc).tsorted() def get_markov(self): """Get Markov chains as a list of lists. """ bnets = self.get_bayesnet(get_list = True) r = [] for bnet in bnets: r += bnet.get_markov() return r def table(self, *args, skip_cons = False, plot = True, use_latex = None, **kwargs): """Plot the information diagram as a Karnaugh map. """ if use_latex is None: use_latex = PsiOpts.settings["venn_latex"] imp_r = self.imp_flippedonly_noaux() if not imp_r.isuniverse(): return imp_r.table(self.consonly(), *args, skip_cons, plot, **kwargs) ceps = PsiOpts.settings["eps"] index = IVarIndex() cmodel = None for a in args: if isinstance(a, ConcModel) or isinstance(a, LinearProg): cmodel = a for a in args: if isinstance(a, Comp): a.record_to(index) self.get_bayesnet().allcomp().record_to(index) if cmodel is not None: if isinstance(cmodel, ConcModel): cmodel.get_bayesnet().allcomp().record_to(index) for a in args: if isinstance(a, Expr) or isinstance(a, Region): a.record_to(index) cs = self if cmodel is not None: if isinstance(cmodel, ConcModel): cs = cs & cmodel.get_bayesnet().get_region() cs = cs.imp_flipped() cs.record_to(index) comprv = index.comprv r = CellTable(comprv) progs = [] # progs.append(self.init_prog(index, lptype = LinearProgType.H)) progs.append(cs.init_prog(index)) creg = Region.universe() for mask in range(1, 1 << len(comprv)): # ch = H(comprv.from_mask(mask) | comprv.from_mask(((1 << len(comprv)) - 1) ^ mask)) ch = I(alland(comprv[i] for i in range(len(comprv)) if mask & (1 << i)) | comprv.from_mask(((1 << len(comprv)) - 1) ^ mask)) ispos = ch.isnonneg() or cs.implies_ineq_prog(index, progs, ch, ">=") isneg = cs.implies_ineq_prog(index, progs, -ch, ">=") # print(str(ch) + " " + str(ispos) + " " + str(isneg)) if ispos and isneg: r.set_attr(mask, "enabled", False) if not skip_cons: creg.iand_norename(ch == 0) elif ispos: if cmodel is None: r.set_attr(mask, "ispos", True) if not skip_cons and not ch.isnonneg(): creg.iand_norename(ch >= 0) elif isneg: if cmodel is None: r.set_attr(mask, "isneg", True) if not skip_cons: creg.iand_norename(ch <= 0) if cmodel is not None: r.set_attr(mask, "cval", cmodel[ch]) cnexpr = 0 cargs = [] if not skip_cons: for cexpr in self.exprs_ge: if not creg.implies(cexpr >= 0): cargs.append(cexpr >= 0) for cexpr in self.exprs_eq: if not creg.implies(cexpr == 0): cargs.append(cexpr == 0) cargs += args for a in cargs: exprlist = [] if isinstance(a, Expr): exprlist = [(a, None)] elif isinstance(a, Region): exprlist = [(b, ">=") for b in a.exprs_ge] + [(b, "==") for b in a.exprs_eq] for b, sn in exprlist: if sn == ">=": r.add_expr(b >= 0, None if cmodel is None else cmodel[b >= 0]) elif sn == "==": r.add_expr(b == 0, None if cmodel is None else cmodel[b == 0]) else: r.add_expr(b, None if cmodel is None else cmodel[b]) maskvals = [0.0 for mask in range(1 << len(comprv))] for v, c in b.terms: if v.get_type() == TermType.IC: xmasks = [comprv.get_mask(x) for x in v.x] zmask = comprv.get_mask(v.z) for mask in range(1, 1 << len(comprv)): if mask & zmask: continue if any(not(mask & xmask) for xmask in xmasks): continue maskvals[mask] += c # print(mask) for mask in range(1, 1 << len(comprv)): if abs(maskvals[mask]) > ceps: r.set_expr_val(mask, maskvals[mask]) # r.set_attr(mask, "val_" + str(cnexpr), maskvals[mask]) cnexpr += 1 # r.nexpr = cnexpr if plot: r.plot(use_latex = use_latex, **kwargs) else: return r def venn(self, *args, style = None, **kwargs): """Plot the information diagram as a Venn diagram. Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5). """ if style is None: style = "" style = "venn,blend," + style return self.table(*args, style = style, **kwargs) def eliminate_term_eq(self, w): ee = None if ee is None: for x in self.exprs_eqi: c = x.get_coeff(w) if abs(c) > PsiOpts.settings["eps"]: ee = (x * (-1.0 / c)).removed_term(w) x.setzero() break if ee is None: for x in self.exprs_eq: c = x.get_coeff(w) if abs(c) > PsiOpts.settings["eps"]: ee = (x * (-1.0 / c)).removed_term(w) x.setzero() break if ee is not None: self.substitute(Expr.fromterm(w), ee) self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] return True return False def get_lb_ub_eq(self, w): """ Get lower bounds, upper bounds and equality constraints about w. Parameters ---------- w : Expr The real variable of interest. Returns ------- tuple A tuple of 3 lists of Expr: Lower bounds, upper bounds and equality constraints about w. """ if isinstance(w, Expr): w = Term.fromcomp(w.allcomp()) el = [] er = [] ee = [] for x in self.exprs_ge: c = x.get_coeff(w) if abs(c) <= PsiOpts.settings["eps"]: pass elif c > 0: er.append((x * (-1.0 / c)).removed_term(w)) else: el.append((x * (-1.0 / c)).removed_term(w)) for x in self.exprs_eq: c = x.get_coeff(w) if abs(c) <= PsiOpts.settings["eps"]: pass else: ee.append((x * (-1.0 / c)).removed_term(w)) return (er, el, ee) def get_one_bound(self, w, max_term = 1): er, el, ee = self.get_lb_ub_eq(w) if not ee: if len(er) == 1: ee.append(er[0]) elif len(er) <= max_term and len(er) <= len(el): ee.append(emax(*er)) if len(el) == 1: ee.append(el[0]) elif len(el) <= max_term and len(el) <= len(er): ee.append(emin(*el)) if not ee: return None if len(ee) == 1: return ee[0].copy() def tmp_complexity(x): return (x.complexity(), len(repr(x))) # return (x.isregtermpresent(), not x.ispresent("realvar"), x.complexity(), len(repr(x))) r = None rc = None for x in ee: c = tmp_complexity(x) if r is None or c < rc: rc = c r = x return r.copy() def eliminate_term(self, w, forall = False, reg_record = None): verbose = PsiOpts.settings.get("verbose_eliminate", False) reg_record_r = None if reg_record is not None: reg_record_r = Region.universe() reg_record.append((Expr.fromterm(w), reg_record_r)) el = [] er = [] ee = [] if verbose: print("=========== Eliminate ===========") print(w) for x in self.exprs_gei: c = x.get_coeff(w) if abs(c) <= PsiOpts.settings["eps"]: x.remove_term(w) elif c > 0: er.append((x * (-1.0 / c)).removed_term(w)) x.setzero() else: el.append((x * (-1.0 / c)).removed_term(w)) x.setzero() for x in self.exprs_eqi: c = x.get_coeff(w) if abs(c) <= PsiOpts.settings["eps"]: x.remove_term(w) else: ee.append((x * (-1.0 / c)).removed_term(w)) x.setzero() elni = len(el) erni = len(er) eeni = len(ee) if not forall and elni + erni + eeni > 0: self.setuniverse() return self for x in self.exprs_ge: c = x.get_coeff(w) if abs(c) <= PsiOpts.settings["eps"]: x.remove_term(w) else: if c > 0: er.append((x * (-1.0 / c)).removed_term(w)) else: el.append((x * (-1.0 / c)).removed_term(w)) if reg_record_r is not None: reg_record_r.exprs_ge.append(x.copy()) x.setzero() for x in self.exprs_eq: c = x.get_coeff(w) if abs(c) <= PsiOpts.settings["eps"]: x.remove_term(w) else: ee.append((x * (-1.0 / c)).removed_term(w)) if reg_record_r is not None: reg_record_r.exprs_ge.append(x.copy()) x.setzero() if len(ee) > 0: if eeni == 0: for i in range(elni): x = el[i] for j in range(erni): y = er[j] self.exprs_gei.append(x - y) for i in range(len(el)): x = el[i] if i < elni and 0 < eeni: self.exprs_gei.append(x - ee[0]) else: self.exprs_ge.append(x - ee[0]) for j in range(len(er)): y = er[j] if j < erni and 0 < eeni: self.exprs_gei.append(ee[0] - y) else: self.exprs_ge.append(ee[0] - y) for i in range(1, len(ee)): x = ee[i] if i < eeni: self.exprs_eqi.append(x - ee[0]) else: self.exprs_eq.append(x - ee[0]) else: for i in range(len(el)): x = el[i] for j in range(len(er)): y = er[j] if i < elni and j < erni: self.exprs_gei.append(x - y) else: self.exprs_ge.append(x - y) if verbose: print("=========== To ===========") print(self) return self @staticmethod def eliminate_reg_record_clean(reg_record, keep_terms, cross=False): for a, b in reg_record: if cross or not keep_terms.ispresent(a): for a2, b2 in reg_record: if not a.ispresent(a2) and b2.ispresent(a): b2.iand_norename(b) b2.eliminate(a) def eliminate_toreal(self, w, forall = False): verbose = PsiOpts.settings.get("verbose_eliminate_toreal", False) if verbose: print("========== elim real ========") print(self) print("========== var =========") print(w) reals = self.allcompreal() self.iand_norename(self.get_prog_region(toreal = w, toreal_only = True)) self.imp_flip() self.iand_norename(self.get_prog_region(toreal = w, toreal_only = True)) self.imp_flip() self.remove_present(w) reals = self.allcompreal() - reals if verbose: print("========== to =========") print(self) for a in reals.varlist: self.eliminate_term(Term.fromcomp(Comp([a])), forall = forall) self.simplify_quick() if verbose: print("========== elim " + str(a) + " =========") print(self) return self def eliminate_toreal_rays(self, w): cs = self.consonly().imp_flipped() index = IVarIndex() cs.record_to(index) r = cs.init_prog(index, lptype = LinearProgType.H).get_region_elim_rays(w) self.exprs_ge = r.exprs_ge self.exprs_eq = r.exprs_eq return self def remove_aux(self, w): self.aux -= w self.auxi -= w def remove_missing_aux(self): #return taux = self.aux self.aux = Comp.empty() tauxi = self.auxi self.auxi = Comp.empty() allcomp = self.allcomprv() self.aux = taux.inter(allcomp) self.auxi = tauxi.inter(allcomp) @staticmethod def get_allcomp(w): if w is None: return Comp.empty() if isinstance(w, CompArray) or isinstance(w, ExprArray): w = w.allcomp() if isinstance(w, (tuple, list)): w = sum((Region.get_allcomp(a) for a in w), Comp.empty()) if isinstance(w, Expr): w = w.allcomp() return w def eliminate_ic(self, w): self.iand_norename(self.completed_semigraphoid(self.allcomprv() - w)) self.split() self.remove_relax(w) return self def eliminate(self, w, reg = None, toreal = False, forall = False, quick = False, method = "", reg_record = None): """Fourier-Motzkin elimination, in place. w is the Expr object with the real variables to eliminate. If w contains random variables, they will be treated as auxiliary RV. """ w = Region.get_allcomp(w) # if isinstance(w, Comp): # w = Expr.H(w) if method == "real": toreal = True if ((not toreal and not forall and not self.auxi.isempty() and any(v.get_type() == IVarType.RV for v in w.allcomp())) or (forall and any(v.get_type() == IVarType.REAL for v in w.allcomp())) or (not forall and self.imp_present())): return RegionOp.inter([self]).eliminate(w, reg, toreal, forall) #self.simplify_quick(reg) if toreal and PsiOpts.settings["eliminate_rays"]: w2 = w w = Comp.empty() toelim = Comp.empty() for v in w2.allcomp(): if toreal or v.get_type() == IVarType.REAL: w += v elif v.get_type() == IVarType.RV: toelim += v if not toelim.isempty(): self.eliminate_toreal_rays(toelim) toelim = Comp.empty() rvs = Comp.empty() simplify_needed = False for v in w.allcomp(): if v.get_type() == IVarType.REAL or toreal: if simplify_needed: if quick: self.simplify_quick(reg) else: with PsiOpts(simplify_redundant_full = False, simplify_aux_commonpart = False, simplify_aux = False, simplify_bayesnet = False): self.simplify(reg) if v.get_type() == IVarType.REAL: self.eliminate_term(Term.fromcomp(v), forall = forall, reg_record = reg_record) else: self.eliminate_toreal(v, forall = forall) simplify_needed = True else: rvs += v if forall: self.auxi += v else: self.aux += v if method == "ic" or method == "ci": if not forall and not rvs.isempty(): self.eliminate_ic(rvs) simplify_needed = True if simplify_needed: if quick: self.simplify_quick(reg) else: self.simplify(reg) return self def eliminate_quick(self, w, reg = None, toreal = False, forall = False, method = ""): """Fourier-Motzkin elimination, in place. w is the Expr object with the real variables to eliminate. If w contains random variables, they will be treated as auxiliary RV. """ return self.eliminate(w, reg = reg, toreal = toreal, forall = forall, quick = True, method = method) def eliminated(self, w, reg = None, toreal = False, forall = False, method = ""): """Fourier-Motzkin elimination, return region after elimination. w is the Expr object with the real variable to eliminate. If w contains random variables, they will be treated as auxiliary RV. """ r = self.copy() r = r.eliminate(w, reg, toreal, forall, method = method) return r def eliminated_quick(self, w, reg = None, toreal = False, forall = False, method = ""): """Fourier-Motzkin elimination, return region after elimination. w is the Expr object with the real variable to eliminate. If w contains random variables, they will be treated as auxiliary RV. """ r = self.copy() r = r.eliminate_quick(w, reg, toreal, forall, method = method) return r def exists(self, w = None, reg = None, toreal = False, method = ""): """Alias of eliminated. """ if w is None: w = self.allcomprv_noaux() r = self.copy() r = r.eliminate(w, reg, toreal, forall = False, method = method) return r def exists_quick(self, w = None, reg = None, toreal = False, method = ""): """Alias of eliminated_quick. """ if w is None: w = self.allcomprv_noaux() r = self.copy() r = r.eliminate_quick(w, reg, toreal, forall = False, method = method) return r def forall(self, w = None, reg = None, toreal = False, method = ""): """Region of intersection for all variable w. """ if w is None: w = self.allcomprv_noaux() r = self.copy() r = r.eliminate(w, reg, toreal, forall = True, method = method) return r def forall_quick(self, w = None, reg = None, toreal = False, method = ""): """Region of intersection for all variable w. """ if w is None: w = self.allcomprv_noaux() r = self.copy() r = r.eliminate_quick(w, reg, toreal, forall = True, method = method) return r def rv_seq_avoid(self, w): index = IVarIndex() self.record_to(index) w2 = [] for a in w: b = Comp.rv(index.name_avoid(a.get_name())) w2.append(b) b.record_to(index) return sum(w2, Comp.empty()) def unique(self, w = None): """The random variable w satisfying the constraints in this region is unique. For both existence and uniqueness, use r.exists_unique(w) instead. """ if w is None: w = self.allcomprv_noaux() w2 = self.rv_seq_avoid(w) return ((self & self.substituted(list(zip(w, w2)))) >> alland(equiv(a, b) for a, b in zip(w, w2))).forall(w+w2) def exists_unique(self, w = None): """The random variable w satisfying the constraints in this region exists and is unique. """ # return self.exists(w) & self.unique(w) if w is None: w = self.allcomprv_noaux() w2 = self.rv_seq_avoid(w) return (self & (self.substituted(list(zip(w, w2))) >> alland(equiv(a, b) for a, b in zip(w, w2))).forall(w2)).exists(w) def projected(self, w = None, reg = None, quick = False): """Project the region to real variables in the list w by eliminating all other real variables. E.g. for a Region r with 3 real variables R1,R2,R3, r.projected([R1, R2]) keeps R1,R2 and eliminates R3, and r.projected(S == R1+R2+R3) projects the region to the diagonal S == R1+R2+R3 (i.e., introduces S and eliminates R1,R2,R3). """ if w is None: w = [] if isinstance(w, Expr): w = list(w) if isinstance(w, Region): w = [w] compreal = self.allcompreal() r = self.copy() for a in w: if isinstance(a, Expr): compreal -= a.allcomp() elif isinstance(a, Region): a2 = a if a2.isregtermpresent(): a2 = a2.flattened(minmax_elim = True) if r.get_type() == RegionType.NORMAL and a2.get_type() == RegionType.NORMAL: r.iand_norename(a2) else: r &= a2 # print(r) if quick: r = r.eliminate_quick(compreal, reg) else: r = r.eliminate(compreal, reg) return r def splice_rate(self, w0, w1): """Allow decreasing the real variable w0 (Expr) and increasing w1 (Expr) by the same ammount. """ if isinstance(w0, Expr): w0 = [w0] if isinstance(w1, Expr): w1 = [w1] t = Expr.real("#TMPVAR") for a0 in w0: self.substitute(a0, a0 + t) for a1 in w1: self.substitute(a1, a1 - t) for a0 in w0: self &= a0 >= 0 self &= t >= 0 self.eliminate(t) return self def spliced_rate(self, w0, w1): """Allow decreasing the real variable w0 (Expr) and increasing w1 (Expr) by the same ammount. """ r = self.copy() r.splice_rate(w0, w1) return r def marginal_eliminate(self, w): """Set input, in place. Denote the RV's in w (type Comp) as input variables, and the region is the union over distributions of w. """ self.inp += w return self def marginal_exists(self, w): """Set input, return result. Denote the RV's in w (type Comp) as input variables, and the region is the union over distributions of w. """ r = self.copy() r.marginal_eliminate(w) return r def marginal_forall(self, w): """Set input, return result. Currently we do not differentiate between input distribution union and intersection. May change in the future. """ r = self.copy() r.marginal_eliminate(w) return r def kernel_eliminate(self, w): """Set output, in place. Denote the RV's in w (type Comp) as output variables, and the region is the union over channels leading to w with same marginal on w. """ self.oup += w return self def kernel_exists(self, w): """Set output, return result. Denote the RV's in w (type Comp) as output variables, and the region is the union over channels leading to w with same marginal on w. """ r = self.copy() r.kernel_eliminate(w) return r def kernel_forall(self, w): """Set output, return result. Currently we do not differentiate between channel union and intersection. May change in the future. """ r = self.copy() r.kernel_eliminate(w) return r def issymmetric(self, w, quick = False): """Check whether region is symmetric with respect to the variables in w """ csstr = "" if quick: csstr = self.tostring(tosort = True) for i in range(1, len(w)): t = self.copy() tvar = Comp.rv("SYMM_TMP") t.substitute(w[i], tvar) t.substitute(w[0], w[i]) t.substitute(tvar, w[0]) if quick: if t.tostring(tosort = True) != csstr: return False else: if not t.implies(self): return False if not self.implies(t): return False return True def __imul__(self, other): compreal = self.allcompreal() for a in compreal.varlist: self.substitute(Expr.fromcomp(Comp([a])), Expr.fromcomp(Comp([a])) / other) self.simplify_quick() return self def __mul__(self, other): r = self.copy() r *= other return r def __rmul__(self, other): r = self.copy() r *= other return r def __itruediv__(self, other): self *= 1.0 / other return self def __truediv__(self, other): r = self.copy() r *= 1.0 / other return r def distribute(self, force_split = False): return self def sum_minkowski(self, other): """Minkowski sum of two regions with respect to their real variables. """ if other.get_type() != RegionType.NORMAL: return other.sum_minkowski(self) cs = self.copy() co = other.copy() param_real = cs.allcomprealvar().inter(co.allcomprealvar()) param_real_expr = Expr.zero() for a in param_real.varlist: newname = "TENSOR_TMP_" + a.name cs.substitute(Expr.real(a.name), Expr.real(a.name) - Expr.real(newname)) co.substitute(Expr.real(a.name), Expr.real(newname)) param_real_expr += Expr.real(newname) if PsiOpts.settings["tensorize_simplify"]: return (cs & co).eliminated(param_real_expr) else: return (cs & co).eliminated_quick(param_real_expr) def __add__(self, other): return self.sum_minkowski(other) def __iadd__(self, other): return self + other # def __invert__(self): # if not self.imp_present(): # t = self.one_flipped() # if t is not None: # return t.forall(self.aux) # return ~RegionOp.union([self]) # def negate(self): # return ~self def __invert__(self): return ~RegionOp.union([self]) def try_negate(self, eps_only = False): if not self.imp_present(): t = self.one_flipped(strict = True) if t is not None and (not eps_only or not t.ispresent(Expr.eps())): return t.forall(self.aux) return None def negate(self): t = self.try_negate() if t is not None: return t return ~RegionOp.union([self]) def forall_completed(self, exclude = None): toforall = self.rvs - (self.aux + self.auxi) if exclude is not None: toforall -= exclude return self.forall(toforall) def noeq(self): for a in self.exprs_eq: if not a.isnonneg(): self.exprs_ge.append(a.copy()) if not a.isnonpos(): self.exprs_ge.append(-a) a.setzero() for a in self.exprs_eqi: if not a.isnonneg(): self.exprs_gei.append(a.copy()) if not a.isnonpos(): self.exprs_gei.append(-a) a.setzero() self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()] self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()] return self def copy_noeq(self): r = self.copy() r.noeq() return r def toregionop_split(self, force_split = False): r = RegionOp.union([]) if self.imp_present(): r.regs.append((self.imp_flippedonly_noaux(), False)) t = self.consonly().copy_noaux() r.regs.append((t.exists(self.aux.copy()), True)) return r.forall(self.auxi.copy()) else: if force_split: t = self.copy_noaux().copy_noeq() if len(t.exprs_ge) == 1: return self.copy() return RegionOp.inter([a for a in t]).exists(self.aux.copy()).forall(self.auxi.copy()) else: return RegionOp.inter([self.copy_noaux()]).exists(self.aux.copy()).forall(self.auxi.copy()) def copy_rename(self): """Return a copy with renamed variables, together with map from old name to new. """ namemap = {} r = self.copy() index = IVarIndex() self.record_to(index) param_rv = index.comprv.copy() param_real = index.compreal.copy() for a in param_rv.varlist: name1 = index.name_avoid(a.name) index.record(Comp.rv(name1)) namemap[a.name] = name1 r.rename_var(a.name, name1) for a in param_real.varlist: name1 = index.name_avoid(a.name) index.record(Comp.real(name1)) namemap[a.name] = name1 r.rename_var(a.name, name1) return (r, namemap) # def nfold(self, bnet = None, n = 2, natural_eqprob = True, all_eqprob = True): # vs = self.allcomprv_noaux() # vmap = [] # for v in vs: # vname = v.get_name() # vmap[vname] = rv_array(vname, n) def tensorize(self, reg_subset = None, chan_cond = None, nature = None, timeshare = False, hint_aux = None, same_dist = False): """Check whether region tensorizes, return auxiliary RVs if tensorizes. chan_cond : The condition on the channel (e.g. degraded broadcast channel) """ for rr in self.tensorize_gen(reg_subset = reg_subset, chan_cond = chan_cond, nature = nature, timeshare = timeshare, hint_aux = hint_aux, same_dist = same_dist): if iutil.signal_type(rr) == "": return rr return None def tensorize_gen(self, reg_subset = None, chan_cond = None, nature = None, timeshare = False, hint_aux = None, same_dist = False): """Check whether region tensorizes, yield all auxiliary RVs if tensorizes. chan_cond : The condition on the channel (e.g. degraded broadcast channel) """ r2, namemap = self.copy_rename() rx = None if reg_subset is None: rx = self.copy() else: rx = reg_subset.copy() if chan_cond is None: chan_cond = Region.universe() chan_cond2 = chan_cond.copy() chan_cond2.rename_map(namemap) if nature is None: nature = Comp.empty() nature2 = nature.copy() nature2.rename_map(namemap) index = IVarIndex() self.record_to(index) #rx.record_to(index) param_rv = index.comprv - self.getauxall() param_real = index.compreal param_rv_map = Comp.empty() param_real_expr = Expr.zero() for a in param_rv.varlist: rx.substitute(Comp.rv(a.name), Comp.rv(a.name) + Comp.rv(namemap[a.name])) param_rv_map += Comp.rv(namemap[a.name]) rsum = None if timeshare: rsum = self.sum_entrywise(r2) for a in param_real.varlist: rx.substitute(Expr.real(a.name), Expr.real(namemap[a.name])) rsum.substitute(Expr.real(a.name), Expr.zero()) else: for a in param_real.varlist: r2.substitute(Expr.real(namemap[a.name]), Expr.real(namemap[a.name]) - Expr.real(a.name)) rx.substitute(Expr.real(a.name), Expr.real(namemap[a.name])) param_real_expr += Expr.real(a.name) rsum = self.copy() rsum.iand_norename(r2) if PsiOpts.settings["tensorize_simplify"]: rsum = rsum.eliminated(param_real_expr) else: rsum = rsum.eliminated_quick(param_real_expr) if rsum.get_type() == RegionType.NORMAL: rsum.aux = self.aux.interleaved(r2.aux) for a in param_real.varlist: rx.substitute(Expr.real(namemap[a.name]), Expr.real(a.name)) rsum.substitute(Expr.real(namemap[a.name]), Expr.real(a.name)) #rx.rename_avoid(chan_cond) chan_cond.aux_avoid_from(rx) rx &= chan_cond chan_cond2.aux_avoid_from(rx) rx &= chan_cond2 chan_cond_comp = chan_cond.allcomprv() chan_cond2_comp = chan_cond2.allcomprv() for i in range(chan_cond_comp.size()): namemap[chan_cond_comp.varlist[i].name] = chan_cond2_comp.varlist[i].name rx.iand_norename(Expr.Ic(self.allcomprv() - self.inp - self.oup - self.aux - self.auxi, r2.allcomprv() - r2.oup - r2.aux - r2.auxi, self.inp) == 0) rx.iand_norename(Expr.Ic(self.inp, r2.allcomprv() - r2.inp - r2.oup - r2.aux - r2.auxi, r2.inp) == 0) if not nature.isempty(): rx.iand_norename(Expr.I(nature, nature2) == 0) hint_pair = [] for (key, value) in namemap.items(): hint_pair.append((Comp.rv(key), Comp.rv(value))) if same_dist: rx.iand_norename(eqdist(param_rv, param_rv_map)) for rr in rx.implies_getaux_gen(rsum, hint_pair = hint_pair, hint_aux = hint_aux): yield rr def check_converse(self, reg_subset = None, chan_cond = None, nature = None, hint_aux = None): """Check whether self is the capacity region of the operational region. reg_subset, return auxiliary RVs if true. chan_cond : The condition on the channel (e.g. degraded broadcast channel). """ return self.tensorize(reg_subset, chan_cond, nature, True, hint_aux = hint_aux) def check_converse_gen(self, reg_subset = None, chan_cond = None, nature = None, hint_aux = None): """Check whether self is the capacity region of the operational region. reg_subset, yield all auxiliary RVs if true. chan_cond : The condition on the channel (e.g. degraded broadcast channel). """ for rr in self.tensorize_gen(reg_subset, chan_cond, nature, True, hint_aux = hint_aux): yield rr def __xor__(self, other): return self.eliminated(other) def __ixor__(self, other): return self.eliminate(other) def isfeasible(self): """Whether this region is feasible. """ return not self.implies(Expr.one() <= 0) def __bool__(self): return self.check() def __call__(self): return self.check() def truth_value(self): """Truth value of this region. Either True (proved to be True), False (proved to be False), or None (cannot be proved to be True/False). """ if bool(self): return True elif bool(~self): return False return None def assume(self): """Assume this region is true in the current context. """ PsiOpts.set_setting(truth_add = self) def assume_only(self): """Assume this region is true in the current context. Overwrite existing assumptions. """ if self.isuniverse(): PsiOpts.set_setting(truth = None) else: PsiOpts.set_setting(truth = self.copy()) def assumed(self): """Create a context where this region is assumed to be true. Use "with region.assumed(): ..." """ return PsiOpts(truth_add = self) def assumed_only(self): """Create a context where this region is assumed to be true. Overwrite existing assumptions. Use "with region.assumed_only(): ..." """ if self.isuniverse(): return PsiOpts(truth = None) else: return PsiOpts(truth = self.copy()) def proof(self, mode = None, **kwargs): """Get the proof of this region. """ if mode is None: mode = True r = None with PsiOpts(proof_new = mode, **{"proof_" + key: val for key, val in kwargs.items()}): bool(self) r = PsiOpts.get_proof() return r def bound(self, expr, var, sgn = 0, minsize = 1, maxsize = 3, coeffmode = 1, skip_simplify = False): """Automatically discover bounds on expr in terms of variables in var. Parameters: sgn : Set to 1 for upper bound, -1 for lower bound, 0 for both. minsize : Minimum number of terms in bound. maxsize : Maximum number of terms in bound. coeffmode : Set to 0 to only allow positive terms. Set to 1 to allow positive/negative terms, but not all negative. Set to 2 to allow positive/negative terms. skip_simplify : Set to True to skip final simplification. """ if sgn == 0: r = (self.bound(expr, var, sgn = 1, minsize = minsize, maxsize = maxsize, coeffmode = coeffmode, skip_simplify = skip_simplify) & self.bound(expr, var, sgn = -1, minsize = minsize, maxsize = maxsize, coeffmode = coeffmode, skip_simplify = skip_simplify)) if not skip_simplify: r.simplify_quick() return r varreal = [] varrv = [] if isinstance(var, list): for v in var: if isinstance(v, Expr): varreal.append(v) else: varrv.append(v) else: for v in var.allcomp(): if v.get_type() == IVarType.REAL: varreal.append(Expr.fromcomp(v)) elif v.get_type() == IVarType.RV: varrv.append(v) fcn = None if sgn > 0: fcn = lambda x: self.implies(expr <= x) elif sgn < 0: fcn = lambda x: self.implies(expr >= x) s = None for _, st in igen.test(igen.subset(itertools.chain(igen.sI(varrv), varreal), minsize = minsize, maxsize = maxsize, coeffmode = coeffmode), fcn, sgn = sgn, yield_set = True): s = st if s is None: return Region.universe() r = Region.universe() for x in s: if sgn > 0: r.exprs_ge.append(x - expr) elif sgn < 0: r.exprs_ge.append(expr - x) if not skip_simplify: r.simplify() return r def union_list(self, distribute = True): return [self.copy()] class SearchEntry: def __init__(self, x, x_reg = None, cmin = -1, cmax = 1): self.x = x if x_reg == None: self.x_reg = x.copy() else: self.x_reg = x_reg self.cmin = cmin self.cmax = cmax def discover(self, entries, method = "hull_auto", minsize = 1, maxsize = 2, skip_simplify = False, reg_init = None, skipto_ex = None, toreal_prefix = None, balanced = False, init_pts_outer = None, pts_outer = None): """Automatically discover inequalities between entries. Parameters: entries : List of variables of interest. minsize : Minimum number of terms in bound. maxsize : Maximum number of terms in bound. skip_simplify : Set to True to skip final simplification. """ ceps = PsiOpts.settings["eps"] truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (self & truth).discover(entries, method, minsize, maxsize, skip_simplify, reg_init, skipto_ex, toreal_prefix, balanced, init_pts_outer, pts_outer) indreg = self.get_indreg_checked(reg_add = entries) if indreg is not None: with PsiOpts(indreg_enabled = False): return (self & indreg).discover(entries, method, minsize, maxsize, skip_simplify, reg_init, skipto_ex, toreal_prefix, balanced, init_pts_outer, pts_outer) verbose = PsiOpts.settings.get("verbose_discover", False) verbose_detail = PsiOpts.settings.get("verbose_discover_detail", False) verbose_terms = PsiOpts.settings.get("verbose_discover_terms", False) verbose_terms_inner = PsiOpts.settings.get("verbose_discover_terms_inner", False) verbose_terms_outer = PsiOpts.settings.get("verbose_discover_terms_outer", False) simp_step = 10 varreal = [] varrv = [] varrv_markers = [] maxlen = 1 css = [cs.simplified_quick() for cs in self.union_list()] # cs = self.simplified_quick() #plain = cs.isplain() plain = True selfifs = None progs_self = [] progs_r = [] index_self = IVarIndex() for cs in css: cs.record_to(index_self) index_r = IVarIndex() isaffine = any(cs.affine_present() for cs in css) if isinstance(entries, Expr): # entries = entries.allcomp() entries = [entries] entries2 = [] for a in entries: if isinstance(a, tuple) and len(a) and a[0] is None: continue if isinstance(a, (CompArray, ExprArray)): for b in a: entries2.append(b) else: entries2.append(a) entries = entries2 maxent_comps = [] nonmaxent_present = False for a in entries: a2 = a cmarkers = {} if balanced: cmarkers["balanced"] = True if isinstance(a, tuple): a2 = [] for a0 in a: if isinstance(a0, str): cmarkers[a0] = True else: a2.append(a0) if len(a2) == 1: a2 = a2[0] else: a2 = tuple(a2) if not isinstance(a2, tuple): if isinstance(a2, Expr) and toreal_prefix is not None: a2 = (a2, Expr.real(toreal_prefix + str(a2))) else: a2 = (a2, None) if isinstance(a2[1], ExprArray) or isinstance(a2[1], CompArray): maxlen = max(maxlen, len(a2[1])) a2[0].record_to(index_r) aself = None if a2[1] is None: aself = a2[0] else: aself = a2[1] aself.record_to(index_self) if isinstance(aself, Expr) and aself.affine_present(): isaffine = True plain = plain and not aself.isregtermpresent() if isinstance(aself, Expr): for t in aself: tt = t.get_maxent_comp() if tt is not None: if tt not in maxent_comps: maxent_comps.append(tt) else: nonmaxent_present = True else: nonmaxent_present = True if isinstance(a2[0], Expr): varreal.append(a2[0]) else: varrv.append(a2[0]) varrv_markers.append(cmarkers) if maxent_comps and not nonmaxent_present: if PsiOpts.settings["maxent_lex_enabled"]: with PsiOpts(maxent_lex_enabled = False): r = Region.universe() for cs in self.maxent_lex_region_gen([(x, False) for x in maxent_comps], varadd = index_self.comprv): # print(cs) t = cs.discover(entries, method, minsize, maxsize, skip_simplify, reg_init, skipto_ex, toreal_prefix, balanced, init_pts_outer, pts_outer) r &= t return r.simplified_quick() for cs in css: if cs.get_type() == RegionType.NORMAL and not cs.aux.isempty() and not cs.imp_present(): cs.aux_strengthen(index_self.comprv) cs.aux = Comp.empty() plain = plain and all(cs.isplain() for cs in css) r = Region.universe() if reg_init is not None: r = reg_init.copy() #print(skipto_ex) if skipto_ex is not None: if not isinstance(skipto_ex, str): skipto_ex = skipto_ex.tostring() nadd = 0 if plain: selfifs = [cs.imp_flipped() for cs in css] if method == "hull_auto": if isaffine: method = "hull" else: method = "hull_cone" if (method == "hull" or method == "hull_cone") and not plain: method = "guess" #print(index_r.comprv) #print(index_self.comprv) if method == "semigraphoid" or method == "ic" or method == "ci": method = "semigraphoid" perform_semigraphoid = method == "semigraphoid" or method == "hull" or method == "hull_cone" if len(css) > 1: perform_semigraphoid = False sg = None if perform_semigraphoid: tvs = [] for a in entries: if isinstance(a, tuple) and isinstance(a[1], Comp): tvs.append((a[0], a[1])) elif isinstance(a, Comp): tvs.append(a) sg = css[0].completed_semigraphoid(vs = tvs) if method == "semigraphoid": return sg # sg = None vis = set() terms = [] if method == "hull" or method == "hull_cone": if sg is None or sg.isuniverse(): terms = list(itertools.chain(varreal, ent_vector(*varrv))) else: terms = list(itertools.chain(varreal, sg.get_basis(more_vars = sum(varrv, Comp.empty())))) else: terms = list(itertools.chain(varreal, igen.sI(varrv))) for a, markers in zip(varrv, varrv_markers): if markers.get("balanced", False): for t in terms: c = t.get_coeff(a) if abs(c) > ceps: t += Expr.H(a) * -c t.simplify_quick() terms = [t for t in terms if not t.iszero()] if verbose: print("discover nterms = " + str(len(terms))) lastres = [False] def expr_tr(ex): exsum = Expr.zero() for i in range(maxlen): tex = ex.copy() for a in entries: if isinstance(a, tuple): if isinstance(a[1], ExprArray) or isinstance(a[1], CompArray): if i < len(a[1]): tex.substitute(a[0], a[1][i]) else: if isinstance(a[0], Expr): tex.substitute(a[0], Expr.zero()) else: tex.substitute(a[0], Comp.empty()) else: tex.substitute(a[0], a[1]) exsum += tex return exsum #print(method) if method == "hull" or method == "hull_cone": # prog = cs.imp_flipped().init_prog(index = index_self, lp_bounded = True) # A = SparseMat(0) # for a in terms: # A.extend(prog.get_vec(expr_tr(a), sparse = True)[0]) # rt = prog.discover_hull(A, iscone = (method == "hull_cone"), init_pts_outer = init_pts_outer, pts_outer = pts_outer) progs = [cs.imp_flipped().init_prog(index = index_self, lp_bounded = True) for cs in css] Ams = [] for prog in progs: A = SparseMat(0) for a in terms: A.extend(prog.get_vec(expr_tr(a), sparse = True)[0]) Ams.append(A) m = len(terms) def toexpr(x): r = Expr.zero() for i in range(m): if abs(x[i]) > ceps: r += terms[i] * x[i] return r def cprog(x): ropt = None rr = None for A, prog in zip(Ams, progs): n = prog.nxvar c = [0.0] * n for i in range(m): for j, a in A.x[i]: c[j] += a * x[i] opt, v = prog.call_prog(c) if opt is None: return (None, None) r = [0.0] * m for i in range(m): for j, a in A.x[i]: r[i] += a * v[j] # return (opt, r) if ropt is None or opt < ropt: ropt = opt rr = r return (ropt, rr) rt = LinearProg.proj_hull(cprog, m, toexpr = toexpr, iscone = (method == "hull_cone"), init_pts_outer = init_pts_outer, pts_outer = pts_outer) if rt is None: return None inf_thres = progs[0].lp_ubound / 5.0 for x in rt: if abs(x[0]) > inf_thres: continue expr = Expr.zero() for i in range(len(terms)): if abs(x[i + 1]) > ceps: expr += terms[i] * x[i + 1] r.iand_norename(expr >= -x[0]) if not (sg is None or sg.isuniverse()): r &= sg if not skip_simplify: r.simplify() return r def exgen(): for size in range(minsize, maxsize + 1): if size == 1: for i in range(len(terms)): yield terms[i] res0 = lastres[0] yield -terms[i] res1 = lastres[0] if res0 and res1: terms[i] = None terms[:] = [a for a in terms if a is not None] elif size == 2: for i in range(len(terms)): for j in range(i): if terms[j] is not None: yield terms[i] - terms[j] res0 = lastres[0] yield terms[j] - terms[i] res1 = lastres[0] if res0 and res1: terms[i] = None break terms[:] = [a for a in terms if a is not None] else: for ex in igen.subset(terms, minsize = size, maxsize = maxsize, coeffmode = -1, replacement = True): yield ex break for ex in exgen(): if PsiOpts.is_timer_ended(): break if skipto_ex is not None: if skipto_ex != ex.tostring(): lastres[0] = False continue skipto_ex = None if verbose_terms: print(str(ex) + " <= 0") ex = ex.simplified_quick() if ex.isnonpos(): lastres[0] = True continue exhash = hash(ex) if exhash in vis: lastres[0] = False continue vis.add(exhash) r2 = ex <= 0 if r.implies_saved(r2, index_r, progs_r): lastres[0] = True continue exsum = expr_tr(ex) if ((plain and all(selfif.implies_impflipped_saved(exsum <= 0, index_self, progs_self) for selfif in selfifs)) or (not plain and all(cs.implies(exsum <= 0) for cs in css))): if verbose: print("ADD " + str(ex) + " <= 0") r &= r2 nadd += 1 if not skip_simplify and nadd % simp_step == 0: r.simplify() progs_r = [] #print("OKAY " + str(r2)) if verbose_detail: print(str(r)) lastres[0] = True else: #print("FAIL " + str(r2) + " " + str(plain) + " " + str(saved_info)) lastres[0] = False if not skip_simplify: r.simplify() return r def hull(self, entries = None, **kwargs): """Find the convex hull outer bound of this region. """ if entries is None: entries = [] return self.discover(list(self.rvs) + list(self.reals) + list(entries), **kwargs) @staticmethod def markov_tostring(cm, style = 0, tosort = False): if len(cm) % 2 == 1 and all(cm[i].isempty() for i in range(1, len(cm), 2)): tlist = [cm[i].tostring(style = style, tosort = tosort, add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(0, len(cm), 2)] if style & PsiOpts.STR_STYLE_LATEX: return (" " + PsiOpts.settings["latex_indep"] + " ").join(tlist) else: return "indep(" + ", ".join(tlist) + ")" else: tlist = [cm[i].tostring(style = style, tosort = tosort, add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(len(cm))] if style & PsiOpts.STR_STYLE_LATEX: return (" " + PsiOpts.settings["latex_markov"] + " ").join(tlist) else: return "markov(" + ", ".join(tlist) + ")" def tostring(self, style = 0, tosort = False, lhsvar = "real", inden = 0, add_bracket = False, small = False, skip_outer_exists = False): """Convert to string. Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) r = "" pf_note = PsiOpts.settings["str_proof_note"] if pf_note: note_pre = self.get_meta("pf_note_pre") if note_pre is not None: cs = self.copy() cs.remove_meta("pf_note_pre") r += cs.tostring(style = style, tosort = tosort, lhsvar = lhsvar, inden = inden, add_bracket = add_bracket, small = small, skip_outer_exists = skip_outer_exists) r = [iutil.pf_note_str(x, style, add_bracket = False) for x in note_pre] + [r] if style & PsiOpts.STR_STYLE_LATEX: r = iutil.latex_split_line(r, None) else: if isinstance(r, list): r = "\n".join(r) return r if isinstance(lhsvar, str) and lhsvar == "real": lhsvar = self.allcomprealvar() nlstr = "\n" if style & PsiOpts.STR_STYLE_LATEX: nlstr = "\\\\\n" spacestr = " " if style & PsiOpts.STR_STYLE_LATEX: spacestr = "\\;" if skip_outer_exists and not self.aux.isempty() and self.auxi.isempty() and self.isuniverse(): return spacestr * inden + self.aux.tostring(style = style, tosort = tosort) if self.isuniverse(sgn = False, canon = True): if style & PsiOpts.STR_STYLE_PSITIP: return spacestr * inden + "empty()" elif style & PsiOpts.STR_STYLE_LATEX: return spacestr * inden + PsiOpts.settings["latex_region_empty"] imp_pres = False if not self.auxi.isempty(): if style & PsiOpts.STR_STYLE_PSITIP: pass else: if style & PsiOpts.STR_STYLE_LATEX: if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += PsiOpts.settings["latex_forall"] + " " r += self.auxi.tostring(style = style, tosort = tosort) r += PsiOpts.settings["latex_quantifier_sep"] + " " if self.exprs_gei or self.exprs_eqi or not self.auxi.isempty(): add_bracket_inner = False inden_inner2 = inden curadd_bracket = bool(add_bracket or (not self.auxi.isempty() and (self.exprs_gei or self.exprs_eqi))) if style & PsiOpts.STR_STYLE_LATEX: if curadd_bracket: r += "\\left\\{" else: r += "(" if self.exprs_gei or self.exprs_eqi: imp_pres = True inden_inner = inden inden_inner1 = inden + 1 inden_inner2 = inden + 3 add_bracket_inner = bool(style & PsiOpts.STR_STYLE_PSITIP) r += spacestr * inden if style & PsiOpts.STR_STYLE_LATEX: r += "\\begin{array}{l}\n" if not small: r += "\displaystyle" r += " " inden_inner = 0 inden_inner1 = 0 inden_inner2 = 0 else: pass r += self.imp_flippedonly_noaux().tostring(style = style, tosort = tosort, lhsvar = lhsvar, inden = inden_inner1, add_bracket = add_bracket_inner).lstrip() if style & PsiOpts.STR_STYLE_PSITIP: r += nlstr + spacestr * inden_inner + ">> " elif style & PsiOpts.STR_STYLE_LATEX: r += (nlstr + ("\displaystyle " if not small else " ") + spacestr * inden_inner + PsiOpts.settings["latex_matimplies"] + " " + spacestr + " ") else: r += nlstr + spacestr * inden_inner + "=> " cs = self.copy() cs.exprs_gei = [] cs.exprs_eqi = [] cs.auxi = Comp.empty() r += cs.tostring(style = style, tosort = tosort, lhsvar = lhsvar, inden = inden_inner2, add_bracket = add_bracket_inner).lstrip() if not self.auxi.isempty(): if style & PsiOpts.STR_STYLE_PSITIP: pass else: if style & PsiOpts.STR_STYLE_LATEX: if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += " , " + PsiOpts.settings["latex_forall"] + " " r += self.auxi.tostring(style = style, tosort = tosort) else: r += " , forall " r += self.auxi.tostring(style = style, tosort = tosort) if style & PsiOpts.STR_STYLE_LATEX: if self.exprs_gei or self.exprs_eqi: r += nlstr + "\\end{array}" if curadd_bracket: r += "\\right\\}" else: r += ")" if not self.auxi.isempty(): if style & PsiOpts.STR_STYLE_PSITIP: r += ".forall(" + self.auxi.tostring(style = style, tosort = tosort) + ")" return r if not self.aux.isempty(): if style & PsiOpts.STR_STYLE_PSITIP: pass else: if style & PsiOpts.STR_STYLE_LATEX: if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: if not skip_outer_exists: r += PsiOpts.settings["latex_exists"] + " " r += self.aux.tostring(style = style, tosort = tosort) r += PsiOpts.settings["latex_quantifier_sep"] + " " cs = self bnets = None if style & PsiOpts.STR_STYLE_MARKOV: cs = cs.copy() tic = cs.remove_ic() bnets = BayesNet.from_ic_list(tic) eqnlist = ([x.tostring_eqn(">=", style = style, tosort = tosort, lhsvar = lhsvar) for x in cs.exprs_ge] + [x.tostring_eqn("==", style = style, tosort = tosort, lhsvar = lhsvar) for x in cs.exprs_eq]) if tosort: eqnlist = zip(eqnlist, [lhsvar is not None and any(x.ispresent(t) for t in lhsvar) for x in cs.exprs_ge] + [lhsvar is not None and any(x.ispresent(t) for t in lhsvar) for x in cs.exprs_eq]) eqnlist = sorted(eqnlist, key=lambda a: (not a[1], len(a[0]), a[0])) eqnlist = [x for x, t in eqnlist] if style & PsiOpts.STR_STYLE_MARKOV: eqnlist2 = [] for bnet in bnets: ms = bnet.get_markov() for cm in ms: eqnlist2.append(Region.markov_tostring(cm, style, tosort)) # if len(cm) % 2 == 1 and all(cm[i].isempty() for i in range(1, len(cm), 2)): # tlist = [cm[i].tostring(style = style, tosort = tosort, # add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(0, len(cm), 2)] # if style & PsiOpts.STR_STYLE_LATEX: # eqnlist2.append((" " + PsiOpts.settings["latex_indep"] + " ").join(tlist)) # else: # eqnlist2.append("indep(" + ", ".join(tlist) + ")") # else: # tlist = [cm[i].tostring(style = style, tosort = tosort, # add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(len(cm))] # if style & PsiOpts.STR_STYLE_LATEX: # eqnlist2.append((" " + PsiOpts.settings["latex_markov"] + " ").join(tlist)) # else: # eqnlist2.append("markov(" + ", ".join(tlist) + ")") if tosort: eqnlist2 = sorted(eqnlist2, key = lambda a: len(a)) eqnlist += eqnlist2 first = True use_array = style & PsiOpts.STR_STYLE_LATEX_ARRAY and len(eqnlist) > 1 isplu = True if style & PsiOpts.STR_STYLE_LATEX: isplu = len(eqnlist) > 1 and use_array else: isplu = not self.aux.isempty() or not self.auxi.isempty() or len(eqnlist) > 1 use_bracket = add_bracket or isplu if style & PsiOpts.STR_STYLE_PSITIP: r += spacestr * inden if use_bracket: r += "(" elif style & PsiOpts.STR_STYLE_LATEX: r += spacestr * inden if use_bracket: r += "\\left\\{" if use_array: r += "\\begin{array}{l}\n" else: r += spacestr * inden if use_bracket: r += PsiOpts.settings["str_brace_l"] for x in eqnlist: if style & PsiOpts.STR_STYLE_PSITIP: if first: if use_bracket: r += " " else: r += nlstr + spacestr * inden + " &" if isplu: r += "(" if use_bracket: r += " " elif style & PsiOpts.STR_STYLE_LATEX: if use_array: if first: r += spacestr * inden + " " else: r += "," + nlstr + spacestr * inden + " " if small: # r += "{\\scriptsize " pass else: if first: r += " " else: r += "," + spacestr + " " else: if first: if use_bracket: r += " " else: r += PsiOpts.settings["str_ineq_sep"] + nlstr + spacestr * inden + " " r += x if style & PsiOpts.STR_STYLE_PSITIP: r += " " if isplu: r += ")" elif style & PsiOpts.STR_STYLE_LATEX: if use_array: if small: # r += "}" pass first = False if len(eqnlist) == 0: if r != "": r += " " if style & PsiOpts.STR_STYLE_PSITIP: r += "universe()" elif style & PsiOpts.STR_STYLE_LATEX: r += PsiOpts.settings["latex_region_universe"] if style & PsiOpts.STR_STYLE_PSITIP: if use_bracket: r += " )" elif style & PsiOpts.STR_STYLE_LATEX: if use_array: r += nlstr + spacestr * inden + "\\end{array}" if use_bracket: r += " \\right\\}" else: if use_bracket: r += " " + PsiOpts.settings["str_brace_r"] if not self.aux.isempty(): if style & PsiOpts.STR_STYLE_PSITIP: pass else: if style & PsiOpts.STR_STYLE_LATEX: if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += " , " + PsiOpts.settings["latex_exists"] + " " r += self.aux.tostring(style = style, tosort = tosort) else: r += " , exists " r += self.aux.tostring(style = style, tosort = tosort) if imp_pres: r += ")" if not self.aux.isempty(): if style & PsiOpts.STR_STYLE_PSITIP: r += ".exists(" + self.aux.tostring(style = style, tosort = tosort) + ")" return r def __str__(self): lhsvar = None if PsiOpts.settings.get("str_lhsreal", False): lhsvar = "real" return self.tostring(PsiOpts.settings["str_style"], tosort = PsiOpts.settings["str_tosort"], lhsvar = lhsvar) def tostring_repr(self, style): if PsiOpts.settings.get("repr_check", False): #return str(self.check()) if self.check(): return str(True) lhsvar = None if PsiOpts.settings.get("str_lhsreal", False): lhsvar = "real" if PsiOpts.settings.get("repr_simplify", False): return self.simplified_quick().tostring(style, tosort = PsiOpts.settings["str_tosort"], lhsvar = lhsvar) return self.tostring(style, tosort = PsiOpts.settings["str_tosort"], lhsvar = lhsvar) def __repr__(self): return self.tostring_repr(PsiOpts.settings["str_style_repr"]) @latex_postprocess def _latex_(self): return self.tostring_repr(iutil.convert_str_style("latex")) def __hash__(self): #return hash(self.tostring(tosort = True)) return hash(( hash(frozenset(hash(x) for x in self.exprs_ge)), hash(frozenset(hash(x) for x in self.exprs_eq)), hash(frozenset(hash(x) for x in self.exprs_gei)), hash(frozenset(hash(x) for x in self.exprs_eqi)), hash(self.aux), hash(self.inp), hash(self.oup), hash(self.auxi) )) def ent_vector_discover_ic(v, x, eps = None, skip_simplify = False): if eps is None: eps = PsiOpts.settings["eps_check"] n = len(x) mask_all = (1 << n) - 1 if len(v) == (1 << n) - 1: v = [0.0] + list(v) r = Region.universe() for mask in range(1 << n): for i in range(n): if not (1 << i) & mask: if v[mask | (1 << i)] - v[mask] <= eps: r.exprs_eq.append(Expr.Hc(x[i], x.from_mask(mask))) for mask in range(1, 1 << n): bins = dict() for xmask in igen.subset_mask(mask_all - mask): t = v[mask | xmask] - v[xmask] if t <= eps: continue tbin0 = int(t / eps + 0.5) for tbin in range(max(tbin0 - 1, 0), tbin0 + 2): if tbin in bins: for ymask in bins[tbin]: if ymask & xmask == ymask and xmask - ymask < mask: r.exprs_eq.append(Expr.Ic(x.from_mask(mask), x.from_mask(xmask - ymask), x.from_mask(ymask))) if tbin == tbin0: bins[tbin].append(xmask) else: if tbin == tbin0: bins[tbin] = [xmask] if not skip_simplify: r.simplify_bayesnet(reduce_ic = True) r.simplify() return r class RegionOp(Region): """A region which is the union/intersection of a list of regions.""" def __init__(self, rtype, regs, auxs, inp = None, oup = None, meta = None): self.rtype = rtype self.regs = regs self.auxs = auxs self.inp = Comp.empty() if inp is None else inp self.oup = Comp.empty() if oup is None else oup self.meta = meta def get_type(self): return self.rtype def isnormalcons(self): return False def isuniverse(self, sgn = True, canon = False): if canon: if sgn: return self.get_type() == RegionType.INTER and len(self.regs) == 0 and len(self.auxs) == 0 else: return self.get_type() == RegionType.UNION and len(self.regs) == 0 and len(self.auxs) == 0 #return False isunion = (self.get_type() == RegionType.UNION) for x, c in self.regs: if isunion ^ sgn ^ x.isuniverse(not c ^ sgn): return not isunion ^ sgn return isunion ^ sgn def isempty(self): return self.isuniverse(False) def iseq(self): """Is this pure equality. """ return False def isineq(self): """Is this pure inequality. """ return False def expr(self): """Returns the sum of the expressions in this region. """ return sum((x.expr() for x, c in self.regs), Expr.zero()) def isplain(self): return False def copy(self): return RegionOp(self.rtype, [(x.copy(), c) for x, c in self.regs], [(x.copy(), c) for x, c in self.auxs], self.inp.copy(), self.oup.copy(), iutil.copy(self.meta)) def copy_noaux(self): return RegionOp(self.rtype, [(x.copy_noaux(), c) for x, c in self.regs], [], self.inp.copy(), self.oup.copy(), iutil.copy(self.meta)) def noaux(self): return self.copy_noaux() def copy_(self, other): self.rtype = other.rtype self.regs = [(x.copy(), c) for x, c in other.regs] self.auxs = [(x.copy(), c) for x, c in other.auxs] self.inp = other.inp.copy() self.oup = other.oup.copy() self.meta = iutil.copy(other.meta) def imp_flip(self): if self.get_type() == RegionType.INTER: self.rtype = RegionType.UNION elif self.get_type() == RegionType.UNION: self.rtype = RegionType.INTER for x, c in self.regs: x.imp_flip() return self def imp_flipped(self): r = self.copy() r.imp_flip() return r def universe_type(rtype): return RegionOp(rtype, [(Region.universe(), True)], []) def union(xs = None, tosimple = False): if xs is None: xs = [] if tosimple and len(xs) == 1: return xs[0].copy() return RegionOp(RegionType.UNION, [(x.copy(), True) for x in xs], []) def inter(xs = None, tosimple = False): if xs is None: xs = [] if tosimple and len(xs) == 1: return xs[0].copy() return RegionOp(RegionType.INTER, [(x.copy(), True) for x in xs], []) def __len__(self): return len(self.regs) def __getitem__(self, key): r = self.regs[key] if isinstance(r, list): return RegionOp(self.rtype, r, []) else: if r[1]: return r[0] else: return RegionOp(self.rtype, [r], []) def union_list(self, distribute = True): if distribute: r = self.copy() r.distribute() return r.union_list(distribute = False) if self.get_type() == RegionType.UNION: return list(self.copy()) return [self.copy()] def add_meta(self, key, value, children = True): if not children: return IBaseObj.add_meta(self, key, value) for x, c in self.regs: x.add_meta(key, value, children = children) return self def get_meta(self, key): t = IBaseObj.get_meta(self, key) if t is not None: return t for x, c in self.regs: t = x.get_meta(key) if t is not None: return t return None def remove_meta(self, key): IBaseObj.remove_meta(self, key) for x, c in self.regs: x.remove_meta(key) return self def ispresent(self, x): """Return whether any variable in x appears here.""" for z, c in self.regs: if z.ispresent(x): return True for z, c in self.auxs: if z.ispresent(x): return True return False def __contains__(self, other): x = other if isinstance(x, str): if x == "exists": if any(c for z, c in self.auxs): return True if x == "forall": if any(not c for z, c in self.auxs): return True for z, c in self.regs: if x in z: return True for z, c in self.auxs: if x in z: return True return False def allcomprealvar(self): r = Comp.empty() for z, c in self.regs: r += z.allcomprealvar() return r def allcompreal_exprlist(self): r = ExprArray([]) for z, c in self.regs: r.iadd_noduplicate(z.allcompreal_exprlist()) return r def allcomprealvar_exprlist(self): r = ExprArray([]) for z, c in self.regs: r.iadd_noduplicate(z.allcomprealvar_exprlist()) return r def rename_var(self, name0, name1): for x, c in self.regs: x.rename_var(name0, name1) for x, c in self.auxs: x.rename_var(name0, name1) def rename_map(self, namemap): for x, c in self.regs: x.rename_map(namemap) for x, c in self.auxs: x.rename_map(namemap) def getaux(self): r = Comp.empty() for x, c in self.auxs: if c: r += x for x, c in self.regs: if c: r += x.getaux() else: r += x.getauxi() return r def getauxi(self): r = Comp.empty() for x, c in self.auxs: if not c: r += x for x, c in self.regs: if c: r += x.getauxi() else: r += x.getaux() return r def getauxall(self): r = Comp.empty() for x, c in self.auxs: r += x for x, c in self.regs: r += x.getauxall() return r def getauxs(self): return [(x.copy(), c) for x, c in self.auxs] @property def aux(self): return self.getaux() @property def auxi(self): return self.getauxi() @fcn_substitute def substitute(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place.""" for x, c in self.regs: x.substitute(v0, v1) if not isinstance(v0, Expr): for x, c in self.auxs: x.substitute(v0, v1) if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute(self.meta, v0, v1) return self @fcn_substitute def substitute_whole(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place.""" for x, c in self.regs: x.substitute_whole(v0, v1) if not isinstance(v0, Expr): for x, c in self.auxs: x.substitute_whole(v0, v1) if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute_whole(self.meta, v0, v1) return self @fcn_substitute def substitute_aux(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, in place.""" for x, c in self.regs: x.substitute_aux(v0, v1) self.auxs = [(x - v0, c) for x, c in self.auxs if not (x - v0).isempty()] if iutil.check_meta_subs_criteria(v0, v1, self): iutil.substitute(self.meta, v0, v1) return self def remove_present(self, v): for x, c in self.regs: x.remove_present(v) self.auxs = [(x, c) for x, c in self.auxs if not x.ispresent(v)] def remove_notpresent(self, v): for x, c in self.regs: x.remove_notpresent(v) self.auxs = [(x, c) for x, c in self.auxs if x.ispresent(v)] def remove_relax(self, v, sn = 1): for x, c in self.regs: x.remove_relax(v, sn * (1 if c else -1)) self.auxs = [(x, c) for x, c in self.auxs if not x.ispresent(v)] def condition(self, b): """Condition on random variable b, in place.""" for x, c in self.regs: x.condition(b) return self def symm_sort(self, terms): """Sort the random variables in terms assuming symmetry among those terms.""" for x, c in self.regs: x.symm_sort(terms) def relax(self, w, gap): """Relax real variables in w by gap, in place""" for x, c in self.regs: if c: x.relax(w, gap) else: x.relax(w, -gap) return self def record_to(self, index): for x, c in self.regs: x.record_to(index) for x, c in self.auxs: x.record_to(index) def pack_type_self(self, totype): if len(self.auxs) == 0: ctype = self.get_type() if ctype == totype: return self if ctype == RegionType.UNION or ctype == RegionType.INTER: if len(self.regs) == 1: self.rtype = totype return self self.regs = [(self.copy(), True)] self.auxs = [] self.rtype = totype return self def pack_type(x, totype): if len(x.getauxs()) == 0: ctype = x.get_type() if ctype == totype: return x.copy() if ctype == RegionType.UNION or ctype == RegionType.INTER: if len(x.regs) == 1: r = x.copy() r.rtype = totype return r return RegionOp(totype, [(x.copy(), True)], []) def iand_norename(self, other): self.pack_type_self(RegionType.INTER) other = RegionOp.pack_type(other, RegionType.INTER) self.regs += other.regs self.auxs += other.auxs self.inter_compress() return self def inter_compress(self): if self.get_type() != RegionType.INTER and self.get_type() != RegionType.UNION: return curc = self.get_type() == RegionType.INTER cons = Region.universe() for x, c in self.regs: if c == curc and x.isnormalcons() and x.getauxall().isempty(): cons &= x x.setuniverse() self.regs = [(cons, curc)] + self.regs self.regs = [(x, c) for x, c in self.regs if c != curc or not x.isuniverse()] def normalcons_sort(self): if self.get_type() != RegionType.INTER and self.get_type() != RegionType.UNION: return curc = self.get_type() == RegionType.INTER regsf = [(x, c) for x, c in self.regs if c == curc and x.isnormalcons()] regsb = [(x, c) for x, c in self.regs if c != curc and x.isnormalcons()] self.regs = regsf + [(x, c) for x, c in self.regs if not x.isnormalcons()] + regsb def __iand__(self, other): if isinstance(other, bool): if not other: return Region.empty() return self other = iutil.ensure_region(other) if other.isuniverse(canon = True): return self self.pack_type_self(RegionType.INTER) other = RegionOp.pack_type(other, RegionType.INTER) self.aux_avoid(other) self.regs += other.regs self.auxs += other.auxs self.inter_compress() return self def __and__(self, other): r = self.copy() r &= other return r def __rand__(self, other): r = self.copy() r &= other return r def __ior__(self, other): if isinstance(other, bool): if other: return Region.universe() return self other = iutil.ensure_region(other) if other.isuniverse(sgn = False, canon = True): return self other = RegionOp.pack_type(other, RegionType.UNION) self.pack_type_self(RegionType.UNION) self.aux_avoid(other) self.regs += other.regs self.auxs += other.auxs return self def ior_norename(self, other): other = iutil.ensure_region(other) other = RegionOp.pack_type(other, RegionType.UNION) self.pack_type_self(RegionType.UNION) self.regs += other.regs self.auxs += other.auxs return self def rior(self, other): other = iutil.ensure_region(other) other = RegionOp.pack_type(other, RegionType.UNION) self.pack_type_self(RegionType.UNION) self.aux_avoid(other) self.regs = other.regs + self.regs self.auxs = other.auxs + self.auxs return self def __or__(self, other): r = self.copy() r |= other return r def __ror__(self, other): r = self.copy() r |= other return r def append_avoid(self, x, c = True): y = x.copy() self.aux_avoid(y) self.regs.append((y, c)) def __imul__(self, other): for i in range(len(self.regs)): self.regs[i][0] *= other return self def negate(self): if self.get_type() == RegionType.UNION: self.rtype = RegionType.INTER elif self.get_type() == RegionType.INTER: self.rtype = RegionType.UNION self.regs = [(x, not c) for x, c in self.regs] self.auxs = [(x, not c) for x, c in self.auxs] return self def __invert__(self): r = self.copy() r = r.negate() return r def negateornot(x, c): if c: return x.copy() else: return ~x def setuniverse(self): self.rtype = RegionType.INTER self.regs = [] def setempty(self): self.rtype = RegionType.UNION self.regs = [] def universe(): return RegionOp(RegionType.INTER, [], []) def empty(): return RegionOp(RegionType.UNION, [], []) #return ~Region.universe() def z3(self, vardict = None): if z3 is None: return None if vardict is None: vardict = iutil.z3_vardict(self.rvs, self.aux, self.auxi, self.reals) r = [] for x, c in self.regs: t = x.z3(vardict) if not c: t = z3.Not(t) r.append(t) if len(r) == 1: r = r[0] else: if self.get_type() == RegionType.UNION: r = z3.Or(r) elif self.get_type() == RegionType.INTER: r = z3.And(r) for x, c in self.auxs: if not c: # raise ValueError("Universally-quantified random variables are not supported for Z3. Use another solver.") r = z3.ForAll([y.z3(vardict) for y in x], r) else: r = z3.Exists([y.z3(vardict) for y in x], r) return r def sum_minkowski(self, other): cs = self.copy() cs.distribute() other = other.copy() other.distribute() # warnings.warn("Minkowski sum of union or intersection regions is unsupported.", RuntimeWarning) auxs = [(x.copy(), c) for x, c in cs.getauxs() + other.getauxs()] if cs.get_type() == RegionType.UNION: return RegionOp(RegionType.UNION, [(x.sum_minkowski(other), c) for x, c in cs.regs], auxs) if other.get_type() == RegionType.UNION: return RegionOp(RegionType.UNION, [(cs.sum_minkowski(x), c) for x, c in other.regs], auxs) # The following are technically wrong if cs.get_type() == RegionType.INTER: return RegionOp(RegionType.INTER, [(x.sum_minkowski(other), c) for x, c in cs.regs], auxs) if other.get_type() == RegionType.INTER: return RegionOp(RegionType.INTER, [(cs.sum_minkowski(x), c) for x, c in other.regs], auxs) return cs.copy() def implicate(self, other, skip_simplify = False): other = iutil.ensure_region(other) cs = self cs |= ~other return cs def implicate_norename(self, other, skip_simplify = False): other = iutil.ensure_region(other) cs = self cs.ior_norename(~other) return cs def implicated(self, other, skip_simplify = False): other = iutil.ensure_region(other) r = self.copy() r = r.implicate(other, skip_simplify) return r def sum_entrywise(self, other): r = RegionOp(self.rtype, [], []) for ((x1, c1), (x2, c2)) in zip(self.regs, other.regs): r.regs.append((x1.sum_entrywise(x2), c1)) for ((x1, c1), (x2, c2)) in zip(self.auxs, other.auxs): r.regs.append((x1.interleaved(x2), c1)) return r def corners_optimum(self, w, sn): if self.get_type() == RegionType.UNION: r = self.copy() r.regs = [] did = False for x, c in self.regs: if not c: r.append_avoid(x.copy(), c) continue t = x.corners_optimum(w, sn) if t.isuniverse(): r.append_avoid(x.copy(), c) else: r.append_avoid(t, c) did = True if did: return r return Region.universe() elif self.get_type() == RegionType.INTER: r = RegionOp.union([]) r.auxs = [(x.copy(), c) for x, c in self.auxs] for x, c in self.regs: if not c: continue t = x.corners_optimum(w, sn) if t.isuniverse(): continue else: tt = self.copy() for i in range(len(tt.regs)): if tt.regs[i] is x: tt.regs[i] = t break r.append_avoid(tt) if len(r.regs) > 0: return r return Region.universe() return Region.universe() def balance(self, v = None, w = None, skip_simplify = False): if w is None: w = self.allcomprv() for x, c in self.regs: x.balance(v, w, skip_simplify = True) if not skip_simplify: self.simplify() return self def sign_present(self, term): r = [False] * 2 for x, c in self.regs: t = x.sign_present(term) if not c: t.reverse() r[0] |= t[0] r[1] |= t[1] if r[0] and r[1]: break return r def substitute_sign(self, v0, v1s): r = [False] * 2 t = [False] * 2 for x, c in self.regs: if c: t = x.substitute_sign(v0, v1s) else: t = x.substitute_sign(v0, list(reversed(v1s)) if v1s else v1s) t.reverse() r[0] |= t[0] r[1] |= t[1] return r def substitute_duplicate(self, v0, v1s): for x, c in self.regs: x.substitute_duplicate(v0, v1s) def flatten_minmax(self, term, sn, bds): # for x, c in self.regs: # x.flatten_minmax(term, sgn, bds) latex_name = term.tostring(style = "latex") v1s = [Expr.real(self.name_avoid(str(term) + "_L") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name), Expr.real(self.name_avoid(str(term) + "_U") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name)] # v1s = [Expr.real(self.name_avoid(str(term) + "_L")), # Expr.real(self.name_avoid(str(term) + "_U"))] sn_present = self.substitute_sign(Expr.fromterm(term), v1s) if sn < 0: sn_present.reverse() v1s.reverse() if sn_present[1]: treg = Region.universe() for b in bds: treg &= v1s[1] * sn <= b * sn self &= treg # print(self) # print(v1s[1]) self.eliminate(v1s[1]) if sn_present[0]: tself = RegionOp.pack_type(self.copy(), RegionType.UNION) if any(not c for x, c in tself.regs): tself = RegionOp.union([tself]) self.copy_(RegionOp.union([])) for t2 in tself: if t2.ispresent(v1s[0]): for b in bds: t3 = t2.copy() t3.substitute(v1s[0], b) self |= t3 else: self |= t2 # print("AFTER FLATTEN REGTERM") # print(self) # print(sn_present) # print() return self def lowest_present(self, v, sn): ps = [] cpres = False for x, c in self.regs: if x.ispresent(v): ps.append(x) cpres = c if len(ps) == 0: return None if len(ps) == 1: t = ps[0].lowest_present(v, sn ^ (not cpres)) if t is not None: return t if sn: return self return None def term_sn_present(self, term, sn): sn *= term.sn sn_present = self.substitute_sign(Expr.fromterm(term), None) if sn < 0: sn_present.reverse() return sn_present[0] def term_sn_present_both(self, term): sn_present = self.substitute_sign(Expr.fromterm(term), None) return sn_present[0] and sn_present[1] def flatten_regterm(self, term, isimp = True, minmax_elim = False): if term.reg is None: return # print("FLAT " + str(minmax_elim)) if minmax_elim: treg, tsgn, tbds = term.get_reg_sgn_bds() if treg.isuniverse() and tsgn != 0: self.flatten_minmax(term, tsgn, tbds) return self.simplify_quick() sn = term.sn latex_name = term.tostring(style = "latex") v1s = [Expr.real(self.name_avoid(str(term) + "_L") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name), Expr.real(self.name_avoid(str(term) + "_U") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name)] sn_present = self.substitute_sign(Expr.fromterm(term), v1s) if sn < 0: sn_present.reverse() v1s.reverse() if sn == 0: sn_present[1] = False sn_present[0] = True if sn_present[1]: term2 = term.copy() self.aux_avoid(term2.reg) rsb = term2.get_reg_sgn_bds() if (PsiOpts.settings["flatten_distribute"] and rsb is not None and not rsb[0].imp_present() and isimp and (PsiOpts.settings["flatten_distribute_multi"] or len(rsb[2]) == 1)): treg, tsgn, tbds = rsb lpres = self.lowest_present(v1s[1], True) lpres.substitute_duplicate(v1s[1], tbds) taux = treg.aux.copy() treg.aux = Comp.empty() lpres &= treg lpres.eliminate(taux) else: reg2 = term.reg.broken_present(Expr.fromterm(term), flipped = True) reg2.substitute(Expr.fromterm(term), v1s[1]) if isimp: self |= reg2 else: self &= ~reg2 if sn_present[0]: reg2 = term.reg.copy() reg2.substitute(Expr.fromterm(term), v1s[0]) if isimp: self.implicate(reg2) else: self &= reg2 # print("AFTER FLATTEN REGTERM") # print(self) # print(sn_present) # print() return self def flatten_ivar(self, ivar, isimp = True): if ivar.reg is None: return newvar = Comp([ivar.copy_noreg()]) reg2 = ivar.reg.copy() self.aux_avoid(reg2) if not ivar.reg_det: newindep = Expr.Ic(reg2.getaux() + newvar, self.allcomprv().reg_excluded() - self.getaux() - reg2.allcomprv() - newvar, reg2.allcomprv_noaux() - newvar).simplified_quick() if not newindep.iszero(): reg2.iand_norename(newindep == 0) #self.implicate_norename(newindep == 0) if isimp: #self.implicate_norename(reg2.exists(newvar), skip_simplify = True) self.implicate_norename(reg2, skip_simplify = True) else: self &= reg2 self.substitute(Comp([ivar]), newvar) return self def isregtermpresent(self): for x, c in self.regs: if x.isregtermpresent(): return True return False def regtermmap(self, cmap, recur): for x, c in self.regs: x.regtermmap(cmap, recur) aux = Comp.empty() for x, c in self.auxs: aux += x (H(aux) >= 0).regtermmap(cmap, recur) def regterm_split(self, term): if not (self.term_sn_present_both(term) or term.reg_outer is not None): return False # sn = term.sn # terml = term.copy() # terml.substitute(term.x[0], Comp.real(self.name_avoid(str(term) + "_L0"))) # if terml.reg_outer is not None: # terml.reg = terml.reg_outer # terml.reg_outer = None # termr = term.copy() # termr.substitute(term.x[0], Comp.real(self.name_avoid(str(term) + "_R0"))) # termr.reg_outer = None # v1s = [Expr.fromterm(terml), Expr.fromterm(termr)] # if sn < 0: # v1s.reverse() v1s = [Expr.fromterm(term.upper_bound(name = self.name_avoid(str(term) + "_L0"))), Expr.fromterm(term.lower_bound(name = self.name_avoid(str(term) + "_R0")))] self.substitute_sign(Expr.fromterm(term), v1s) return True def flatten(self, minmax_elim = False): verbose = PsiOpts.settings.get("verbose_flatten", False) write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) and PsiOpts.settings.get("proof_step_expand_def", False) if write_pf_enabled: prevself = self.copy() did = True didall = False while did: did = False regterms = {} # self.regtermmap(regterms, True) # for (name, term) in regterms.items(): # regterms_in = {} # term.reg.regtermmap(regterms_in, False) # if not regterms_in: regterms_exc = {} self.regtermmap(regterms, False) for (name, term) in regterms.items(): term.reg.regtermmap(regterms_exc, True) for (name, term) in regterms.items(): if name not in regterms_exc: if isinstance(term, IVar): pass else: if self.regterm_split(term): did = True break if did: continue for cpass, (name, term) in itertools.product(range(2), regterms.items()): if name not in regterms_exc: if isinstance(term, IVar): if cpass == 1: continue else: if cpass == 0 and self.term_sn_present(term, 1): continue if verbose: print("========= flatten op ========") print(self) print("========= term ========") print(term) print("========= region ========") print(term.reg) if isinstance(term, IVar): self.flatten_ivar(term) else: self.flatten_regterm(term, minmax_elim = minmax_elim) did = True didall = True if verbose: print("========= to ========") print(self) break if write_pf_enabled: if didall: # pf = ProofObj.from_region(prevself, c = "Expand definitions") # pf += ProofObj.from_region(self, c = "Expanded definitions to") pf = ProofObj.from_region(("equiv", prevself, self), c = "Expand definitions:") PsiOpts.set_setting(proof_add = pf) return self def tosimple(self): r = Region.universe() if self.get_type() == RegionType.UNION: if len(self.regs) == 0: r = Region.empty() elif len(self.regs) == 1 and self.regs[0][1]: r = self.regs[0][0].tosimple() if r is None: return None else: return None elif self.get_type() == RegionType.INTER: for x, c in self.regs: if not c: return None t = x.tosimple() if t is None or t.imp_present(): return None r.iand_norename(t) for x, c in self.auxs: if not c: return None r.eliminate(x) return r def tosimple_safe(self): if not self.getauxi().isempty(): return None for x, c in self.regs: if x.aux_present(): return None return self.tosimple() def tonormal_safe(self): t = self.tosimple_safe() if t is not None: return t return self.copy() def level(self, mode = ""): """The level of a region in the linear entropy hierarchy. """ r = Level() for x, c in self.regs: t = x.level(mode) if not c: t = ~t r = r | t for x, c in self.auxs: r = r.exists(not c) return r def complexity(self): return sum(x.complexity() for x, c in self.regs) + len(self.auxs) * 100 + sum(len(x) for x, c in self.auxs) * 100 def sorting_priority(self): return self.complexity() def var_neighbors(self, v): r = v.copy() for x, c in self.regs: r += x.var_neighbors(v) return r def one_flipped(self): return None def distribute(self, force_split = False): """Expand to a single union layer. """ # print(self) # if self.get_type() == RegionType.UNION: # print("UNION") # if self.get_type() == RegionType.INTER: # print("INTER") # print() if self.get_type() == RegionType.UNION: tregs = [] for x, c in self.regs: if force_split and x.get_type() == RegionType.NORMAL: if x.imp_present(): x = x.toregionop_split() elif not c: x = x.toregionop_split(force_split = True) if force_split or isinstance(x, RegionOp): if not c: x = x.negate() c = True x.distribute(force_split = force_split) if c and x.get_type() == RegionType.UNION: tregs += x.regs self.auxs = x.getauxs() + self.auxs else: tregs.append((x, c)) self.regs = tregs return self if self.get_type() == RegionType.INTER: tregs = [(Region.universe(), True)] self.rtype = RegionType.UNION for x, c in self.regs: if force_split and x.get_type() == RegionType.NORMAL: if x.imp_present(): x = x.toregionop_split() elif not c: x = x.toregionop_split(force_split = True) if force_split or isinstance(x, RegionOp): if not c: x = x.negate() c = True x.distribute(force_split = force_split) if c and x.get_type() == RegionType.UNION: tregs2 = [] for y, cy in x.regs: for a, ca in tregs: tregs2.append((RegionOp.negateornot(a, ca) & RegionOp.negateornot(y, cy), True)) tregs = tregs2 self.auxs = x.getauxs() + self.auxs else: tregs = [(RegionOp.negateornot(a, ca) & RegionOp.negateornot(x, c), True) for a, ca in tregs] self.auxs = x.getauxs() + self.auxs self.regs = tregs return self return self def tounion(self): r = self.copy() r = r.distribute() if r.get_type() == RegionType.UNION: return r return RegionOp.pack_type(r, RegionType.UNION) def aux_appearance(self, curc): allcomprv = self.allcomprv() - self.getauxall() r = [] for x, c in self.regs: if isinstance(x, RegionOp): r += x.aux_appearance(curc ^ (not c)) else: xallcomprv = x.allcomprv() - x.getauxall() if not x.auxi.isempty(): r.append((x.auxi.copy(), curc ^ (not c) ^ True, xallcomprv.copy())) if not x.aux.isempty(): r.append((x.aux.copy(), curc ^ (not c), xallcomprv.copy())) rt = [] for x, c in self.auxs[::-1]: rt.append((x.copy(), curc ^ (not c), allcomprv.copy())) allcomprv += x r += rt[::-1] return r def aux_remove(self): self.auxs = [] for x, c in self.regs: if isinstance(x, RegionOp): x.aux_remove() else: x.aux = Comp.empty() x.auxi = Comp.empty() return self def aux_collect(self): """Collect auxiliaries to outermost layer. """ iaux = [] for x, c in self.regs: if isinstance(x, RegionOp): x.aux_collect() iaux += [(x2, c2 ^ (not c)) for x2, c2 in x.auxs] x.auxs = [] else: iaux += [(x2, c2 ^ (not c)) for x2, c2 in x.getauxs()] x.aux = Comp.empty() x.auxi = Comp.empty() self.auxs = iaux + self.auxs def break_imp(self): for i in range(len(self.regs)): if self.regs[i][0].get_type() == RegionType.NORMAL and self.regs[i][0].imp_present(): rc = self.regs[i][0].consonly() ri = self.regs[i][0].imp_flippedonly() ri.aux = Comp.empty() r = RegionOp.union([rc]) if not ri.isuniverse(): r = r.implicated(ri) if not self.regs[i][0].auxi.isempty(): r.auxs = [(self.regs[i][0].auxi.copy(), False)] self.regs[i] = (r, self.regs[i][1]) if isinstance(self.regs[i][0], RegionOp): self.regs[i][0].break_imp() def break_imp_old(self): for i in range(len(self.regs)): if self.regs[i][0].get_type() == RegionType.NORMAL and self.regs[i][0].imp_present(): auxs = self.regs[i][0].getauxs() rc = self.regs[i][0].consonly() rc.aux = Comp.empty() ri = self.regs[i][0].imp_flippedonly() ri.aux = Comp.empty() r = RegionOp.union([rc]) if not ri.isuniverse(): r = r.implicated(ri) r.auxs = auxs self.regs[i] = (r, self.regs[i][1]) if isinstance(self.regs[i][0], RegionOp): self.regs[i][0].break_imp() def from_region(x): if isinstance(x, RegionOp): return x.copy() r = RegionOp.union([x]) r.break_imp() if len(r.regs) == 1 and r.regs[0][1] and isinstance(r.regs[0][0], RegionOp): return r.regs[0][0] return r def break_present(self, w, flipped = True): for i in range(len(self.regs)): x, c = self.regs[i] if isinstance(x, RegionOp): x.break_present(w, flipped) else: self.regs[i] = (x.broken_present(w, flipped), c) return self def broken_present(self, w, flipped = True): r = self.copy() r = r.break_present(w, flipped) return r def aux_clean(self): auxs = self.auxs self.auxs = [] for x, c in auxs: if len(self.auxs) > 0 and self.auxs[-1][1] == c: self.auxs[-1]= (self.auxs[-1][0] + x, self.auxs[-1][1]) else: self.auxs.append((x, c)) def emptycons_present(self): for x, c in self.regs: cons_present = False if x.get_type() == RegionType.INTER: for y, cy in x.regs: if not cy: cons_present = True break elif x.get_type() == RegionType.NORMAL: if not c: cons_present = True if not cons_present: return True return False def sign_has_component(self, sn): for x, c in self.regs: if x.get_type() == RegionType.NORMAL: if c == sn: return True else: if x.sign_has_component(sn ^ (not c)): return True return False def sign_only_inplace(self, sn, outermost = False): t = [] for x, c in self.regs: if x.get_type() == RegionType.NORMAL: if c == sn: t.append((x, c)) elif outermost: if c == sn and x.sign_has_component(sn ^ (not c)): x.sign_only_inplace(sn ^ (not c)) t.append((x, c)) else: x.sign_only_inplace(sn ^ (not c)) t.append((x, c)) self.regs = t self.auxs = [(x, c) for x, c in self.auxs if c == sn] def sign_only(self, sn, outermost = False): r = self.copy() r.sign_only_inplace(sn, outermost) return r def consonly(self): return self.sign_only(True, outermost = True) def imponly(self): return self.sign_only(False, outermost = True) def process_pm(self, fcn1, fcn0): for x, c in self.regs: if c: if x.get_type() == RegionType.NORMAL: x = fcn1(x) else: x.process_pm(fcn1, fcn0) else: if x.get_type() == RegionType.NORMAL: x = fcn0(x) else: x.process_pm(fcn0, fcn1) def presolve_strengthen(self): def fcn1(r): r.presolve_process(relax = True) return r def fcn0(r): r.presolve_process(relax = False) return r self.process_pm(fcn1, fcn0) def presolve(self): self.break_imp() self.simplify_quick(zero_group = 1) self.flatten(minmax_elim = PsiOpts.settings["flatten_minmax_elim"]) self.break_imp() self.presolve_strengthen() aux_ap = self.aux_appearance(True) self.aux_remove() self.distribute() if False: t = Comp.empty() for x, c, ccomp in aux_ap: t += x t = self.allcomprv() - t if not t.isempty(): #aux_ap = [(t, False, Comp.empty())] + aux_ap aux_ap.append((t, False, Comp.empty())) #self.inter_compress() self.normalcons_sort() for x, c in self.regs: if isinstance(x, RegionOp): x.inter_compress() self.auxs = [(x, c) for x, c, ccomp in aux_ap] self.aux_clean() if not self.emptycons_present(): self.regs.append((Region.empty(), True)) return self def to_cause_consequence(self): cs = ~self cs.presolve() r = [] #allauxs = cs.auxs auxi = cs.getaux() for x, c in cs.regs: #allcomprv = x.allcomprv() #cur_auxs_incomp = RegionOp.auxs_incomp(allauxs, allcomprv) req = Region.universe() cons = [] if x.get_type() == RegionType.INTER: for y, cy in x.regs: if cy: req &= y else: cons.append(y) elif x.get_type() == RegionType.NORMAL: if c: if x.isempty(): continue req &= x else: cons.append(x) ccons = None if len(cons) == 1: ccons = cons[0] else: ccons = RegionOp.union(cons) cauxi = auxi.inter(x.allcomp()) # cauxi = x.auxi # cauxi = x.aux r.append((req, ccons, cauxi)) return r def get_var_avoid(self, a): r = None for x, c in self.regs: t = x.get_var_avoid(a) if r is None: r = t elif t is not None: r = r.inter(t) return r def auxs_icreg(auxs, othercomp, exclcomp): r = Region.universe() ccomp = Comp.empty() for a, c in reversed(auxs): if not c: r.iand_norename(Region.Ic(a - exclcomp, othercomp, ccomp)) ccomp += a return r def auxs_incomp(auxs, x): r = [(a.inter(x), c) for a, c in auxs] return [(a, c) for a, c in r if not a.isempty()] # r2 = [] # for a, c in r: # if not a.isempty(): # if len(r2) == 0 or r2[-1][1] != c: # r2.append((a, c)) # else: # r2[-1] = (r2[-1][0] + a, c) # return r2 def calc_req_cons(self): self.presolve() n = len(self.regs) r = [] for i in range(n): x, c = self.regs[i] req = Region.universe() cons = [] if x.get_type() == RegionType.INTER: for y, cy in x.regs: if cy: req &= y else: cons.append(y) elif x.get_type() == RegionType.NORMAL: if c: req &= x else: cons.append(x) r.append((req, cons)) return r def get_req_cons(self): cs = self.copy() r = cs.calc_req_cons() return (r, cs.auxs) def check_getaux_inplace(self, must_include = None, single_include = None, hint_pair = None, hint_aux = None, hint_aux_avoid = None, max_iter = None, leaveone = None, strengthen = None, get_info = None): verbose = PsiOpts.settings.get("verbose_auxsearch", False) verbose_step = PsiOpts.settings.get("verbose_auxsearch_step", False) verbose_op = PsiOpts.settings.get("verbose_auxsearch_op", False) verbose_op_step = PsiOpts.settings.get("verbose_auxsearch_op_step", False) verbose_op_detail = PsiOpts.settings.get("verbose_auxsearch_op_detail", False) verbose_op_detail2 = PsiOpts.settings.get("verbose_auxsearch_op_detail2", False) ignore_must = PsiOpts.settings["ignore_must"] forall_multiuse = PsiOpts.settings["forall_multiuse"] forall_multiuse_numsave = PsiOpts.settings["forall_multiuse_numsave"] auxsearch_local = PsiOpts.settings["auxsearch_local"] init_leaveone = PsiOpts.settings["init_leaveone"] auxsearch_aux_strengthen = PsiOpts.settings["auxsearch_aux_strengthen"] if leaveone is None: leaveone = PsiOpts.settings["auxsearch_leaveone"] save_res = auxsearch_local if max_iter is None: max_iter = PsiOpts.settings["auxsearch_max_iter"] if hint_aux is None: hint_aux = [] if hint_aux_avoid is None: hint_aux_avoid = [] if must_include is None: must_include = Comp.empty() if strengthen is None: strengthen = PsiOpts.settings["auxsearch_strengthen"] casesteplimit = PsiOpts.settings["auxsearch_op_casesteplimit"] caselimit = PsiOpts.settings["auxsearch_op_caselimit"] write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False) yield_one = write_pf_enabled and PsiOpts.settings.get("proof_yield_one", False) if verbose_op: print("========= aux search op ========") print(self) self.presolve() if verbose_op: print("========= expanded ========") print(self) print() csallcomprv = self.allcomprv() csaux = Comp.empty() csauxi = Comp.empty() for a, ca in self.auxs: if not ca: csauxi += a else: csaux += a write_pf_twopass = write_pf_enabled and not csaux.isempty() if write_pf_twopass: PsiOpts.settings["proof_enabled"] = False write_pf_enabled_inner = PsiOpts.settings.get("proof_enabled", False) max_numupdate_one = write_pf_enabled_inner and PsiOpts.settings.get("auxsearch_max_numupdate_one_if_proof", False) allauxs = self.auxs csnonaux = csallcomprv - csaux - csauxi if not csnonaux.isempty(): allauxs.append((csnonaux, False)) csauxiall = csauxi + csnonaux csallcomprv = csaux + csauxiall csaux_id = IVarIndex() csaux_id.record(csaux) csauxiall_id = IVarIndex() csauxiall_id.record(csauxiall) csall_id = IVarIndex() csall_id.record(csallcomprv) csaux_is = [csall_id.get_index(a) for a in csaux.varlist] csauxiall_is = [csall_id.get_index(a) for a in csauxiall.varlist] csauxidep = Comp.empty() nvar = len(csallcomprv) n = len(self.regs) nreqcheck = 0 nfinal = 0 depgraph = [[False] * nvar for j in range(nvar)] xallcomprv = [] xconscomprv = [] xreq = [] xcons = [] xmultiuse = [] xleaveone = [] xaux = [] xauxi = [] xaux_avoid = [] xoneuse_aux = [] xcforall = [] xauxs_incomp = [] xreqcomprv = [] xconsonly = [] xconsonly_init = [] init_reg = Region.universe() for i in range(n): x, c = self.regs[i] allcomprv = x.allcomprv() cur_auxs_incomp = RegionOp.auxs_incomp(allauxs, allcomprv) req = Region.universe() cons = [] if x.get_type() == RegionType.INTER: for y, cy in x.regs: if cy: req &= y else: cons.append(y) elif x.get_type() == RegionType.NORMAL: if c: req &= x else: cons.append(x) conscomprv = Comp.empty() for con in cons: conscomprv += con.allcomprv() cur_multiuse = forall_multiuse aux = Comp.empty() #auxi = csauxi.inter(conscomprv) auxi = csauxi.inter(allcomprv) aux_avoid = [] cavoid = Comp.empty() cavoidmask = 0 cforall_c = Comp.empty() cforall = Comp.empty() for a, ca in allauxs: b = a.inter(allcomprv) bmask = csall_id.get_mask(b) if not b.isempty(): if ca: aux += b aux_avoid.append((b.copy(), cavoid.copy())) for j in range(nvar): if bmask & (1 << j) != 0: for j2 in range(nvar): if cavoidmask & (1 << j2) != 0: depgraph[j2][j] = True if not cforall_c.isempty(): cur_multiuse = False csauxidep += cforall_c else: cforall_c += b.inter(auxi) cavoid += b cavoidmask |= bmask #cavoid += a if len(cons) == 0: cur_multiuse = True cforall_c = Comp.empty() for a, ca in allauxs[::-1]: if ca: if not a.inter(allcomprv).isempty(): cforall = cforall_c.copy() else: cforall_c += a cur_leaveone = leaveone and cur_multiuse and len(cons) == 0 innermost_auxi = Comp.empty() if auxsearch_aux_strengthen and len(cur_auxs_incomp) and not cur_auxs_incomp[0][1]: innermost_auxi = cur_auxs_incomp[0][0].copy() - csnonaux cons2 = cons cons = [] for x in cons2: y = x.copy() if not innermost_auxi.isempty(): innermost_auxi_int = innermost_auxi.inter(y.allcomprv()) if not innermost_auxi_int.isempty(): y.aux += innermost_auxi_int y.aux_strengthen(req.allcomprv() + csnonaux) y.aux -= innermost_auxi_int for v in aux: y.substitute(v, Comp.rv(v.get_name() + "_R" + str(i))) cons.append(y) oneuse_aux = MHashSet() oneuse_aux.add(None) if req.isuniverse() and aux.isempty(): pass else: nreqcheck += 1 if len(cons) == 0: nfinal += 1 if init_leaveone and aux.isempty() and csnonaux.super_of(auxi): ofl = req.one_flipped() if ofl is not None: init_reg &= ofl.add_meta("pf_note", ["trivial otherwise"]) cur_consonly = req.isuniverse() and aux.isempty() cur_consonly_init = cur_consonly and len(cons) == 1 and csnonaux.super_of(auxi) # if cur_consonly_init: # init_reg &= cons[0] xallcomprv.append(allcomprv) xconscomprv.append(conscomprv) xreq.append(req) xcons.append(cons) xmultiuse.append(cur_multiuse) xleaveone.append(cur_leaveone) xaux.append(aux) xauxi.append(auxi) xaux_avoid.append(aux_avoid) xoneuse_aux.append(oneuse_aux) xcforall.append(cforall) xauxs_incomp.append(cur_auxs_incomp) xreqcomprv.append(req.allcomprv()) xconsonly.append(cur_consonly) xconsonly_init.append(cur_consonly_init) if verbose_op: print("========= #" + iutil.strpad(str(i), 3, " requires ========")) print(req) if len(cons) > 0: print("========= consequences ========") print("\nOR\n".join(str(con) for con in cons)) if not aux.isempty() or not auxi.isempty(): print("========= auxiliary ========") print(" ".join(("|" if c else "&") + str(a) for a, c in cur_auxs_incomp)) print("Multiuse = " + str(cur_multiuse), ", Leave one = " + str(cur_leaveone)) print("Cons only = " + str(cur_consonly)) print() hint_aux_avoid = hint_aux_avoid + self.get_aux_avoid_list() mustcomp = Comp.empty() if not ignore_must: for a in csauxi: cmarkers = a.get_markers() cdict = {v: w for v, w in cmarkers} if cdict.get("mustuse", False): mustcomp += a rcases = MHashSet() rcases.add((init_reg, [False] * n)) oneuse_added = [False] * n res = MHashList() rcases_hashset = set() #rcases_hashset.add(hash(rcases)) oneuse_set = MHashSet() oneuse_set.add(([None] * len(csaux), [])) max_iter_pow = 4 #cur_max_iter = 200 cur_max_iter = 800 max_yield_pow = 3 #cur_max_yield = 5 cur_max_yield = 16 max_yield_add = 0 max_numupdate = 1000000000000 if max_numupdate_one: max_numupdate = 1 if nreqcheck <= 1: cur_max_iter *= 1000000000 cur_max_yield *= 1000000000 if yield_one: cur_max_yield = 1 max_yield_pow = 1 max_yield_add = 1 did = True prev_did = True caselimit_warned = False caselimit_reached = False if verbose_op: print("========= markers ========") for a in csallcomprv: cmarkers = a.get_markers() if len(cmarkers) > 0: print(iutil.strpad(str(a), 6, " : " + str(cmarkers))) print("Must use: " + str(mustcomp)) print("csnonaux: " + str(csnonaux)) print("csauxidep: " + str(csauxidep)) print("Max iter: " + str(cur_max_iter) + " Max yield: " + str(cur_max_yield)) print("========= init region ========") print(init_reg) print("========= begin search ========") cnstep = 0 consonly_reg = None #while did and (max_iter <= 0 or cur_max_iter < max_iter): while (did or cnstep <= 1) and (max_iter <= 0 or cur_max_iter < max_iter): if PsiOpts.is_timer_ended(): break cnstep += 1 prev_did = did did = False #rcases3 = rcases nonsimple_did = False numupdate = 0 if cnstep >= 10: max_numupdate = 1000000000000 for i in range(n): if len(rcases) == 0: break cur_consonly = xconsonly[i] cur_consonly_init = xconsonly_init[i] if cnstep > 1 and cur_consonly: continue if cnstep == 1 and not cur_consonly: continue x, c = self.regs[i] allcomprv = xallcomprv[i] conscomprv = xconscomprv[i] req = xreq[i] cons = xcons[i] cur_multiuse = xmultiuse[i] cur_leaveone = xleaveone[i] aux = xaux[i] auxi = xauxi[i] aux_avoid = xaux_avoid[i] oneuse_aux = xoneuse_aux[i] cforall = xcforall[i] cur_auxs_incomp = xauxs_incomp[i] reqcomprv = xreqcomprv[i] if not cur_consonly: if not aux.isempty() or len(cons) > 0 or cur_leaveone: nonsimple_did = True if len(rcases) > caselimit: caselimit_reached = True if not caselimit_warned: caselimit_warned = True warnings.warn("Max number of cases " + str(caselimit) + " reached. May give false reject.", RuntimeWarning) if len(cons) >= 2 and caselimit_reached: continue rcases2 = rcases rcases = MHashSet() for rcase_tuple in rcases2: rcase = rcase_tuple[0] rcase_vis = rcase_tuple[1] rcur = None if aux.isempty(): rcur = req.implicated(rcase) else: rcase_t = rcase if strengthen: rcase_t = rcase_t.copy() rcase_t.aux += csauxi.inter(rcase_t.allcomprv()) rcase_t.simplify_aux_eq() rcase_t.aux = Comp.empty() #rcur = req.implicated(rcase).exists(aux).forall(csnonaux) # rcur = req.implicated(rcase_t).exists(aux).forall(csauxiall) rcur = req.exists(aux).implicated(rcase_t).forall(csauxiall) rcases_toadd = MHashSet() rcases_toadd.add((rcase.copy(), rcase_vis[:])) cur_yield = 0 for oneaux in reversed(oneuse_set): auxvis = oneaux[1] if i in auxvis: continue auxmasks = oneaux[0] auxlist = [(None if a is None else csauxiall.from_mask(a)) for a in auxmasks] mustleft = Comp.empty() if nfinal == 1 and len(cons) == 0 and not mustcomp.isempty(): cmustcomp = Comp.empty() for i2 in auxvis: cmustcomp += xauxi[i2].inter(mustcomp) if not cmustcomp.isempty(): auxlistall = sum((a for a in auxlist if a is not None), Comp.empty()) mustleft = cmustcomp - auxlistall if aux.isempty() and not mustleft.isempty(): #print("MUST NOT") #print(str(mustcomp) + " " + str(cmustcomp) + " " + str(auxlistall)) #print(rcur) continue rcur2 = rcur.copy() if verbose_op_detail: print("========= #" + iutil.strpad(str(i), 3, " step ========")) print("SUB " + " ".join(str(csaux[j]) + ":" + str(auxlist[j]) for j in range(len(csaux)) if auxlist[j] is not None)) print("DEP " + " ".join(str(i2) for i2 in auxvis)) #print("=====================") #print(rcur2) if verbose_op_detail2: print("========= #" + iutil.strpad(str(i), 3, " before indep ====")) print(rcur2) #clcomp = csnonaux.copy() clcomp = csauxiall - csauxidep #clcomp = rcur2.imp_flippedonly().allcomprv() + csnonaux #for i2 in auxvis: # clcomp -= xallcomprv[i2] #for i2 in reversed(tsorted): # if i2 != i and oneauxs[i2] is None: # clcomp += xallcomprv[i2] for i2 in auxvis: #tauxi = xauxi[i2] - clcomp #tcond = xallcomprv[i2] - xauxi[i2] #rcur2.iand_norename(Region.Ic(tauxi, clcomp, tcond).imp_flipped()) #print("TREG") #print(Expr.Ic(tauxi, clcomp, tcond)) rcur2.iand_norename(RegionOp.auxs_icreg(xauxs_incomp[i2], clcomp - xallcomprv[i2], clcomp + xreqcomprv[i2]).imp_flipped()) if verbose_op_detail2: treg2 = RegionOp.auxs_icreg(xauxs_incomp[i2], clcomp - xallcomprv[i2], clcomp) if not treg2.isuniverse(): print("========= #" + iutil.strpad(str(i), 3, " indep " + str(i2) + " =====")) print(treg2) print("clcomp=" + str(clcomp)) clcomp += xallcomprv[i2] - csaux for i2 in range(n): if i2 not in auxvis: if rcase_vis[i2]: rcur2.remove_present(xaux[i2].added_suffix("_R" + str(i2))) else: for v in xaux[i2]: w = auxlist[csaux.varlist.index(v.varlist[0])] if w is not None: rcur2.substitute_aux(Comp.rv(v.get_name() + "_R" + str(i2)), w) for j in range(len(csaux)): if auxlist[j] is not None: rcur2.substitute_aux(csaux[j], auxlist[j]) #print("SUB " + "; ".join([str(v) + ":" + str(w) for v, w in oneaux])) if verbose_op_detail2: print("========= #" + iutil.strpad(str(i), 3, " after rename ====")) print(rcur2) #print(rcur2.getaux()) #print(rcur2.getaux().get_markers()) #if i == 4: # return None hint_aux_add = [(csaux[i], auxlist[i]) for i in range(len(csaux)) if auxlist[i] is not None] hint_aux_avoid_add = [(csaux[i], csallcomprv - auxlist[i]) for i in range(len(csaux)) if auxlist[i] is not None] #print(rcur2) cdepgraph = [a[:] for a in depgraph] for j in range(len(csaux)): mask = auxmasks[j] if mask is not None: for j2 in range(len(csauxiall_is)): if mask & (1 << j2) != 0: cdepgraph[j][csauxiall_is[j2]] = True for j in range(len(csaux)): if auxlist[j] is None and aux.ispresent(csaux[j]): cmask = 0 for j2 in range(len(csauxiall_is)): tdepgraph = [a[:] for a in cdepgraph] tdepgraph[j][csauxiall_is[j2]] = True if iutil.iscyclic(tdepgraph): cmask |= 1 << j2 if cmask != 0: hint_aux_avoid_add.append((csaux[j], csauxiall.from_mask(cmask))) if verbose_op_detail2: print("AVOID " + str(csaux[j]) + " : " + str(csauxiall.from_mask(cmask))) #rcaseallcomprv = rcase.allcomprv() + csnonaux rcaseallcomprv = rcur2.imp_flippedonly().allcomprv() + csnonaux #rcaseallcomprv = csnonaux creg_indep = None t_cur_leaveone = cur_leaveone if numupdate >= max_numupdate: did = True t_cur_leaveone = False oproof = None if write_pf_enabled_inner: oproof = PsiOpts.get_proof().copy() curtermdid = False for rr in rcur2.check_getaux_inplace_gen(must_include = must_include + mustleft, single_include = single_include, hint_pair = hint_pair, hint_aux = hint_aux + hint_aux_add, hint_aux_avoid = hint_aux_avoid + hint_aux_avoid_add, max_iter = cur_max_iter, leaveone = t_cur_leaveone): cur_yield += 1 if cur_yield > cur_max_yield: did = True break stype = iutil.signal_type(rr) if stype == "": if len(cons) > 0 and iutil.list_iscomplex(rr): continue t_multiuse = cur_multiuse if t_multiuse and len(cons) > 0 and not csauxidep.isempty() and any(csauxidep.ispresent(w) for v, w in rr): t_multiuse = False if t_multiuse: if creg_indep is None: creg_indep = RegionOp.auxs_icreg(cur_auxs_incomp, rcaseallcomprv - allcomprv, rcaseallcomprv + reqcomprv) #creg_indep.simplify() #creg_indep = RegionOp.auxs_icreg(cur_auxs_incomp, clcomp - allcomprv, clcomp) #cauxi = auxi - rcaseallcomprv #ccond = allcomprv - auxi #ccompleft = rcaseallcomprv - cauxi - ccond #creg_indep = Region.Ic(cauxi, ccompleft, ccond) #for v in aux: # creg_indep.substitute(v, Comp.rv(str(v) + "_R" + str(i))) #print("CREG") #print(rcur2) #print(creg_indep) rcases_toadd2 = rcases_toadd rcases_toadd = MHashSet() for rcase_toadd_tuple in rcases_toadd2: rcase_toadd = rcase_toadd_tuple[0] rcase_toadd_vis = rcase_toadd_tuple[1] for con in cons: ccon = con.copy() ccon.iand_norename(creg_indep) Comp.substitute_list(ccon, rr, suffix = "_R" + str(i)) crcase = rcase_toadd.copy() #crcase.iand_norename(ccon) #crcase.simplify_quick(zero_group = 1) crcase.iand_simplify_quick(ccon) rcases_toadd.add((crcase, rcase_toadd_vis[:])) if rcases_toadd != rcases_toadd2: curtermdid = True tauxlist = [(None if w is None else w.copy()) for w in auxlist] for v, w in rr: #print(">>>>" + str(csaux) + " " + str(v) + " " + str(w)) tauxlist[csaux.varlist.index(v.varlist[0])] = w.copy() rr2 = [(csaux[j], tauxlist[j]) for j in range(len(csaux)) if tauxlist[j] is not None] if len(rr2) > 0: res.add(rr2) if verbose_op_step: print("ADD " + " ".join([str(v) + ":" + str(w) for v, w in rr2]) + " y=" + str(cur_yield) + "/" + str(cur_max_yield)) else: oneuse_added[i] = True #oneuse_aux.clear() rcases_toadd2 = rcases_toadd rcases_toadd = MHashSet() for rcase_toadd_tuple in rcases_toadd2: rcase_toadd = rcase_toadd_tuple[0] rcase_toadd_vis = rcase_toadd_tuple[1] if rcase_toadd_vis[i]: rcases_toadd.add(rcase_toadd_tuple) else: for con in cons: crcase = rcase_toadd.copy() #crcase.iand_norename(con) #crcase.simplify_quick(zero_group = 1) crcase.iand_simplify_quick(con) rcases_toadd.add((crcase, [rcase_toadd_vis[i2] or i2 == i for i2 in range(n)])) tauxmasks = auxmasks[:] for v, w in rr: tauxmasks[csaux.varlist.index(v.varlist[0])] = csauxiall_id.get_mask(w) #print("; ".join(str(v) + ":" + str(w) for v, w in rr)) #print(cdepends) if oneuse_set.add((tauxmasks, auxvis + [i])): did = True curtermdid = True if verbose_op_step: tauxlist = [(None if w is None else w.copy()) for w in auxlist] for v, w in rr: tauxlist[csaux.varlist.index(v.varlist[0])] = w.copy() rr2 = [(csaux[j], tauxlist[j]) for j in range(len(csaux)) if tauxlist[j] is not None] # if len(cons) == 0: # if len(rr2) > 0: # res.add(rr2) print("ONE " + " ".join([str(v) + ":" + str(w) for v, w in rr2])) if len(cons) == 0: break elif stype == "leaveone": if len(cons) == 0: rcases_toadd2 = rcases_toadd rcases_toadd = MHashSet() casedid = False for rcase_toadd_tuple in rcases_toadd2: rcase_toadd = rcase_toadd_tuple[0] rcase_toadd_vis = rcase_toadd_tuple[1] #crcase = rcase_toadd & (rr[2] >= 0) #crcase.simplify_quick(zero_group = 1) crcase = rcase_toadd.copy() if not crcase.implies_ineq_quick(rr[2], ">="): crcase.iand_simplify_quick((rr[2] >= 0).add_meta("pf_note", ["case"])) casedid = True rcases_toadd.add((crcase, rcase_toadd_vis)) if casedid and rcases_toadd != rcases_toadd2: curtermdid = True tauxlist = [(None if w is None else w.copy()) for w in auxlist] for v, w in rr[1]: tauxlist[csaux.varlist.index(v.varlist[0])] = w.copy() rr2 = [(csaux[j], tauxlist[j]) for j in range(len(csaux)) if tauxlist[j] is not None] if len(rr2) > 0: res.add(rr2) if verbose_op_step: print("LVO " + " ".join([str(v) + ":" + str(w) for v, w in rr2])) numupdate += 1 if numupdate >= max_numupdate: did = True break elif stype == "max_iter_reached": did = True if len(rcases_toadd) > casesteplimit: break if len(rcases_toadd) == 0: break if not curtermdid and oproof is not None: PsiOpts.set_proof(oproof) if len(rcases_toadd) > casesteplimit: break if len(rcases_toadd) == 0: break if cur_yield > cur_max_yield: did = True break rcases += rcases_toadd if verbose_op_detail and i < n - 1: print("========= cases ========") print("\nOR\n".join(str(rcase[0]) for rcase in rcases)) if cnstep == 1 and get_info is not None: consonly_reg = RegionOp.union([a for a, _ in rcases], tosimple = True) if cnstep != 1 and not nonsimple_did: break # if not nonsimple_did: # break rcases_hash = hash(rcases) if rcases_hash not in rcases_hashset: did = True rcases_hashset.add(rcases_hash) if verbose_op_step: print("========= cases ========") print("\nOR\n".join(str(rcase[0]) for rcase in rcases)) if len(rcases) == 0: break cur_max_iter = int(cur_max_iter * max_iter_pow) cur_max_yield = int(cur_max_yield * max_yield_pow + max_yield_add) if get_info is not None: for i in range(len(get_info)): if get_info[i] == "assumption_init": get_info[i] = consonly_reg.copy() get_info[i] = get_info[i].exists(csauxi.inter(get_info[i].allcomprv())) elif get_info[i] == "assumption_final": get_info[i] = RegionOp.union([a for a, _ in rcases], tosimple = True) get_info[i] = get_info[i].exists(csauxi.inter(get_info[i].allcomprv())) elif get_info[i] == "assumption_new": if consonly_reg.get_type() == RegionType.NORMAL: bnet = consonly_reg.get_bayesnet() tlist = [] for ccase, _ in rcases: tr = ccase.copy() tr.simplify_redundant(reg = consonly_reg, quick = True, bnet = bnet) tr.simplify(reg = consonly_reg) tlist.append(tr) get_info[i] = RegionOp.union(list(tlist), tosimple = True) get_info[i] = get_info[i].exists(csauxi.inter(get_info[i].allcomprv())) else: get_info[i] = Region.universe() PsiOpts.settings["proof_enabled"] = write_pf_enabled if len(rcases) == 0: if verbose_op: print("========= success ========") print(iutil.list_tostr_std(res.x)) resrr = None #resrr = res.x if len(res.x) == 1: resrr = res.x[0] else: resrr = res.x if write_pf_enabled: if write_pf_twopass: if verbose_op: print("Proof consonly:") print(self.consonly()) print() print(self.consonly().simplified_quick()) print() pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly().simplified_quick(), c = "Claim:") PsiOpts.set_setting(proof_step_in = pf) resrr_dict = Comp.substitute_list_to_dict(resrr, multi = True) if not Comp.substitute_dict_ismulti(resrr_dict): pf = ProofObj.from_region(None, c = ["Substitute ", CompArray(resrr_dict).add_meta("omit_bracket", True).add_meta("subs", True), ":"]) PsiOpts.set_setting(proof_add = pf) # cs = self.copy() cs = RegionOp.union([]) for i in range(n): req = xreq[i] cons = xcons[i] tadd = (req & ~RegionOp.union(cons)).noaux() for v in xaux[i]: tadd.substitute_aux(Comp.rv(v.get_name() + "_R" + str(i)), v) tadd = tadd.substituted_dict_union_plain(resrr_dict) cs |= tadd cs = cs.copy() # Comp.substitute_list(cs, resrr, isaux = True) cs = cs.noaux() if verbose_op: print("Proof reconstruction:") print(cs) print() if cs.getaux().isempty(): with PsiOpts(proof_enabled = True): cs.check() PsiOpts.set_setting(proof_step_out = True) return resrr return None def check_getaux_op_inplace(self, hint_pair = None, hint_aux = None): """Return whether implication is true, with auxiliary search result.""" r = [] #print("") #print(self) for x in self.regs: #print(x) if x.get_type() == RegionType.NORMAL: t = x.check_getaux(hint_pair, hint_aux) else: t = x.check_getaux_op_inplace(hint_pair, hint_aux) if self.get_type() == RegionType.UNION and t is not None: return t if self.get_type() == RegionType.INTER and t is None: return None r.append(t) if self.get_type() == RegionType.INTER: return r return None def check_getaux(self, hint_pair = None, hint_aux = None): """Return whether implication is true, with auxiliary search result.""" truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).check_getaux(hint_pair, hint_aux) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).check_getaux(hint_pair, hint_aux) cs = self.copy() return cs.check_getaux_inplace(hint_pair = hint_pair, hint_aux = hint_aux) def check_getaux_gen(self, hint_pair = None, hint_aux = None): """Return whether implication is true, with auxiliary search result.""" rr = self.check_getaux(hint_pair = hint_pair, hint_aux = hint_aux) if rr is not None: yield rr def check(self): """Return whether implication is true.""" if iutil.get_solver() == "z3": return self.check_z3() return self.check_getaux() is not None def assumption(self, mode = None): """Retrieve the strengthened assumptions for proving this region. The implication in this region must be true if this assumption does not hold. """ cmode = "assumption_final" if mode == "new": cmode = "assumption_new" elif mode == "init": cmode = "assumption_init" get_info = [cmode] cs = self.copy() with PsiOpts(cases = True): cs.check_getaux_inplace(get_info = get_info) return get_info[0] def evalcheck(self, f): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).evalcheck(f) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).evalcheck(f) ceps = PsiOpts.settings["eps_check"] isunion = (self.get_type() == RegionType.UNION) for x, c in self.regs: if isunion ^ c ^ x.evalcheck(f): return isunion return not isunion def eval_max_violate(self, f): truth = PsiOpts.settings["truth"] if truth is not None: with PsiOpts(truth = None): return (truth >> self).eval_max_violate(f) indreg = self.get_indreg_checked() if indreg is not None: with PsiOpts(indreg_enabled = False): return (indreg >> self).eval_max_violate(f) ceps = PsiOpts.settings["eps_check"] if self.get_type() == RegionType.INTER: r = 0.0 for x, c in self.regs: t = x.eval_max_violate(f) if c: r = max(r, t) else: if t <= ceps: return numpy.inf return r elif self.get_type() == RegionType.UNION: r = numpy.inf for x, c in self.regs: t = x.eval_max_violate(f) if c: r = min(r, t) else: if t > ceps: return 0.0 return r return 0.0 def eval_sum_violate(self, f, pow = 1, leak = 0.1): # truth = PsiOpts.settings["truth"] # if truth is not None: # with PsiOpts(truth = None): # return (truth >> self).eval_sum_violate(f, pow = pow, leak = leak) ceps = PsiOpts.settings["eps_check"] if self.get_type() == RegionType.INTER: r = 0.0 for x, c in self.regs: t = x.eval_sum_violate(f, pow = pow, leak = leak) if c: r = r + t else: if t <= ceps: return numpy.inf return r elif self.get_type() == RegionType.UNION: r = numpy.inf for x, c in self.regs: t = x.eval_sum_violate(f, pow = pow, leak = leak) if c: # r = r * t r = min(r, t) else: if t > ceps: return 0.0 return r return 0.0 def implies_getaux(self, other, hint_pair = None, hint_aux = None): """Whether self implies other, with auxiliary search result.""" return (self <= other).check_getaux(hint_pair, hint_aux) def istight(self, canon = False): return all(x.istight(canon) for x, c in self.regs) def tighten(self): for x, c in self.regs: x.tighten() def add_meta_present(self, b, key, value): for x, c in self.regs: x.add_meta_present(b, key, value) def simplify_op(self): if self.isuniverse(): self.setuniverse() return self if self.isempty(): self.setempty() return self if self.get_type() == RegionType.INTER or self.get_type() == RegionType.UNION: for i in range(len(self.regs)): x, c = self.regs[i] if not c and x.get_type() == RegionType.NORMAL: t = x.try_negate(eps_only = True) if t is not None: self.regs[i] = (t, not c) if len(self.auxs) == 0: tregs = [] for x, c in self.regs: if x.isuniverse(c ^ (self.get_type() == RegionType.UNION)): continue if c and x.get_type() == self.get_type() and len(x.auxs) == 0: tregs += x.regs self.auxs = x.auxs + self.auxs else: tregs.append((x, c)) self.regs = tregs self.aux_clean() return self def simplify_quick(self, reg = None, zero_group = 0): """Simplify a region in place, without linear programming. Optional argument reg with constraints assumed to be true. zero_group = 2: group all nonnegative terms as a single inequality. """ if not PsiOpts.settings.get("simplify_enabled", False): return self #self.distribute() #self.remove_missing_aux() for x, c in self.regs: x.simplify_quick(reg, zero_group) self.simplify_op() return self def aux_push(self): for i in range(len(self.regs)): for x, c in self.auxs: if c: self.regs[i] = (self.regs[i][0].exists(x), self.regs[i][1]) else: self.regs[i] = (self.regs[i][0].forall(x), self.regs[i][1]) self.auxs = [] return self def simplify_union(self, reg = None): """Simplify a union region in place. May take much longer than Region.simplify(). Optional argument reg with constraints assumed to be true. """ if reg is None: reg = Region.universe() # print("simplify_union pre distribute") # print(self) self.distribute() # print("simplify_union post distribute") # print(self) if not self.get_type() == RegionType.UNION: return if any(not c for x, c in self.auxs): return #self.aux_push() aux = Comp.empty() for x, c in self.auxs: aux += x regc = [i for i in range(len(self.regs)) if self.regs[i][1]] for i in regc: self.regs[i][0].eliminate(aux) self.regs[i][0].simplify() self.regs[i][0].remove_aux(aux) regs_rem = [False for x, c in self.regs] for i, j in itertools.permutations(regc, 2): if regs_rem[i] or regs_rem[j]: continue #print("###") if (self.regs[i][0].exists(aux.inter(self.regs[i][0].allcomprv())) & reg).implies(self.regs[j][0].exists(aux.inter(self.regs[j][0].allcomprv()))): regs_rem[i] = True self.regs = [(x, c) for i, (x, c) in enumerate(self.regs) if not regs_rem[i]] # print("simplify_union output") # print(self) return self def imp_present(self): return True def var_mi_only(self, v): return all(x.var_mi_only(v) for x, c in self.regs) def sort(self): pass def simplify(self, reg = None, zero_group = 0, **kwargs): """Simplify a region in place. Optional argument reg with constraints assumed to be true. zero_group = 2: group all nonnegative terms as a single inequality. """ if kwargs: r = None with PsiOpts(**{"simplify_" + key: val for key, val in kwargs.items()}): r = self.simplify(reg, zero_group) return r if not PsiOpts.settings.get("simplify_enabled", False): return self simplify_redundant_op = (PsiOpts.settings.get("simplify_redundant_op", False) and not PsiOpts.settings.get("simplify_quick", False)) #self.distribute() self.remove_missing_aux() r_assumed = None if reg is None: r_assumed = Region.universe() else: r_assumed = reg.copy() if PsiOpts.settings.get("simplify_regterm", False): self.simplify_regterm(reg) with PsiOpts(simplify_regterm = False): if self.get_type() == RegionType.INTER or self.get_type() == RegionType.UNION: isunion = (self.get_type() == RegionType.UNION) regs_sorted = list(self.regs) def reg_sort_priority(t): x, c = t return (c == (not isunion), x.complexity()) regs_sorted.sort(key = reg_sort_priority) for i0, (x0, c0) in enumerate(regs_sorted): t_assumed = r_assumed.copy() for i, (x, c) in enumerate(regs_sorted): if i0 == i: continue if c ^ isunion: t = x.tosimple_noaux() if t is not None: t_assumed &= t else: flipped = x.tosimple_noaux() if flipped is not None: flipped = flipped.one_flipped() if flipped is not None: t_assumed &= flipped x0.simplify(t_assumed, zero_group) # def reg_sort_priority(t): # x, c = t # if c == (not isunion): # return 0 # if x.one_flipped() is not None: # return 1 # return 2 # regs_sorted.sort(key = reg_sort_priority) # for x, c in regs_sorted: # x.simplify(r_assumed, zero_group) # if c ^ isunion: # t = x.tosimple_noaux() # if t is not None: # r_assumed &= t # else: # flipped = x.tosimple_noaux() # if flipped is not None: # flipped = flipped.one_flipped() # if flipped is not None: # r_assumed &= flipped # for c_pass, one_flipped_need in [(not isunion, None), (isunion, True), (isunion, False)]: # for x, c in self.regs: # if c != c_pass: # continue # if one_flipped_needed is not None: # if c == c_pass and (one_flipped_need is None or ): # x.simplify(r_assumed, zero_group) # if c ^ isunion: # t = x.tosimple_noaux() # if t is not None: # r_assumed &= t if simplify_redundant_op: regs_s = [(x.tosimple_noaux() if not c ^ isunion else None) for x, c in self.regs] regs_rem = [False for x, c in self.regs] # for i, j in itertools.permutations([i for i in range(len(self.regs)) if regs_s[i] is not None], 2): for i in range(len(self.regs)): if PsiOpts.is_timer_ended(): break if regs_rem[i] or regs_s[i] is None: continue #print("###") t_assumed = r_assumed.copy() for k, (x, c) in enumerate(self.regs): if k == i or regs_rem[k]: continue if c ^ isunion: t = x.tosimple_noaux() if t is not None: t_assumed &= t else: flipped = x.tosimple_noaux() if flipped is not None: flipped = flipped.one_flipped() if flipped is not None: t_assumed &= flipped # for k in range(len(self.regs)): # if k != i and not regs_rem[k] and not self.regs[k][1] ^ isunion: # tf = self.regs[k][0].one_flipped() # if tf is not None: # t_assumed &= tf for j in range(len(self.regs)): if PsiOpts.is_timer_ended(): break if i == j or regs_rem[j] or regs_s[j] is None: continue if (regs_s[i] & t_assumed).implies(regs_s[j]): regs_rem[i] = True break self.regs = [(x, c) for i, (x, c) in enumerate(self.regs) if not regs_rem[i]] self.simplify_op() if PsiOpts.settings.get("simplify_union", False): self.simplify_union(reg) return self def simplified_quick(self, reg = None, zero_group = 0): """Returns the simplified region Optional argument reg with constraints assumed to be true zero_group = 2: group all nonnegative terms as a single inequality """ if reg is None: reg = Region.universe() r = self.copy() r.simplify_quick(reg, zero_group) t = r.tosimple_safe() if t is not None: return t return r def simplified(self, reg = None, zero_group = 0, **kwargs): """Returns the simplified region Optional argument reg with constraints assumed to be true zero_group = 2: group all nonnegative terms as a single inequality """ if reg is None: reg = Region.universe() r = self.copy() r.simplify(reg, zero_group, **kwargs) t = r.tosimple_safe() if t is not None: return t return r def add_aux(self, aux, c): if len(self.auxs) > 0 and self.auxs[-1][1] == c: self.auxs[-1]= (self.auxs[-1][0] + aux, self.auxs[-1][1]) else: self.auxs.append((aux.copy(), c)) def remove_aux(self, w): t = self.auxs self.auxs = [] for x, c in t: y = x - w if not y.isempty(): self.auxs.append((y, c)) def remove_missing_aux(self): #return t = self.auxs self.auxs = [] allcomp = self.allcomprv() for x, c in t: y = x.inter(allcomp) if not y.isempty(): self.auxs.append((y, c)) def eliminate(self, w, reg = None, toreal = False, forall = False, quick = False, method = "", reg_record = None): w = Region.get_allcomp(w) toelim = Comp.empty() for v in w.allcomp(): if toreal or v.get_type() == IVarType.REAL: toelim += v elif v.get_type() == IVarType.RV: self.add_aux(v, not forall) simplify_needed = False if not toelim.isempty(): simplify_needed = True if forall: self.negate() # print(self) self.distribute(force_split = True) # print(self) for x, c in self.regs: x.simplify_quick() if not c and x.ispresent(toelim): if forall: self.setempty() else: self.setuniverse() return self for x, c in self.regs: if x.ispresent(toelim): x.eliminate(toelim, reg = reg, toreal = toreal, forall = not c, quick = quick, method = method, reg_record = reg_record) if forall: self.negate() if simplify_needed: if quick: self.simplify_quick(reg) else: self.simplify(reg) return self.tonormal_safe() return self def eliminate_quick(self, w, reg = None, toreal = False, forall = False, method = ""): return self.eliminate(w, reg = reg, toreal = toreal, forall = forall, quick = True, method = method) def marginal_eliminate(self, w): for x in self.regs: x.marginal_eliminate(w) def kernel_eliminate(self, w): for x in self.regs: x.kernel_eliminate(w) def tostring(self, style = 0, tosort = False, lhsvar = "real", inden = 0, add_bracket = False, small = False, skip_outer_exists = False): """Convert to string. Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ style = iutil.convert_str_style(style) if isinstance(lhsvar, str) and lhsvar == "real": lhsvar = self.allcomprealvar() curadd_bracket = True if style & PsiOpts.STR_STYLE_LATEX: if len(self.regs) == 1: curadd_bracket = add_bracket r = "" interstr = "" nlstr = "\n" notstr = "NOT" spacestr = " " if style & PsiOpts.STR_STYLE_PSITIP: notstr = "~" elif style & PsiOpts.STR_STYLE_LATEX: notstr = "\\lnot" if style & PsiOpts.STR_STYLE_LATEX_ARRAY: nlstr = "\\\\\n" spacestr = "\\;" if self.get_type() == RegionType.UNION: if style & PsiOpts.STR_STYLE_PSITIP: interstr = "|" elif style & PsiOpts.STR_STYLE_LATEX: interstr = PsiOpts.settings["latex_or"] else: interstr = "OR" if self.get_type() == RegionType.INTER: if style & PsiOpts.STR_STYLE_PSITIP: interstr = "&" elif style & PsiOpts.STR_STYLE_LATEX: interstr = PsiOpts.settings["latex_and"] else: interstr = "AND" if self.isuniverse(sgn = True, canon = True): if style & PsiOpts.STR_STYLE_PSITIP: return spacestr * inden + "RegionOp.universe()" elif style & PsiOpts.STR_STYLE_LATEX: return spacestr * inden + PsiOpts.settings["latex_region_universe"] else: return spacestr * inden + "Universe" if self.isuniverse(sgn = False, canon = True): if style & PsiOpts.STR_STYLE_PSITIP: return spacestr * inden + "RegionOp.empty()" elif style & PsiOpts.STR_STYLE_LATEX: return spacestr * inden + PsiOpts.settings["latex_region_empty"] else: return spacestr * inden + "{}" for x, c in reversed(self.auxs): if c: if style & PsiOpts.STR_STYLE_PSITIP: pass elif style & PsiOpts.STR_STYLE_LATEX: if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += PsiOpts.settings["latex_exists"] + " " r += x.tostring(style = style, tosort = tosort) r += PsiOpts.settings["latex_quantifier_sep"] + " " else: if style & PsiOpts.STR_STYLE_PSITIP: pass elif style & PsiOpts.STR_STYLE_LATEX: if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += PsiOpts.settings["latex_forall"] + " " r += x.tostring(style = style, tosort = tosort) r += PsiOpts.settings["latex_quantifier_sep"] + " " inden_inner = inden inden_inner1 = inden + 2 if style & PsiOpts.STR_STYLE_PSITIP: r += spacestr * inden + "(" + nlstr elif style & PsiOpts.STR_STYLE_LATEX: if style & PsiOpts.STR_STYLE_LATEX_ARRAY: r += spacestr * inden if curadd_bracket: r += "\\left\\{" r += "\\begin{array}{l}\n" inden_inner = 0 inden_inner1 = 0 else: r += spacestr * inden + "\\{" + nlstr else: r += spacestr * inden + "{" + nlstr rlist = [spacestr * inden_inner1 + ("" if c else " " + notstr) + x.tostring(style = style, tosort = tosort, lhsvar = lhsvar, inden = inden_inner1, add_bracket = True, small = small).lstrip() for x, c in self.regs] if tosort: rlist = zip(rlist, [any(x.ispresent(t) for t in lhsvar) for x, c in self.regs]) rlist = sorted(rlist, key=lambda a: (not a[1], len(a[0]), a[0])) rlist = [x for x, t in rlist] r += (nlstr + spacestr * inden_inner + " " + interstr + nlstr).join(rlist) r += nlstr + spacestr * inden_inner if style & PsiOpts.STR_STYLE_PSITIP: r += ")" elif style & PsiOpts.STR_STYLE_LATEX: if style & PsiOpts.STR_STYLE_LATEX_ARRAY: r += "\\end{array}" if curadd_bracket: r += "\\right\\}" else: r += "\\}" else: r += "}" for x, c in self.auxs: if c: if style & PsiOpts.STR_STYLE_PSITIP: r += ".exists(" elif style & PsiOpts.STR_STYLE_LATEX: if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += " , " + PsiOpts.settings["latex_exists"] + " " else: continue else: r += " , exists " else: if style & PsiOpts.STR_STYLE_PSITIP: r += ".forall(" elif style & PsiOpts.STR_STYLE_LATEX: if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER: r += " , " + PsiOpts.settings["latex_forall"] + " " else: continue else: r += " , forall " r += x.tostring(style = style, tosort = tosort) if style & PsiOpts.STR_STYLE_PSITIP: r += ")" return r def __hash__(self): #return hash(self.tostring(tosort = True)) return hash((self.rtype, hash(frozenset((hash(x), c) for x, c in self.regs)), hash(tuple((hash(x), c) for x, c in self.auxs)), hash(self.inp), hash(self.oup) )) class MonotoneSet: def __init__(self, sgn = 1): self.sgn = sgn self.cache = [] def add(self, x): if self.sgn > 0: self.cache = [y for y in self.cache if not y >= x] elif self.sgn < 0: self.cache = [y for y in self.cache if not y <= x] self.cache.append(x) def __contains__(self, x): for i in range(len(self.cache)): if (self.sgn > 0 and x >= self.cache[i]) or (self.sgn < 0 and x <= self.cache[i]): for j in range(i, 0, -1): self.cache[j - 1], self.cache[j] = self.cache[j], self.cache[j - 1] return True return False def __len__(self): return len(self.cache) def __getitem__(self, key): return self.cache[key] class IBaseArray(IBaseObj): def __init__(self, x = None, shape = None, meta = None): if x is None: x = [] if isinstance(x, dict): d = x x = [] for key, value in d.items(): if isinstance(value, list): for t in value: x.append([key, t]) else: x.append([key, value]) cshape = None if isinstance(x, type(self).entry_cls): self.x = [x] cshape = tuple() elif iutil.istensor(x): self.x = [] cshape = tuple(x.shape) for xs in itertools.product(*[range(t) for t in cshape]): self.x.append(self.entry_convert(x[xs])) else: self.x = [] tshape = [] def recur(i, y): if isinstance(y, IBaseArray): y = y.tolist() if not isinstance(y, (list, tuple)): self.x.append(self.entry_convert(y)) return if len(tshape) <= i: tshape.append(len(y)) else: if len(y) != tshape[i]: raise ValueError("Shape mismatch.") return for z in y: recur(i + 1, z) recur(0, x) cshape = tuple(tshape) if shape is not None: self.shape = shape else: self.shape = cshape if meta is None: self.meta = dict() else: self.meta = meta @property def shape(self): return self._shape # if len(self._shape) == 0: # return self._shape # return self._shape + (len(self.x) // iutil.product(self._shape),) @shape.setter def shape(self, value): if isinstance(value, int): value = (value,) for i in range(len(value)): if value[i] < 0: tvalue = list(value) tvalue[i] = len(self.x) // (iutil.product(value[:i]) * iutil.product(value[i+1:])) value = tuple(tvalue) break self._shape = tuple(value) def reshaped(self, newshape): r = self.copy() r.shape = newshape return r def copy(self): return type(self)([a.copy() for a in self.x], shape = self.shape, meta = iutil.copy(self.meta)) @classmethod def empty(cls, shape = 0): if isinstance(shape, int): shape = (shape,) n = iutil.product(shape) return cls([cls.entry_cls_zero() for i in range(n)], shape = shape) @classmethod def zeros(cls, shape = 0): return cls.empty(shape = shape) @classmethod def ones(cls, shape = 0, x = None): if isinstance(shape, int): shape = (shape,) n = iutil.product(shape) if x is None: return cls([cls.entry_cls_one() for i in range(n)], shape = shape) else: return cls([x.copy() for i in range(n)], shape = shape) @classmethod def eye(cls, n, x = None): r = cls.zeros((n, n)) for i in range(n): if x is None: r[i, i] = cls.entry_cls_one() else: r[i, i] = x.copy() return r @classmethod def make(cls, *args): r = cls.empty() for a in args: t = cls.entry_convert(a) if t is not None: r.append(t) else: for b in a: r.append(b.copy()) return r @classmethod def isthis(cls, x): if isinstance(x, cls): return True if isinstance(x, list) and len(x) > 0 and isinstance(x[0], cls.entry_cls): return True return False def tolist(self, key = None, process = None): if key is None: key = tuple() shape = self.shape i = len(key) if i == len(shape): if process is not None: return process(self[key]) else: return self[key] r = [] for j in range(shape[i]): r.append(self.tolist(key + (j,), process = process)) return r def to_numpy(self): return numpy.array(self.tolist()) def to_dict(self): shape = self.shape if len(shape) < 2: return dict() r = dict() for xs in itertools.product(*[range(t) for t in shape[:-1]]): r[self[xs + (0,)]] = self[xs + (1,)] return r def find_dict(self, key): shape = self.shape if len(shape) < 2: return None r = None for xs in itertools.product(*[range(t) for t in shape[:-1]]): if self[xs + (0,)] == key: r = self[xs + (1,)] return r def iadd_noduplicate(self, x): nameset = set(a.get_name() for a in self.x) for a in x: cname = a.get_name() if cname not in nameset: nameset.add(cname) self.append(a) def subsets(self, ndim = 1, minsize = 0, maxsize = 100000, size = None, reverse = False): """Subsets of this array along the first ndim (default = 1) dimensions, as a generator """ shape = self.shape t = [] for xs in itertools.product(*[range(t) for t in shape[:ndim]]): t.append(self[tuple(xs)]) return igen.subset(t, minsize = minsize, maxsize = maxsize, size = size, reverse = reverse) def allcomp(self): return sum([a.allcomp() for a in self.x], Comp.empty()) def find(self, *args): return self.allcomp().find(*args) def from_mask(self, mask): """Return subset using bit mask.""" r = type(self).entry_cls_zero() for i in range(len(self.x)): if mask & (1 << i) != 0: r += self.x[i] return r def append(self, a): if len(self.shape) != 1: raise ValueError("Can only append 1D array.") return self.x.append(a) self.shape = (self.shape[0] + 1,) def swapped_id(self, i, j): if i >= len(self.x) or j >= len(self.x): return self.copy() r = self.copy() r.x[i], r.x[j] = r.x[j], r.x[i] return r def transpose(self): r = type(self).zeros(tuple(reversed(self.shape))) for xs in itertools.product(*[range(t) for t in self.shape]): r[tuple(reversed(xs))] = self[xs] return r @fcn_substitute def substitute(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place.""" for i in range(len(self.x)): self.x[i].substitute(v0, v1) return self @fcn_substitute def substitute_whole(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), in place.""" for i in range(len(self.x)): self.x[i].substitute_whole(v0, v1) return self @fcn_substitute def substitute_aux(self, v0, v1): """Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, in place.""" for i in range(len(self.x)): self.x[i].substitute_aux(v0, v1) return self def substituted_aux(self, *args, **kwargs): """Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, return result""" r = self.copy() r.substitute_aux(*args, **kwargs) return r def set_len(self, n): if n < len(self.x): self.x = self.x[:n] return while n > len(self.x): self.x.append(type(self).entry_cls_zero()) def __neg__(self): return type(self)([-a for a in self.x], shape = self.shape) def __iadd__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) if iutil.istensor(other): if self.shape != other.shape: raise ValueError("Shape mismatch.") return for xs in itertools.product(*[range(t) for t in self.shape]): self[xs] += other[xs] return self # if isinstance(other, IBaseArray): # r = [] # for i in range(len(other.x)): # if i < len(self.x): # self.x[i] += other.x[i] # else: # self.x.append(other.x[i].copy()) # return self for i in range(len(self.x)): self.x[i] += other return self def __imul__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) if iutil.istensor(other): if self.shape != other.shape: raise ValueError("Shape mismatch.") return for xs in itertools.product(*[range(t) for t in self.shape]): self[xs] *= other[xs] return self for i in range(len(self.x)): self.x[i] *= other return self def __itruediv__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) if iutil.istensor(other): if self.shape != other.shape: raise ValueError("Shape mismatch.") return for xs in itertools.product(*[range(t) for t in self.shape]): self[xs] /= other[xs] return self for i in range(len(self.x)): self.x[i] /= other return self def __ipow__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) if iutil.istensor(other): if self.shape != other.shape: raise ValueError("Shape mismatch.") return for xs in itertools.product(*[range(t) for t in self.shape]): self[xs] **= other[xs] return self for i in range(len(self.x)): self.x[i] **= other return self def __mul__(self, other): r = self.copy() r *= other return r def __rmul__(self, other): r = self.copy() r *= other return r def __truediv__(self, other): r = self.copy() r /= other return r def __rtruediv__(self, other): r = type(self).ones(len(self.x)) r *= other r /= self return r def __pow__(self, other): r = self.copy() r **= other return r def __rpow__(self, other): r = type(self).ones(len(self.x)) r *= other r **= self return r def __add__(self, other): if isinstance(other, int) and other == 0: return self.copy() if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) r = self.copy() r += other return r def __radd__(self, other): if isinstance(other, int) and other == 0: return self.copy() if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) r = self.copy() r += other return r def __isub__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) self += -other return self def __sub__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) r = self.copy() r += -other return r def __rsub__(self, other): if isinstance(other, (tuple, list, ConcDist)): other = type(self)(other) r = -self r += other return r def __len__(self): return len(self.x) def product(self): if len(self.x) == 0: return 1 r = None for a in self.x: if r is None: r = a.copy() else: r *= a return r def key_to_id(self, key): if isinstance(key, int): key = (key,) if isinstance(key, tuple) and all(isinstance(xs, int) for xs in key): shape = self.shape if len(key) != len(shape): raise IndexError("Dimension mismatch.") return i = 0 for s, k in zip(shape, key): if k < 0 or k >= s: raise IndexError("Index out of bound.") return i = i * s + k return i return 0 def getitem_str(self, key): key = key.lower() r = type(self).entry_cls_zero() for c in key: if c == "0" or c == "c": r += self.x[0] elif c == "p" and len(self.x) > 1: r += self.x[1] elif c == "f" and len(self.x) > 2: r += self.x[2] elif c == "a": if len(self.x) > 1: r += self.x[1] for i, y in enumerate(self.x): if i != 1: r += y return r def __getitem__(self, key): if isinstance(key, int): return self.x[key] if isinstance(key, str): return self.getitem_str(key) if isinstance(key, IBaseObj): return self.find_dict(key) if isinstance(key, slice): r = self.x[key] if isinstance(r, list): return type(self)(r) return r if isinstance(key, tuple) and len(key) < len(self.shape): return type(self)(self.tolist(key)) return self.x[self.key_to_id(key)] def __setitem__(self, key, item): if isinstance(key, int): self.x[key] = self.entry_convert(item) self.x[self.key_to_id(key)] = self.entry_convert(item) def dot(self, other): """ Dot product like numpy.dot. """ if isinstance(other, list): other = type(self)(other) if not iutil.istensor(other): return self * other selfshape = self.shape othershape = other.shape self_np = max(len(selfshape) - 1, 0) other_np = max(len(othershape) - 2, 0) if selfshape[self_np] != othershape[other_np]: raise ValueError("Shape mismatch.") return r = type(self).zeros(selfshape[:self_np] + othershape[:other_np] + othershape[other_np+1:]) for xs in itertools.product(*[range(t) for t in selfshape[:self_np]]): for ys in itertools.product(*[range(t) for t in othershape[:other_np]]): for i in range(selfshape[self_np]): for zs in itertools.product(*[range(t) for t in othershape[other_np+1:]]): r[xs + ys + zs] += self[xs + (i,)] * other[ys + (i,) + zs] if len(r.shape) == 0: return r[tuple()] return r def __matmul__(self, other): return self.dot(other) def __imatmul__(self, other): return (type(self)(other)).dot(self) def trace_mat(self): """Trace of matrix. """ selfshape = self.shape r = type(self).zeros(selfshape[2:]) for i in range(min(selfshape[0], selfshape[1])): for xs in itertools.product(*[range(t) for t in selfshape[2:]]): r[xs] += self[(i, i) + xs] if len(r.shape) == 0: return r[tuple()] return r def trace(self): """Trace. """ selfshape = self.shape n = min(selfshape) return sum((self[(i,) * len(selfshape)] for i in range(n)), type(self).entry_cls_zero()) def diag(self): """Return diagonal. """ selfshape = self.shape n = min(selfshape) r = [] for i in range(n): r.append(self[(i,) * len(selfshape)]) return type(self)(r) def record_to(self, index): for a in self.x: a.record_to(index) def isregtermpresent(self): for a in self.x: if a.isregtermpresent(): return True return False def get_sum(self): """Sum of all entries. """ return sum(self.x, type(self).entry_cls_zero()) def avg(self): """Average of all entries. """ return self.get_sum() / len(self) # def add_meta(self, key, value): # if self.meta is None: # self.meta = {} # self.meta[key] = value # return self # def get_meta(self, key): # if self.meta is None: # return None # if key not in self.meta: # return None # return self.meta[key] # def remove_meta(self, key): # if self.meta is None: # return self # self.meta.pop(key, None) # return self def tostring(self, style = 0, tosort = False): """Convert to string Parameters: style : Style of string conversion STR_STYLE_STANDARD : I(X,Y;Z|W) STR_STYLE_PSITIP : I(X+Y&Z|W) """ if len(self.shape) == 0: return "" style = iutil.convert_str_style(style) is_subs = len(self) > 0 and not (style & PsiOpts.STR_STYLE_PSITIP) and self.get_meta("subs") is True nlstr = "\n" if style & PsiOpts.STR_STYLE_LATEX: nlstr = "\\\\\n" latex_hline = "\\hline\n" shape = self.shape r = "" add_bracket = True list_bracket0 = "[" list_bracket1 = "]" interstr = "" if style & PsiOpts.STR_STYLE_LATEX: # list_bracket0 = "" # list_bracket1 = "" list_bracket0 = PsiOpts.settings["latex_list_bracket_l"] list_bracket1 = PsiOpts.settings["latex_list_bracket_r"] omit_bracket = len(self) > 0 and (self.get_meta("omit_bracket") is True) if not (style & PsiOpts.STR_STYLE_PSITIP) and omit_bracket: list_bracket0 = "" list_bracket1 = "" if is_subs: if style & PsiOpts.STR_STYLE_LATEX: interstr = PsiOpts.settings["latex_subs"] if len(shape) >= 2 and any(tshape > 1 for tshape in shape[:-1]): list_bracket0 = PsiOpts.settings["latex_subs_bracket_l"] list_bracket1 = PsiOpts.settings["latex_subs_bracket_r"] else: list_bracket0 = "" list_bracket1 = "" else: interstr = ":=" list_bracket0 = "" list_bracket1 = "" if style & PsiOpts.STR_STYLE_PSITIP: r += type(self).cls_name + "(" if len(shape) > 1: r += nlstr add_bracket = False isarray = False if style & PsiOpts.STR_STYLE_LATEX: if list_bracket0 != "": r += "\\left" + list_bracket0 + " " if not (len(self.shape) in [1, 2] and not any(tshape > 1 for tshape in shape[:-1])): isarray = True if is_subs and self.shape[-1] >= 2: r += "\\begin{array}{" + "l" * self.shape[-1] + "}\n" # r += "\\begin{array}{r" + "l" * (self.shape[-1] - 1) + "}\n" else: r += "\\begin{array}{" + "c" * self.shape[-1] + "}\n" # if style & PsiOpts.STR_STYLE_PSITIP: # r += type(self).cls_name + "([ " # add_bracket = False # elif style & PsiOpts.STR_STYLE_LATEX: # if style & PsiOpts.STR_STYLE_LATEX_ARRAY: # r += "\\left\\[\\begin{array}{l}\n" # add_bracket = False # else: # r += "\\[ " # else: # r += "[ " # float_style = self.get_meta("float_style") # cont = PsiOpts() # if float_style is not None: # cont = PsiOpts(float_style = float_style) # with cont: for xs in itertools.product(*[range(t) for t in shape]): si0 = len(shape) while si0 > 0 and xs[si0 - 1] == 0: si0 -= 1 si1 = len(shape) while si1 > 0 and xs[si1 - 1] == shape[si1 - 1] - 1: si1 -= 1 if si0 < len(shape) and not style & PsiOpts.STR_STYLE_LATEX: r += " " * si0 + list_bracket0 * (len(shape) - si0) if type(self).tostring_bracket_needed: r += self[xs].tostring(style = style, tosort = tosort, add_bracket = add_bracket) else: r += self[xs].tostring(style = style, tosort = tosort) if si1 < len(shape) and not style & PsiOpts.STR_STYLE_LATEX: r += list_bracket1 * (len(shape) - si1) if style & PsiOpts.STR_STYLE_LATEX: if si1 > 0: if si1 == len(shape): if isarray: r += " & " elif is_subs: r += " " else: r += " \;\; " if interstr != "": r += interstr + " " else: if latex_hline is not None: cc = len(shape) - si1 r += nlstr if cc > 1: r += latex_hline if cc > 2: r += nlstr * (cc - 2) r += latex_hline else: r += nlstr * (len(shape) - si1) else: if si1 > 0: r += "," if si1 == len(shape): r += " " if interstr != "": r += interstr + " " else: r += nlstr * (len(shape) - si1) if style & PsiOpts.STR_STYLE_PSITIP: r += ")" if style & PsiOpts.STR_STYLE_LATEX: if isarray: r += "\\end{array}" if list_bracket1 != "": r += "\\right" + list_bracket1 return r def __str__(self): return self.tostring(PsiOpts.settings["str_style"], PsiOpts.settings["str_tosort"]) def __repr__(self): return self.tostring(PsiOpts.settings["str_style_repr"]) @latex_postprocess def _latex_(self): return self.tostring(iutil.convert_str_style("latex")) class CompArray(IBaseArray): cls_name = "CompArray" entry_cls = Comp entry_cls_zero = Comp.empty entry_cls_one = None tostring_bracket_needed = True @staticmethod def entry_convert(a): if isinstance(a, Comp): return a.copy() return None def arg_convert(b): if isinstance(b, list) or isinstance(b, Comp): return CompArray.make(*b) return b def get_comp(self): return sum(self.x, Comp.empty()) def get_term(self): return Term([a.copy() for a in self.x]) def series(self, vdir): """Get past or future sequence. Parameters: vdir : Direction, 1: future non-strict, 2: future strict, -1: past non-strict, -2: past strict """ if vdir == 1: return CompArray([sum(self.x[i:], Comp.empty()) for i in range(len(self.x))]) elif vdir == 2: return CompArray([sum(self.x[i+1:], Comp.empty()) for i in range(len(self.x))]) elif vdir == -1: return CompArray([sum(self.x[:i+1], Comp.empty()) for i in range(len(self.x))]) elif vdir == -2: return CompArray([sum(self.x[:i], Comp.empty()) for i in range(len(self.x))]) return self.copy() def past_ns(self): return self.series(-1) def past(self): return self.series(-2) def future_ns(self): return self.series(1) def future(self): return self.series(2) @property def p(self): return self.past() @property def f(self): return self.future() def series_list(self, name = None, suf0 = "Q", sufp = "P", suff = "F"): if name is None: if len(self.x) == 0: name = "" else: name = self.x[0].get_name() for a in self.x[1:]: tname = a.get_name() while not tname.startswith(name): name = name[:-1] r = [] if suf0 is not None: r.append((Comp.rv(name + suf0), self.copy())) if sufp is not None: r.append((Comp.rv(name + sufp), self.past())) if suff is not None: r.append((Comp.rv(name + suff), self.future())) return r @staticmethod def series_sym(x, sufp = "P", suff = "F"): if isinstance(x, str): x = Comp.rv(x) r = CompArray.empty() r.append(x) rename_char = PsiOpts.settings["rename_char"] if sufp is not None: r.append(Comp.rv(iutil.set_suffix_num(x.get_name(), sufp, rename_char, replace_mode = "append"))) if suff is not None: r.append(Comp.rv(iutil.set_suffix_num(x.get_name(), suff, rename_char, replace_mode = "append"))) return r def __and__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([a & b for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([a & other for a in self.x], shape = self.shape) def __or__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([a | b for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([a | other for a in self.x], shape = self.shape) def __rand__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([b & a for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([other & a for a in self.x], shape = self.shape) def __ror__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([b | a for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([other | a for a in self.x], shape = self.shape) def mark(self, *args, **kwargs): for a in self.x: a.mark(*args, **kwargs) return self def set_card(self, m): for a in self.x: a.set_card(m) return self def get_card(self): return self.get_comp().get_card() def get_shape(self): r = [] for a in self.x: t = a.get_card() if t is None: raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.") return r.append(t) return tuple(r) class CheckResult(CompArray): def __init__(self, *args, truth = None, method = None, reg = None, getaux = None, proof = None, model = None, display_reg = False, **kwargs): self.truth = truth self.method = method self.reg = reg self.getaux = getaux self.proof = proof self.model = model self.display_reg = display_reg CompArray.__init__(self, *args, **kwargs) self.add_meta("omit_bracket", True) self.add_meta("subs", True) def tostring(self, style = 0, *args, **kwargs): style = iutil.convert_str_style(style) r = "" nlstr = "\n" if style & PsiOpts.STR_STYLE_LATEX: nlstr = "\\\\\n" if style & PsiOpts.STR_STYLE_LATEX: r += "\\begin{array}{l}\n" if self.display_reg and self.reg is not None: r += self.reg.tostring(style) if style & PsiOpts.STR_STYLE_LATEX: r += "\\;\\mathrm{is}\\;" else: r += " is " truthstr = "" if self.truth is True: truthstr = "True" elif self.truth is False: truthstr = "False" else: truthstr = "Unknown" if style & PsiOpts.STR_STYLE_LATEX: r += "\\mathrm{" + truthstr + "}" else: r += truthstr r += nlstr addproof = self.proof is not None if not addproof and len(self) > 0: r += CompArray.tostring(self, style = style, *args, **kwargs) r += nlstr if self.model is not None: r += nlstr r += self.model.tostring(style = style) r += nlstr if style & PsiOpts.STR_STYLE_LATEX: r += "\\end{array}" if self.proof is not None and not self.proof.isempty(): r += nlstr r = iutil.latex_concat(style, [r, self.proof.tostring(style)]) return r def __getitem__(self, key): if len(self): return CompArray.__getitem__(self, key) if self.model is not None: return self.model[key] return None def __bool__(self): return bool(self.truth) class ExprArray(IBaseArray): cls_name = "ExprArray" entry_cls = Expr entry_cls_zero = Expr.zero entry_cls_one = Expr.one tostring_bracket_needed = False @staticmethod def entry_convert(a): if isinstance(a, Expr): return a.copy() elif isinstance(a, Term): return a.copy() elif isinstance(a, (int, float)): return Expr.const(a) elif iutil.istensor(a) and len(a.shape) == 0: return Expr.const(float(a)) return None def set_float(self, force_float = True): for x in self: x.add_meta("float_style", force_float) def get_expr(self): return sum(self.x, Expr.zero()) def series(self, vdir): """Get past or future sequence. Parameters: vdir : Direction, 1: future non-strict, 2: future strict, -1: past non-strict, -2: past strict """ if vdir == 1: return ExprArray([sum(self.x[i:], Expr.zero()) for i in range(len(self.x))]) elif vdir == 2: return ExprArray([sum(self.x[i+1:], Expr.zero()) for i in range(len(self.x))]) elif vdir == -1: return ExprArray([sum(self.x[:i+1], Expr.zero()) for i in range(len(self.x))]) elif vdir == -2: return ExprArray([sum(self.x[:i], Expr.zero()) for i in range(len(self.x))]) return self.copy() def past_ns(self): return self.series(-1) def past(self): return self.series(-2) def future_ns(self): return self.series(1) def future(self): return self.series(2) def isconst(self): return all(a.get_const() is not None for a in self.x) def to_numpy(self): if self.isconst(): return numpy.array(self.tolist(process = lambda a: a.get_const())) return IBaseArray.to_numpy(self) def __abs__(self): return eabs(self) def __and__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([a & b for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([a & other for a in self.x], shape = self.shape) def __or__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([a | b for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([a | other for a in self.x], shape = self.shape) def __rand__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([b & a for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([other & a for a in self.x], shape = self.shape) def __ror__(self, other): if isinstance(other, CompArray) or isinstance(other, ExprArray): return ExprArray([b | a for a, b in zip(self.x, other.x)], shape = self.shape) return ExprArray([other | a for a in self.x], shape = self.shape) def ge_region(self): r = Region.universe() for a in self.x: r.exprs_ge.append(a.copy()) return r def eq_region(self): r = Region.universe() for a in self.x: r.exprs_eq.append(a.copy()) return r def __ge__(self, other): # if not isinstance(other, ExprArray): # other = ExprArray(other) return (self - other).ge_region() def __le__(self, other): # if not isinstance(other, ExprArray): # other = ExprArray(other) return (other - self).ge_region() def agree_shape_in(self, other): shape_in = self.meta.get("prob_shape_in", None) if shape_in is None: return None if isinstance(other, ExprArray): other_shape_in = other.meta.get("prob_shape_in", None) if other_shape_in is None or other_shape_in != shape_in: return None return shape_in def __eq__(self, other): # if not isinstance(other, ExprArray): # other = ExprArray(other) return (self - other).eq_region() def discover_ic_lex(self, x): """Discover conditional independence relations among random variables in x using this entropy vector. The entropy vector must be ordered in lexicographical order (e.g. obtained by ent_vector_lex(x)). """ return Region.ent_vector_discover_ic(self, x) def discover_lex(self, x): """Discover relations among random variables in x using this entropy vector. The entropy vector must be ordered in lexicographical order (e.g. obtained by ent_vector_lex(x)). """ return (self.discover_ic_lex(x) & (ent_vector_lex(*x) == self)).simplified() def fcn(fcncall, shape): if isinstance(shape, int): shape = (shape,) return ExprArray([Expr.fcn( (lambda txs: (lambda P: fcncall(P, txs[0] if len(txs) == 1 else txs)))(xs)) for xs in itertools.product(*[range(t) for t in shape])], shape = shape) # def tostring(self, style = 0, tosort = False): # """Convert to string # Parameters: # style : Style of string conversion # STR_STYLE_STANDARD : I(X,Y;Z|W) # STR_STYLE_PSITIP : I(X+Y&Z|W) # """ # style = iutil.convert_str_style(style) # nlstr = "\n" # if style & PsiOpts.STR_STYLE_LATEX: # nlstr = "\\\\\n" # r = "" # add_bracket = True # if style & PsiOpts.STR_STYLE_PSITIP: # r += "ExprArray([ " # add_bracket = False # elif style & PsiOpts.STR_STYLE_LATEX: # if style & PsiOpts.STR_STYLE_LATEX_ARRAY: # r += "\\left\\[\\begin{array}{l}\n" # add_bracket = False # else: # r += "\\[ " # else: # r += "[ " # for i, a in enumerate(self.x): # if i: # if style & PsiOpts.STR_STYLE_LATEX: # if style & PsiOpts.STR_STYLE_LATEX_ARRAY: # r += nlstr # else: # r += ", " # else: # r += ", " # r += a.tostring(style = style, tosort = tosort) # if style & PsiOpts.STR_STYLE_PSITIP: # r += " ])" # elif style & PsiOpts.STR_STYLE_LATEX: # if style & PsiOpts.STR_STYLE_LATEX_ARRAY: # r += "\\end{array}\\right\\]" # else: # r += " \\]" # else: # r += " ]" # return r # def __str__(self): # return self.tostring(PsiOpts.settings["str_style"], PsiOpts.settings["str_tosort"]) # def __repr__(self): # return self.tostring(PsiOpts.STR_STYLE_PSITIP) class ivenn: def ellipse_contains(el, p): if len(el) < 4: el = (el[0], el[1], numpy.linalg.inv(el[1]), True) elcen, elm, elmi, elinc = el return not elinc ^ (numpy.linalg.norm(elmi.dot(p - elcen)) <= 1.0) def ellipse_angle(el, p): if len(el) < 4: el = (el[0], el[1], numpy.linalg.inv(el[1]), True) elcen, elm, elmi, elinc = el t = elmi.dot(p - elcen) return numpy.arctan2(t[1], t[0]) def ellipse_clamp(el, p, border, inc, is_it): if len(el) < 4: el = (el[0], el[1], numpy.linalg.inv(el[1]), True) elcen, elm, elmi, elinc = el pme = p - elcen t = elmi.dot(pme) tn = numpy.linalg.norm(t) if tn <= 1e-9: return None if is_it: ndiv = 600 md = 1e20 mg = numpy.array([0.0, 0.0]) for i in range(ndiv + 1): ia = i * numpy.pi * 2 / ndiv g = elm.dot(numpy.array([numpy.cos(ia), numpy.sin(ia)])) cd = numpy.linalg.norm(g - pme) if cd < md: md = cd mg = g if md < 1e-9: return None if (inc ^ (tn > 1.0)) and md >= border: return None if inc ^ (tn > 1.0): return (pme - mg) * (border / md) + mg + elcen else: return -(pme - mg) * (border / md) + mg + elcen else: border /= numpy.linalg.norm(pme) / tn if inc: if tn <= 1.0 - border: return None t = t * (1.0 - border) / tn else: if tn >= 1.0 + border: return None t = t * (1.0 + border) / tn return elm.dot(t) + elcen def ellipse_intersect_it(els, r, ndiv = 500, maxnp = 10000, cp = None, stop0 = None, stop1 = None, stopp = None): elcen, elm, elmi, elinc = els[0] started = False pcontain = True pcontainxi = -1 angle0 = 0.0 anglesn = 1.0 if elinc else -1.0 if cp is not None: angle0 = ivenn.ellipse_angle(els[0], cp) for i in range(ndiv + 1): ia = anglesn * i * numpy.pi * 2 / ndiv + angle0 p = elm.dot(numpy.array([numpy.cos(ia), numpy.sin(ia)])) + elcen ccontains = [e is els[0] or ivenn.ellipse_contains(e, p) for e in els] ccontain = all(ccontains) ccontainxi = -1 for ei in range(len(els)): if not ccontains[ei]: ccontainxi = ei if cp is None: if ccontain and not pcontain: ivenn.ellipse_intersect_it(els, r, ndiv, maxnp, p, els[pcontainxi], els[0], p) return else: if started and not ccontain: for ei in range(len(els)): if not ccontains[ei]: if els[0] is stop0 and els[ei] is stop1 and numpy.linalg.norm(p - stopp) <= 0.02: return cels = [els[ei]] + [e for e in els if e is not els[ei]] ivenn.ellipse_intersect_it(cels, r, ndiv, maxnp, p, stop0, stop1, stopp) return return if ccontain: started = True if started: r.append(p) if len(r) >= maxnp: print("Venn diagram intersection failure!") return pcontain = ccontain pcontainxi = ccontainxi if cp is None: if len(els) >= 2: cels = els[1:] + [els[0]] ivenn.ellipse_intersect_it(cels, r, ndiv, maxnp, cp, stop0, stop1, stopp) return # print("Venn diagram includes all!") for i in range(ndiv): ia = anglesn * i * numpy.pi * 2 / ndiv + angle0 p = elm.dot(numpy.array([numpy.cos(ia), numpy.sin(ia)])) + elcen r.append(p) return def ellipse_intersect(els, ndiv = 700): cels = [(e[0], e[1], numpy.linalg.inv(e[1]), e[2] if len(e) >= 3 else True) for e in els] r = [] ivenn.ellipse_intersect_it(cels, r, ndiv) return r def ellipses(n): if n == 1: return [(numpy.array([0.0, 0.0]), numpy.eye(2))] elif n == 2: ratio = 0.68 return [(numpy.array([ratio - 1.0, 0.0]), numpy.eye(2) * ratio), (numpy.array([-ratio + 1.0, 0.0]), numpy.eye(2) * ratio)] elif n == 3: ratio = 0.65 rad = -ratio + 1.0 return [(numpy.array([-numpy.sin(i * numpy.pi * 2 / 3) * rad, numpy.cos(i * numpy.pi * 2 / 3) * rad]), numpy.eye(2) * ratio) for i in range(3)] elif n == 4: mtilt0 = numpy.array([[1.0, 0.0], [-0.7, 1.0]]) mtilt1 = numpy.array([[1.0, 0.0], [-0.8, 1.0]]) mtilt2 = numpy.array([[1.0, 0.0], [0.8, 1.0]]) mtilt3 = numpy.array([[1.0, 0.0], [0.7, 1.0]]) return [(numpy.array([-0.4 * 1.011, -0.17]), mtilt0 * 0.7), (numpy.array([-0.1, 0.15]), mtilt1 * 0.62), (numpy.array([0.1, 0.15]), mtilt2 * 0.62), (numpy.array([0.4 * 1.011, -0.17]), mtilt3 * 0.7)] elif n == 5: # 5-set Venn diagram by Branko Grunbaum ratio = 0.85 r = [] for i in range(5): a = -i * numpy.pi * 2 / 5 + 0.2 cosa = numpy.cos(a) sina = numpy.sin(a) rmat = numpy.array([[cosa, -sina], [sina, cosa]]) r.append((rmat.dot(numpy.array([0.04 * 2, 0.087 * 2])) * ratio, rmat.dot(numpy.array([[0.63, 0.0], [0.0, 1.0]])) * ratio)) return r def intersect(n, inc): els = ivenn.ellipses(n) # ndiv = 500 # if n >= 4: # ndiv = 700 # return ivenn.ellipse_intersect([e for i, e in enumerate(els) if inc[i]], ndiv = ndiv) return ivenn.ellipse_intersect([e for i, e in enumerate(els) if inc[i]]) def calc_ellipses(n): els = ivenn.ellipses(n) elsi = [(el[0], el[1], numpy.linalg.inv(el[1]), True) for el in els] r = [[None, None, None] for i in range(1 << n)] tpolys = [None] * (1 << n) for mask in range((1 << n) - 1, 0, -1): maskbc = iutil.bitcount(mask) # r[mask][0] = ivenn.ellipse_intersect([e for i, e in enumerate(els) if mask & (1 << i)]) r[mask][0] = ivenn.ellipse_intersect( [(e[0], e[1], bool(mask & (1 << i))) for i, e in enumerate(els)]) if len(r[mask][0]) == 0: print("Venn diagram intersection empty!") if False: cen = sum(r[mask][0], numpy.array([0.0, 0.0])) / len(r[mask][0]) # print(cen) subcen = numpy.array([0.0, 0.0]) mask_other = ((1 << n) - 1) ^ mask for mask2 in igen.subset_mask(mask_other): if mask2 == 0: continue mask3 = mask | mask2 subcen += r[mask3][1] # cborder = 0.55 # if n == 3: # if maskbc == 2: # cborder = 0.7 # elif n == 4: # cborder = 0.8 #cborder = 0.4 * 3.0 / (n + 2.0) cborder_max = 0.4 nit = 500 cborder_shrink = 0.05 walk_ratio = 1.0 # if n == 4: # nit += 400 # # nit = 3 # cborder_shrink = 0.975 cborder_step_init = 0.02 cborder_step_shrink = 0.1 if n == 5: # nit += 400 # nit = 3 cborder_shrink = 0.005 if mask_other != 0: subcen /= ((1 << iutil.bitcount(mask_other)) - 1) # r[mask][1] = cen * 5 - subcen * 4 r[mask][1] = cen # tpoly = ivenn.ellipse_intersect([(e[0], e[1], bool(mask & (1 << i))) for i, e in enumerate(els)], ndiv = 2000) # r[mask][1] = sum(tpoly, numpy.array([0.0, 0.0])) / len(tpoly) # tpolys[mask] = tpoly cborder = 0.0 for it in range(nit): # cborder = cborder_max * (cborder_shrink ** (it * 1.0 / nit)) cborder_step = cborder_step_init * (cborder_step_shrink ** (it * 1.0 / nit)) tgts = [] for i, e in enumerate(elsi): t = ivenn.ellipse_clamp(e, r[mask][1], cborder, bool(mask & (1 << i)), n >= 4) if t is not None: tgts.append(t) if len(tgts): r[mask][1] = (sum(tgts, numpy.array([0.0, 0.0])) / len(tgts)) * walk_ratio + r[mask][1] * (1.0 - walk_ratio) cborder = max(cborder - cborder_step, 0.0) # print(" " + str(it) + ": " + str(cborder) + " " + str(r[mask][1])) else: cborder += cborder_step else: r[mask][1] = cen # r[mask][1] = cen * (1 << iutil.bitcount(mask_other)) - subcen # print(r[mask][1]) print("r[" + str(mask) + "][1] = numpy." + repr(r[mask][1])) for i, e in enumerate(elsi): # if ivenn.ellipse_contains(e, r[mask][1]) ^ bool(mask & (1 << i)): # print("FAIL " + str(i)) if ivenn.ellipse_clamp(e, r[mask][1], 0.0, bool(mask & (1 << i)), n >= 4) is not None: print("FAIL " + str(i)) r[mask][2] = numpy.array([0.0, 0.0]) textps = [] if n == 1: textps = [(0.0, 1.1)] elif n == 2: textps = [(-0.45, 0.8), (0.45, 0.8)] elif n == 3: textps = [(0.0, 1.1), (-1.0, -0.6), (1.0, -0.6)] elif n == 4: textps = [(-0.95, 0.85), (-0.4, 1.1), (0.4, 1.1), (0.95, 0.85)] elif n == 5: textps = [(-numpy.sin(-i * numpy.pi * 2 / 5 + 0.12) * 1.1, numpy.cos(-i * numpy.pi * 2 / 5 + 0.12) * 1.1) for i in range(5)] if True: if n == 1: r[1][1] = numpy.array([0.0, 0.0]) elif n == 2: r[3][1] = numpy.array([-5.19979865e-06, 0.0]) r[2][1] = numpy.array([6.39999982e-01, 0.0]) r[1][1] = numpy.array([-6.39999982e-01, 0.0]) elif n == 3: r[7][1] = numpy.array([-0.00011776, 0.00020427]) r[6][1] = numpy.array([ 8.10570556e-05, -4.61882452e-01]) r[5][1] = numpy.array([0.40013006, 0.2308303 ]) r[4][1] = numpy.array([ 0.60932452, -0.35185475]) r[3][1] = numpy.array([-0.40005325, 0.23087825]) r[2][1] = numpy.array([-0.6093162 , -0.35186916]) r[1][1] = numpy.array([1.71157142e-05, 7.03617892e-01]) elif n == 4: r[15][1] = numpy.array([ 0.0, -0.24960125]) r[14][1] = numpy.array([0.24123917, 0.14045396 - 0.04]) r[13][1] = numpy.array([-0.18729334, -0.53280546]) r[12][1] = numpy.array([0.51662578, 0.40616016]) r[11][1] = numpy.array([ 0.19303297, -0.51624427]) r[10][1] = numpy.array([ 0.41826399 + 0.01, -0.41951261]) r[9][1] = numpy.array([ 0.0, -0.78949216]) r[8][1] = numpy.array([0.85559034, 0.22166595 - 0.02]) r[7][1] = numpy.array([-0.24252916, 0.14045396 - 0.04]) r[6][1] = numpy.array([0.0, 0.36160384 + 0.04]) r[5][1] = numpy.array([-0.41826399 - 0.01, -0.41951261]) r[4][1] = numpy.array([0.40438341, 0.74829018]) r[3][1] = numpy.array([-0.51706528, 0.40634087]) r[2][1] = numpy.array([-0.4085948 , 0.74881886]) r[1][1] = numpy.array([-0.85559034, 0.22166595 - 0.02]) elif n == 5: r[31][1] = numpy.array([ 0.00123121, -0.00035098]) r[30][1] = numpy.array([-0.50522841, -0.15738005]) r[29][1] = numpy.array([-0.30334848, 0.43405041]) r[28][1] = numpy.array([-0.5759732 , 0.08623396]) r[27][1] = numpy.array([0.31906663, 0.42263051]) r[26][1] = numpy.array([-0.56718932, -0.29605626]) r[25][1] = numpy.array([-0.09599438, 0.57443462]) r[24][1] = numpy.array([-0.66609369, -0.20490312]) r[23][1] = numpy.array([ 0.49741082, -0.18604116]) r[22][1] = numpy.array([ 0.63288148, -0.09387176]) r[21][1] = numpy.array([-0.45683543, 0.44794225]) r[20][1] = numpy.array([-0.57279059, 0.30624 ]) r[19][1] = numpy.array([0.51664827, 0.26884256]) r[18][1] = numpy.array([0.64339464, 0.0889563 ]) r[17][1] = numpy.array([-0.36041924, 0.57743233]) r[16][1] = numpy.array([-0.8230906 , 0.14612619]) r[15][1] = numpy.array([-0.00456696, -0.52900733]) r[14][1] = numpy.array([-0.25995857, -0.52117389]) r[13][1] = numpy.array([ 0.10629509, -0.63091551]) r[12][1] = numpy.array([-0.01095303, -0.69681429]) r[11][1] = numpy.array([0.28484849, 0.57289808]) r[10][1] = numpy.array([-0.46825182, -0.45012615]) r[9][1] = numpy.array([0.11424952, 0.63938959]) r[8][1] = numpy.array([-0.39334952, -0.73764132]) r[7][1] = numpy.array([ 0.41524152, -0.40832889]) r[6][1] = numpy.array([ 0.63606 , -0.25034334]) r[5][1] = numpy.array([ 0.28339764, -0.58443058]) r[4][1] = numpy.array([ 0.5799869 , -0.60204133]) r[3][1] = numpy.array([0.43456033, 0.52772247]) r[2][1] = numpy.array([0.75179328, 0.36557041]) r[1][1] = numpy.array([-0.11537642, 0.82796063]) if False: for mask in range(1, (1 << n) - 1): maski = mask for i in range(1, 5): maski = maski << 1 if maski & (1 << n): maski -= 1 << n maski += 1 a = i * numpy.pi * 2 / 5 cosa = numpy.cos(a) sina = numpy.sin(a) r[maski][1] = numpy.array([[cosa, -sina], [sina, cosa]]).dot(r[mask][1]) for mask in range((1 << n) - 1, 0, -1): mdist = 1e10 for x in r[mask][0]: cdist = x[1] cdist += numpy.linalg.norm(r[mask][1] - x) if cdist < mdist: mdist = cdist r[mask][2] = x return (r, textps) def patch(n, inc, **kwargs): return matplotlib.patches.Polygon(ivenn.intersect(n, inc), True, **kwargs) class CellTable: def __init__(self, x): self.x = x self.cells = [{} for i in range(1 << len(self.x))] self.fontsize = 22 self.linewidth = 1.5 self.exprs = [] def set_enabled(self, mask, enabled = True): self.cells[mask]["enabled"] = enabled def get_enabled(self, mask): return self.cells[mask].get("enabled", True) def set_attr(self, mask, key, val): self.cells[mask][key] = val def get_attr(self, mask, key, default_val = None): return self.cells[mask].get(key, default_val) def add_expr(self, expr, cval): self.exprs.append({"expr": expr, "cval": cval}) def set_expr_val(self, mask, val): self.cells[mask]["val_" + str(len(self.exprs) - 1)] = val def get_pos(self, mask): n = len(self.x) nv = n // 2 maskv = mask & ((1 << nv) - 1) maskh = mask >> nv return (iutil.gray_to_bin(maskh), iutil.gray_to_bin(iutil.bit_reverse(maskv, nv))) # nh = (n + 1) // 2 # maskh = mask & ((1 << nh) - 1) # maskv = mask >> nh # return (iutil.gray_to_bin(iutil.bit_reverse(maskh, nh)), iutil.gray_to_bin(maskv)) def get_x_poss(self, xi): n = len(self.x) nv = n // 2 cn = nv ax = 1 if xi >= nv: xi -= nv ax = 0 cn = n - nv else: xi = nv - 1 - xi r = [] for mask in range(1, 1 << cn): if iutil.bin_to_gray(mask) & (1 << xi): r.append(mask) return (ax, r) # nh = (n + 1) // 2 # cn = nh # ax = 0 # if xi >= nh: # xi -= nh # ax = 1 # cn = n - nh # else: # xi = nh - 1 - xi # r = [] # for mask in range(1, 1 << cn): # if iutil.bin_to_gray(mask) & (1 << xi): # r.append(mask) # return (ax, r) @staticmethod def get_color(x): if x is None: return None if isinstance(x, str): mode = None x = x.lower().strip() if x.startswith("l_"): x = x[2:] mode = "l" elif x.startswith("d_"): x = x[2:] mode = "d" colors = dict(matplotlib.colors.BASE_COLORS, **matplotlib.colors.CSS4_COLORS) colors["r"] = (1.0, 0.0, 0.0) colors["g"] = (0.0, 1.0, 0.0) colors["b"] = (0.0, 0.0, 1.0) colors["c"] = (0.0, 1.0, 1.0) colors["m"] = (1.0, 0.0, 1.0) colors["y"] = (1.0, 1.0, 0.0) colors["k"] = (0.0, 0.0, 0.0) colors["w"] = (1.0, 1.0, 1.0) if x not in colors: return None r = colors[x] if not isinstance(r, tuple): r = tuple(matplotlib.colors.to_rgba(r)[:3]) if mode == "l": return (r[0] * 0.5 + 0.5, r[1] * 0.5 + 0.5, r[2] * 0.5 + 0.5) elif mode == "d": return (r[0] * 0.5, r[1] * 0.5, r[2] * 0.5) else: return r return x @staticmethod def color_blend(cols, style): if len(cols) == 0: return (0.0, 0.0, 0.0) if len(cols) == 1: return tuple(cols[0]) if style == "avghsv": r = [CellTable.color_blend(cols, "hsv"), CellTable.color_blend(cols, "avg")] return tuple([sum(c[i] for c in r) / len(r) for i in range(3)]) if style == "hsv": hsv = [matplotlib.colors.rgb_to_hsv(c) for c in cols] s = sum(t[1] for t in hsv) / len(hsv) v = sum(t[2] for t in hsv) / len(hsv) hvecs = [(math.cos((t[0] - 0.5) * 2 * math.pi) * t[1] * t[2], math.sin((t[0] - 0.5) * 2 * math.pi) * t[1] * t[2]) for t in hsv] hvec = [sum(t[i] for t in hvecs) / len(hvecs) for i in range(2)] h = 0.0 if hvec[0]**2 + hvec[1]**2 > 0.03**2: h = math.atan2(hvec[1], hvec[0]) / (2 * math.pi) + 0.5 else: s = 0.0 return matplotlib.colors.hsv_to_rgb((h, s, v)) return tuple([sum(c[i] for c in cols) / len(cols) for i in range(3)]) def get_expr_color(self, i, color_shift): r = None if i is not None: r = CellTable.get_color(self.exprs[i].get("color", None)) if r is None: # return [(1, 0.5, 0.5), (0.5, 1, 0.5), (0.5, 0.5, 1), (1, 1, 0.2), (1, 0.3, 1), (0.3, 1, 1), # (1, 0.8, 0.3), (0.6, 0.6, 0.6)][i] # clist = [(1, 0.5, 0.5), (0.5, 0.5, 1), (0.5, 1, 0.5), (0.6, 0.6, 0.6), (1, 0.3, 1), (1, 1, 0.2), (0.3, 1, 1), # (1, 0.8, 0.3)] clist = [(1, 0.5, 0.5), (0.5, 0.5, 1), (0.5, 1, 0.5), (1, 1, 0.2), (1, 0.3, 1), (0.3, 1, 1), (1, 0.8, 0.3), (0.6, 0.6, 0.6)] clist += [tuple(max(t[i] - 0.3, 0.0) for i in range(3)) for t in clist] cid = (0 if i is None else i) + color_shift return clist[cid % len(clist)] return r def plot(self, style = "hsplit", legend = True, use_latex = True, figsize = None, color = None, cval_color = None): label_interval = 0.32 label_width = 0.29 fontsize = self.fontsize fontsize_in_mul = 1.0 linewidth = self.linewidth if figsize is None: figsize = PsiOpts.settings["figsize"] if figsize is None: figsize = [10, 8] if color is None: color = [] elif isinstance(color, str): color = color.split(",") color_list = [CellTable.get_color(c) for c in color] # print(color_list) rcParams_orig = None if use_latex: rcParams_orig = plt.rcParams.copy() plt.rcParams.update({"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]}) fig, ax = plt.subplots(figsize = figsize) # patches = [] xlim = [0, 0] ylim = [0, 0] expr_hatch = [None] * len(self.exprs) mask_nexpr = [0] * (1 << len(self.x)) hatches = ['//', '\\\\', '||', '--', 'o', 'O', '.', '*', '+', 'x'] cval_present = False cval_min = 0.0 # 1e20 cval_max = 0.0 # -1e20 for mask in range(1, 1 << len(self.x)): cval = self.get_attr(mask, "cval") if cval is not None: cval = float(cval) cval_present = True cval_min = min(cval_min, cval) cval_max = max(cval_max, cval) rect_draw = True pm = False text_draw = True val_outline = None neg_hatch = None is_blend = None blend_style = "avghsv" is_venn = False is_venn_overlap = False text_sub_draw = True neg_hatch_style = "//" hatch_all = True color_shift = 0 numdp = 4 cval_color_enabled = True cval_ignore = False style_split = style.split(",") style = None for cstyle2 in style_split: cstyle = cstyle2.strip() if len(cstyle) == 0: continue if cstyle == "nofill": rect_draw = False elif cstyle == "pm": pm = True elif cstyle == "num": pm = False elif cstyle == "text": text_draw = True elif cstyle == "notext": text_draw = False elif cstyle == "sign": text_sub_draw = False elif cstyle == "nosign": text_sub_draw = False elif cstyle == "signhatch": neg_hatch = True elif cstyle == "nosignhatch": neg_hatch = False elif cstyle == "legend": legend = True elif cstyle == "nolegend": legend = False elif cstyle == "val_outline": val_outline = True elif cstyle == "val_nooutline": val_outline = False elif cstyle == "nocval": cval_ignore = True elif cstyle == "nocval_color": cval_color_enabled = False elif cstyle == "dp0": numdp = 0 elif cstyle == "dp1": numdp = 1 elif cstyle == "dp2": numdp = 2 elif cstyle == "dp3": numdp = 3 elif cstyle == "dp4": numdp = 4 elif cstyle == "dp5": numdp = 5 elif cstyle == "dp6": numdp = 6 elif cstyle == "dp7": numdp = 7 elif cstyle == "dp8": numdp = 8 elif cstyle == "dp9": numdp = 9 elif cstyle == "venn": is_venn = True elif cstyle == "blend_hsv": style = "blend" blend_style = "hsv" elif cstyle == "blend_avg": style = "blend" blend_style = "avg" elif cstyle == "blend_avghsv": style = "blend" blend_style = "avghsv" else: style = cstyle if style is None: style = "hsplit" if cval_ignore: cval_present = False if cval_min >= cval_max: cval_color_enabled = False if cval_present: if cval_color_enabled: if cval_color is None: cval_color = self.get_expr_color(None, color_shift) color_shift += 1 else: cval_color = CellTable.get_color(cval_color) else: cval_color_enabled = False if cval_color_enabled: style = "hatch" if len(self.x) > 5: is_venn = False if is_venn: if len(self.x) == 4: fontsize_in_mul = 0.9 if len(self.x) == 5: fontsize_in_mul = 0.8 if val_outline is None: val_outline = style == "hatch" or style == "text" or style == "blend" if neg_hatch is None: neg_hatch = style == "hsplit_fixed" or style == "hsplit" or style == "blend" if is_blend is None: is_blend = style == "blend" for mask in range(1, 1 << len(self.x)): expr_pres = [ei for ei in range(len(self.exprs)) if self.get_attr(mask, "val_" + str(ei)) is not None] mask_nexpr[mask] = len(expr_pres) if style == "hatch": if len(expr_pres) >= 2: for ei in expr_pres: expr_hatch[ei] = True if style == "hatch": if hatch_all or cval_color_enabled: for ei in range(len(self.exprs)): expr_hatch[ei] = True nover = 0 for ei in range(len(self.exprs)): if expr_hatch[ei]: expr_hatch[ei] = hatches[nover % len(hatches)] nover += 1 n = len(self.x) axlen = [0, 0] axslen = [0, 0] xstr = [] for xi in range(n): if use_latex: xstr.append("$" + self.x[xi].latex() + "$") else: xstr.append(str(self.x[xi])) polys = [None] * (1 << n) poly_els = [None] * n textps = [None] * n if is_venn: polys, textps = ivenn.calc_ellipses(n) poly_els = [ivenn.intersect(n, [i == j for j in range(n)]) for i in range(n)] for xi in range(n): ax.text(textps[xi][0], textps[xi][1], xstr[xi], horizontalalignment="center", verticalalignment="center", fontsize = fontsize) if not is_venn: for xi in range(n): cax, clist = self.get_x_poss(xi) axslen[cax] += 1 for xi in range(n): cax, clist = self.get_x_poss(xi) avgpos = [0.0, 0.0] caxlen = axlen[cax] if cax: caxlen = axslen[cax] - 1 - caxlen for v in clist: pos = (v, -(caxlen + 1) * label_interval) size = (1, label_width) if cax: pos = (pos[1], pos[0]) size = (size[1], size[0]) pos = (pos[0], -pos[1] - size[1]) avgpos[0] += pos[0] + size[0] * 0.5 avgpos[1] += pos[1] + size[1] * 0.5 xlim[0] = min(xlim[0], pos[0]) xlim[1] = max(xlim[1], pos[0] + size[0]) ylim[0] = min(ylim[0], pos[1]) ylim[1] = max(ylim[1], pos[1] + size[1]) ax.add_patch(matplotlib.patches.Rectangle( pos, size[0], size[1], facecolor = "lightgray")) avgpos[0] /= len(clist) avgpos[1] /= len(clist) ax.text(avgpos[0], avgpos[1], xstr[xi], horizontalalignment="center", verticalalignment="center", fontsize = fontsize) axlen[cax] += 1 # passes = [0, 4] # if is_venn: passes = [0, 2, 4] for cpass in passes: for mask in range(1, 1 << len(self.x)): if cpass == 2: if is_venn and iutil.bitcount(mask) != 1: continue pos = None size = None if is_venn: pos = polys[mask][1] size = 1.2 / n if n == 4: size *= 0.95 elif n == 5: nbit = iutil.bitcount(mask) if nbit >= 2 and nbit <= 4: size *= 0.55 elif nbit == 5: size *= 1.5 else: size *= 1.2 size = (size, size) pos = (pos[0] - size[0] * 0.5, pos[1] - size[1] * 0.5) else: pos = self.get_pos(mask) pos = (pos[0], -pos[1] - 1) size = (1, 1) xlim[0] = min(xlim[0], pos[0]) xlim[1] = max(xlim[1], pos[0] + size[0]) ylim[0] = min(ylim[0], pos[1]) ylim[1] = max(ylim[1], pos[1] + size[1]) isenabled = self.get_enabled(mask) cval = None if cval_present: cval = self.get_attr(mask, "cval") if cval is not None: cval = float(cval) color = "none" if not isenabled: color = "k" elif (cval is not None) and cval_color_enabled: cval_scaled = (cval - cval_min) / (cval_max - cval_min) color = tuple(cval_color[i] * cval_scaled + 1.0 - cval_scaled for i in range(3)) if cpass == 0 and color != "none": params = { "facecolor": color } if is_venn: ax.add_patch(matplotlib.patches.Polygon( polys[mask][0], True, **params)) else: ax.add_patch(matplotlib.patches.Rectangle( pos, size[0], size[1], **params)) if cpass == 0: params = { "facecolor": "white", "edgecolor": "none" } if is_venn and is_venn_overlap: ax.add_patch(matplotlib.patches.Polygon( polys[mask][0], True, **params)) if cpass == 2: params = { "facecolor": "none", "linestyle": "-", "linewidth": linewidth, "edgecolor": "k" } if is_venn: # ax.add_patch(matplotlib.patches.Polygon( # polys[mask][0], True, **params)) for i in range(n): if mask & (1 << i): ax.add_patch(matplotlib.patches.Polygon( poly_els[i], True, **params)) break else: ax.add_patch(matplotlib.patches.Rectangle( pos, size[0], size[1], **params)) continue if isenabled: cnexpr = 0 snexpr = mask_nexpr[mask] colsum = [0.0, 0.0, 0.0] colsumsol = [0.0, 0.0, 0.0] collist = [] collistsol = [] nsol = 0 if cval is not None: if cpass == 4: if text_draw: text_y = 0.7 if len(self.exprs) == 0: text_y = 0.5 ctext = ("{:." + str(numdp) + "f}").format(cval) ax.text(pos[0] + 0.5 * size[0], pos[1] + size[1] * text_y, ctext, horizontalalignment="center", verticalalignment="center", fontsize = fontsize * fontsize_in_mul, color = "k") for ei in range(len(self.exprs)): v = self.get_attr(mask, "val_" + str(ei)) if v is None: continue if pm: ctext = iutil.float_tostr(abs(v), bracket = False) if ctext == "1": ctext = "" if v >= 0: ctext = "+" + ctext else: ctext = "-" + ctext else: ctext = iutil.float_tostr(v, bracket = False) ccolor = None if ei < len(color_list): ccolor = color_list[ei] else: ccolor = self.get_expr_color(ei, color_shift) chatch = expr_hatch[ei] hatch_invert = False if chatch is not None: hatch_invert = True if neg_hatch and v < 0: chatch = neg_hatch_style rect_x = 0.0 rect_w = 1.0 text_x = (cnexpr + 0.5) * 1.0 / snexpr text_y = 0.5 if cval is not None: text_y = 0.3 if style == "hsplit_fixed": rect_x = ei * 1.0 / len(self.exprs) rect_w = 1.0 / len(self.exprs) text_x = (ei + 0.5) * 1.0 / len(self.exprs) elif style == "hsplit": rect_x = cnexpr * 1.0 / snexpr rect_w = 1.0 / snexpr if isinstance(ccolor, tuple): colsum[0] += ccolor[0] colsum[1] += ccolor[1] colsum[2] += ccolor[2] collist.append(ccolor) if not (neg_hatch and v < 0): colsumsol[0] += ccolor[0] colsumsol[1] += ccolor[1] colsumsol[2] += ccolor[2] collistsol.append(ccolor) nsol += 1 if cpass == 0: if rect_draw and (not is_blend or cnexpr == snexpr - 1): hatch_col = (1.0, 1.0, 1.0) if is_blend and cnexpr == snexpr - 1: # ccolor = (colsum[0] / snexpr, colsum[1] / snexpr, colsum[2] / snexpr) ccolor = CellTable.color_blend(collist, blend_style) if nsol > 0: # hatch_col = (colsumsol[0] / nsol, colsumsol[1] / nsol, colsumsol[2] / nsol) hatch_col = CellTable.color_blend(collistsol, blend_style) if abs(ccolor[0] - hatch_col[0]) + abs(ccolor[1] - hatch_col[1]) + abs(ccolor[2] - hatch_col[2]) <= 0.001: chatch = None else: chatch = neg_hatch_style params = { "facecolor": "none" if hatch_invert else ccolor, "hatch": chatch, "edgecolor": ccolor if hatch_invert else hatch_col if chatch else "none", "linewidth": 0 } if is_venn: # print(mask) # print(polys[mask]) ax.add_patch(matplotlib.patches.Polygon( polys[mask][0], True, **params)) else: ax.add_patch(matplotlib.patches.Rectangle( (pos[0] + rect_x * size[0], pos[1]), size[0] * rect_w, size[1], **params)) # ax.add_patch(matplotlib.patches.Rectangle( # (pos[0] + rect_x * size[0], pos[1]), # size[0] * rect_w, size[1], # facecolor = "none" if hatch_invert else ccolor, hatch = chatch, # edgecolor = ccolor if hatch_invert else "white" if chatch else "none", # linewidth = 0)) elif cpass == 4: if text_draw: ax.text(pos[0] + text_x * size[0], pos[1] + size[1] * text_y, ctext, horizontalalignment="center", verticalalignment="center", fontsize = fontsize * fontsize_in_mul, color = ccolor if val_outline else "k", path_effects = [matplotlib.patheffects.withStroke(linewidth = 3.5, foreground = "k")] if val_outline else None) cnexpr += 1 if cpass == 0: remtext = "" if self.get_attr(mask, "ispos"): remtext += "+" if self.get_attr(mask, "isneg"): remtext += "-" if text_sub_draw and len(remtext): rempos = None if is_venn: # vpos = -0.08 # if n == 5: # vpos = 0.05 # rempos = (pos[0] + size[0] * 0.5, pos[1] + size[1] * vpos) rempos = polys[mask][2] rempos = (rempos[0], rempos[1] + size[1] * 0.033) else: rempos = (pos[0] + size[0] * 0.015, pos[1] + size[1] * 0.02) ax.text(rempos[0], rempos[1], remtext, horizontalalignment="center" if is_venn else "left", verticalalignment="bottom", fontsize = fontsize * fontsize_in_mul * 0.85) if is_venn: xlim = [-1.1, 1.1] ylim = [-1.1, 1.1] if legend and len(self.exprs): legends = [] for ei in range(len(self.exprs)): ccolor = None if ei < len(color_list): ccolor = color_list[ei] else: ccolor = self.get_expr_color(ei, color_shift) clabel = "" cexpr = self.exprs[ei].get("expr") if cexpr is not None: with PsiOpts(str_lhsreal = False): if use_latex: clabel = "$" + cexpr.latex(skip_simplify = True) + "$" else: clabel = str(cexpr) # if isinstance(cexpr, Region): # clabel = cexpr.tostring(lhsvar = None) # else: # clabel = str(cexpr) cval = self.exprs[ei].get("cval") if cval_present and (cval is not None): if not isinstance(cval, bool): cval = float(cval) if isinstance(cval, float): clabel += " = " + ("{:." + str(numdp) + "f}").format(cval) else: clabel += " = " + str(cval) if not len(clabel): continue # legends.append(matplotlib.lines.Line2D([0], [0], color=ccolor, lw=4, label=clabel)) legends.append(matplotlib.patches.Patch(facecolor=ccolor, label=clabel)) ax.legend(handles = legends, fontsize = fontsize, bbox_to_anchor=(0.5, -0.03), loc="upper center", frameon=False) if is_venn: ax.set_aspect("equal") # ax.add_collection(PatchCollection(patches)) plt.axis("off") plt.xlim([xlim[0] - 0.011, xlim[1] + 0.011]) plt.ylim([ylim[0] - 0.011, ylim[1] + 0.011]) plt.show() fig.tight_layout() if use_latex: plt.rcParams = rcParams_orig.copy() class ProofObj(IBaseObj): def __init__(self, claim, desc = None, steps = None, parent = None, meta = None): self.claim = claim self.desc = desc if steps is None: steps = [] self.steps = steps self.parent = parent if meta is None: meta = {} self.meta = meta @staticmethod def empty(): return ProofObj(None, None, []) def isempty(self): return self.claim is None and self.desc is None and len(self.steps) == 0 def copy(self): # return ProofObj([(x.copy(), list(d), c) for x, d, c in self.steps]) return ProofObj(iutil.copy(self.claim), iutil.copy(self.desc), iutil.copy(self.steps), self.parent, iutil.copy(self.meta)) def copy_(self, other): self.claim = iutil.copy(other.claim) self.desc = iutil.copy(other.desc) self.steps = iutil.copy(other.steps) self.parent = other.parent self.meta = iutil.copy(other.meta) @staticmethod def from_region(x, c = ""): # return ProofObj([(x.copy(), [], c)]) return ProofObj(iutil.copy(x), c) def get_case(self): if "case" in self.meta: return self.meta["case"] return Region.universe() def set_case(self, reg): self.meta["case"] = reg def __iadd__(self, other): # n = len(self.steps) # self.steps += [(x.copy(), [di + n for di in d], c) for x, d, c in other.steps] other = other.copy() other.parent = self self.steps.append(other) return self def insert_step(self, other, pos): other = other.copy() other.parent = self self.steps.insert(pos, other) return self def __add__(self, other): r = self.copy() r += other return r def step_in(self, other): other = other.copy() other.parent = self self.steps.append(other) return other def step_out(self): return self.parent def clear(self): self.steps = [] # def tostring(self, style = 0, prev = None): # """Convert to string. # Parameters: # style : Style of string conversion # STR_STYLE_STANDARD : I(X,Y;Z|W) # STR_STYLE_PSITIP : I(X+Y&Z|W) # """ # style = iutil.convert_str_style(style) # r = "" # start_n = 0 # if prev is not None: # start_n = len(prev.steps) # for i, (x, d, c) in enumerate(self.steps): # if i > 0: # r += "\n" # r += "STEP #" + str(start_n + i) # if c != "": # r += " " + c # if x is None or x.isuniverse(): # r += "\n" # else: # r += ":\n" # r += x.tostring(style = style) # r += "\n" # return r def step_regions(self, flat = True): r = [] if self.desc is None: pass else: if self.claim is not None: if isinstance(self.claim, tuple): if self.claim[0] == "equiv": r.append(equiv(self.claim[1], self.claim[2])) elif self.claim[0] == "implies": r.append(self.claim[1] >> self.claim[2]) else: if isinstance(self.claim, list): isle = any(eqnstr == "<=" for expr, eqnstr, note in self.claim) prev = None for i, a in enumerate(self.claim): expr, eqnstr, note = a if prev is not None: r.append(iutil.op_str(eqnstr, prev, expr, isle = isle)) prev = expr else: r.append(self.claim) for i, x in enumerate(self.steps): t = x.step_regions(flat=flat) if flat: r += t else: r.append(t) return r def tostring(self, style = 0, prefix = "", inden = 0, skip_env = False, outermost = True): style = iutil.convert_str_style(style) slstr = "" nlstr = "\n" if style & PsiOpts.STR_STYLE_LATEX: slstr = "&" nlstr = "\\\\\n" spacestr = " " if style & PsiOpts.STR_STYLE_LATEX: spacestr = "\\;" slstr = slstr + spacestr * inden inden2 = inden if self.desc is not None: inden2 += 2 r = "" cinden = 4 cprefix = prefix if style & PsiOpts.STR_STYLE_LATEX: if not skip_env: r += "\\begin{align*}\n" if self.desc is None: cinden = 0 else: if not outermost or self.steps: outermost = False r += slstr if prefix: r += prefix + "." + spacestr cprefix += "." r += iutil.tostring_join(self.desc, style, nlstr = nlstr + slstr + spacestr * 8) if self.claim is not None: r += nlstr if self.claim is not None: if isinstance(self.claim, tuple): if self.claim[0] == "equiv": r += slstr + self.claim[1].tostring(style = style) if style & PsiOpts.STR_STYLE_PSITIP: r += nlstr + slstr + "==" + nlstr elif style & PsiOpts.STR_STYLE_LATEX: r += nlstr + slstr + PsiOpts.settings["latex_equiv"] + nlstr else: r += nlstr + slstr + "<=>" + nlstr r += slstr + self.claim[2].tostring(style = style) elif self.claim[0] == "implies": r += slstr + self.claim[1].tostring(style = style) if style & PsiOpts.STR_STYLE_PSITIP: r += nlstr + slstr + ">>" + nlstr elif style & PsiOpts.STR_STYLE_LATEX: r += nlstr + slstr + PsiOpts.settings["latex_implies"] + nlstr else: r += nlstr + slstr + "=>" + nlstr r += slstr + self.claim[2].tostring(style = style) else: if isinstance(self.claim, list): note_color = None note_newline = False if "note_color" in self.meta: note_color = self.meta["note_color"] if "note_newline" in self.meta: note_newline = self.meta["note_newline"] if style & PsiOpts.STR_STYLE_LATEX and PsiOpts.settings["latex_line_len"] is not None: if note_newline is False: note_newline = 100000000 if note_newline is not True: note_newline = min(note_newline, PsiOpts.settings["latex_line_len"]) # print(self.claim) for i, a in enumerate(self.claim): expr, eqnstr, note = a if i: r += nlstr cline = expr.tostring_line_len(eqnstr = eqnstr, style = style) r += slstr + cline if note is not None and not (isinstance(note, list) and len(note) == 0): cnote = iutil.tostring_join(note, style) cnumspace = 3 if note_newline is True or (note_newline is not False and isinstance(note_newline, int) and iutil.latex_len(cline + cnote) > note_newline): r += nlstr + slstr cnumspace = 8 if style & PsiOpts.STR_STYLE_LATEX: r += "\\;" * cnumspace if note_color is not None: r += "{\\color{" + str(note_color) + "}{" else: r += " " * cnumspace r += cnote if style & PsiOpts.STR_STYLE_LATEX: if note_color is not None: r += "}}" else: r += slstr + self.claim.tostring(style = style) r += nlstr if self.steps: r += nlstr if len(self.steps) >= 2: outermost = False for i, x in enumerate(self.steps): if i: r += nlstr # r += iutil.str_inden(x.tostring(style, cprefix + str(i + 1), skip_env = True), cinden, spacestr = spacestr, slstr = slstr) r += x.tostring(style, cprefix + str(i + 1), inden = inden2, skip_env = True, outermost = outermost) if style & PsiOpts.STR_STYLE_LATEX: if not skip_env: r += "\\end{align*}\n" return r def __str__(self): return self.tostring(PsiOpts.settings["str_style"]) @latex_postprocess def _latex_(self): return self.tostring(iutil.convert_str_style("latex")) class CodingNode(IBaseObj): def __init__(self, rv_out, aux_out = None, aux_dec = None, aux_ndec = None, rv_in_causal = None, rv_in_scausal = None, ndec_mode = None, label = None, rv_ndec_force = None): self.rv_out = rv_out if isinstance(aux_out, tuple): aux_out = [aux_out] if isinstance(aux_out, list): self.aux_out_sublist = aux_out self.aux_out = Comp.empty() for v0, v1 in self.aux_out_sublist: self.aux_out += v0 else: self.aux_out_sublist = [] self.aux_out = aux_out self.aux_dec = aux_dec self.aux_ndec = aux_ndec if rv_in_causal is None: self.rv_in_causal = Comp.empty() else: self.rv_in_causal = rv_in_causal.copy() if rv_in_scausal is None: self.rv_in_scausal = Comp.empty() else: self.rv_in_scausal = rv_in_scausal.copy() self.aux_ndec_try = Comp.empty() self.aux_ndec_force = Comp.empty() self.aux_borrow = Comp.empty() if rv_ndec_force is None: self.rv_ndec_force = Comp.empty() else: self.rv_ndec_force = rv_ndec_force.copy() self.ndec_mode = ndec_mode self.label = label def copy(self): r = CodingNode(self.rv_out.copy()) r.aux_out = iutil.copy(self.aux_out) r.aux_out_sublist = [(a.copy(), b.copy()) for a, b in self.aux_out_sublist] r.aux_dec = iutil.copy(self.aux_dec) r.aux_ndec = iutil.copy(self.aux_ndec) r.rv_in_causal = iutil.copy(self.rv_in_causal) r.aux_ndec_try = iutil.copy(self.aux_ndec_try) r.aux_ndec_force = iutil.copy(self.aux_ndec_force) r.aux_borrow = iutil.copy(self.aux_borrow) r.rv_ndec_force = iutil.copy(self.rv_ndec_force) r.ndec_mode = self.ndec_mode r.label = self.label return r def clear(self): self.aux_out = None self.aux_out_sublist = [] self.aux_dec = None self.aux_ndec = None self.aux_ndec_try = Comp.empty() self.aux_ndec_force = Comp.empty() self.aux_borrow = Comp.empty() # self.rv_ndec_force = Comp.empty() def record_to(self, index, skip_msg = False): self.rv_out.record_to(index) if self.aux_out is not None: self.aux_out.record_to(index) if self.aux_dec is not None: self.aux_dec.record_to(index) if self.aux_ndec is not None: self.aux_ndec.record_to(index) class CodingModel(IBaseObj): def __init__(self, bnet = None, sublist = None, nodes = None, reg = None, partition = None): if bnet is None: self.bnet = BayesNet() else: if isinstance(bnet, BayesNet): self.bnet = bnet else: self.bnet = BayesNet(bnet) if reg is None: self.reg = Region.universe() else: self.reg = reg if sublist is None: self.sublist = [] else: self.sublist = sublist self.sublist_rate = [] self.sublist_const = [] self.nodes = [] if nodes is None: pass else: for node in nodes: if isinstance(node, CodingNode): self.nodes.append(node) else: self.nodes.append(CodingNode(node)) self.inner_mode = "plain" self.bnet_out = None self.bnet_in = None self.bnet_out_arrays = None self.partition = partition self.use_union = True self.aux_importance = [] self.aux_dummy = Comp.empty() self.created_rv = Comp.empty() def copy(self): r = CodingModel() r.bnet = iutil.copy(self.bnet) r.reg = iutil.copy(self.reg) r.sublist = [(a.copy(), b.copy()) for a, b in self.sublist] r.sublist_rate = [(a.copy(), b.copy()) for a, b in self.sublist_rate] r.sublist_const = [(a.copy(), b.copy()) for a, b in self.sublist_const] r.nodes = [a.copy() for a in self.nodes] r.inner_mode = self.inner_mode r.bnet_out = iutil.copy(self.bnet_out) r.bnet_in = iutil.copy(self.bnet_in) r.partition = iutil.copy(self.partition) r.use_union = self.use_union r.aux_importance = iutil.copy(self.aux_importance) r.aux_dummy = iutil.copy(self.aux_dummy) r.created_rv = iutil.copy(self.created_rv) return r def __iadd__(self, other): if isinstance(other, CodingNode): self.nodes.append(other) else: self.bnet += other return self def __iand__(self, other): if isinstance(other, Region): self.reg = self.reg & other return self def add_edge(self, a, b, coded = False, children_edge = None, is_fcn = False, rv_in_causal = None, rv_in_scausal = None, **kwargs): if rv_in_causal is None: rv_in_causal = Comp.empty() if rv_in_scausal is None: rv_in_scausal = Comp.empty() ex = Comp.empty() b2 = Comp.empty() for c in b: if self.bnet.index.get_index(c) >= 0: ex += c else: b2 += c ac = a.copy() + rv_in_causal - rv_in_scausal acb = Comp.empty() if not ex.isempty(): # cname = str(ex) + "?@@100@@#EX_" # for c in a: # cname += str(c) # cname += "_" # for c in ex: # cname += str(c) cname = str(ex) + "?" cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" cname += ex.tostring(style = PsiOpts.STR_STYLE_LATEX) cname += "?[" + a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "]" exc = Comp.rv(cname) self.created_rv += exc self.bnet.add_edge(ac, exc) if is_fcn: self.bnet.set_fcn(exc) # if children_edge: # ac += ex for c in ex: if children_edge is True or (children_edge is None and not self.is_rate(c)): acb += c if coded: self.nodes.append(CodingNode(exc, rv_in_causal = rv_in_causal, rv_in_scausal = rv_in_scausal, **kwargs)) self.sublist.append((exc, ex)) for c in b2: if children_edge is True or (children_edge is None and not self.is_rate(c)): self.bnet.add_edge(ac + acb, c) acb += c else: self.bnet.add_edge(ac, c) if is_fcn: self.bnet.set_fcn(c) if coded: self.nodes.append(CodingNode(c, rv_in_causal = rv_in_causal, rv_in_scausal = rv_in_scausal, **kwargs)) def add_node(self, a, b, **kwargs): self.add_edge(a, b, coded = True, **kwargs) def set_rate(self, a, rate): if isinstance(rate, (int, float)): const_rate = rate rate = Expr.real("#TMPRATE" + str(a)) self.sublist_const.append((rate, Expr.const(const_rate))) self.sublist_rate.append((a, rate)) def get_aux(self): r = Comp.empty() for node in self.nodes: r += node.aux_out return r def get_parents_noaux(self, x, bnet): if not self.get_aux().ispresent(x): return x.copy() t = bnet.get_parents(x) return sum((self.get_parents_noaux(y, bnet) for y in t), Comp.empty()) def get_node_rv_out_from_aux_dec(self, x): r = Comp.empty() for node in self.nodes: if (node.aux_dec is not None and node.aux_dec.ispresent(x)) or (node.aux_ndec is not None and node.aux_ndec.ispresent(x)): r += node.rv_out return r def is_aux_dummy(self, x): return self.aux_dummy.ispresent(x) def get_rv_rates(self, v): v = v.copy() r = Expr.zero() for v0, v1 in self.sublist + self.sublist_rate: if v0.ispresent(v): if isinstance(v1, Expr): r += v1 else: v.substitute(v0, v1) return r def is_rate(self, v): return not self.get_rv_rates(v).iszero() def is_rate_zero(self, v): if not self.is_rate(v): return False v = self.get_rv_rates(v) for v0, v1 in self.sublist_const: v.substitute(v0, v1) v.simplify_quick() return v.iszero() def get_rate_rvs(self): r = Comp.empty() for a in self.bnet.index.comprv: if self.is_rate(a): r += a return r def is_src_rate(self, v): if not self.is_rate(v): return False rate = self.get_rv_rates(v) if not rate.isrealvar(): return False if self.bnet.get_parents(v).isempty(): return False count = 0 for v0, v1 in self.sublist + self.sublist_rate: if v1.ispresent(v): return False if v1.ispresent(rate): count += 1 if count >= 2: return False return True def get_refs(self, v): r = v.copy() for a in self.bnet.index.comprv: if self.get_rv_ratervs(a).ispresent(v): r += a return r def is_ch_rate(self, v): if not self.is_rate(v): return False rate = self.get_rv_rates(v) if not rate.isrealvar(): return False if not self.bnet.get_parents(v).isempty(): return False g = self.get_refs(v) - v for a in g: if not self.bnet.get_children(a).isempty(): return False return True # def is_src_rate(self, v): # if not self.is_rate(v): # return False # rate = self.get_rv_rates(v) # if not rate.isrealvar(): # return False # if self.find_node_rv_out(v) is None: # return False # return True def get_rv_sub(self, v): v = v.copy() for v0, v1 in self.sublist + self.sublist_rate: if v0.ispresent(v): if isinstance(v1, Expr): v.substitute(v0, v0[0]) # v = v0[0].copy() else: v.substitute(v0, v1) return v def get_rv_ratervs(self, v): v = v.copy() r = Comp.empty() for v0, v1 in self.sublist + self.sublist_rate: if v0.ispresent(v): if isinstance(v1, Expr): r += v0[0] else: v.substitute(v0, v1) return r def is_rv_original(self, v): for v0, v1 in self.sublist + self.sublist_rate: if v0.ispresent(v) and not isinstance(v1, Expr): return False return True def get_node_immediate_rates(self, node): x = node.rv_out + self.bnet.get_parents(node.rv_out) return self.get_rv_rates(x) def get_node_immediate_ratervs(self, node, original = False, include_ndec_force = False): x = node.rv_out + self.bnet.get_parents(node.rv_out) if original: x = sum((y for y in x if self.is_rv_original(y)), Comp.empty()) if include_ndec_force: x += node.rv_ndec_force return self.get_rv_ratervs(x) def get_node_descendants(self, node): x = Comp.empty() nodes = [] if isinstance(node, CodingNode): nodes = [node] else: nodes = node for node in nodes: x += self.bnet.get_descendants(node.rv_out) for node in nodes: x -= node.rv_out r = [] for node2 in self.nodes: if x.ispresent(node2.rv_out): r.append(node2) return r def get_node_dependent_after(self, node): found = False r = [] for node2 in self.nodes: if found: if not self.bnet.check_ic(Expr.I(node.rv_out, node2.rv_out)): r.append(node2) if node2 is node: found = True return r def get_nodes_rv_out(self): r = Comp.empty() for node in self.nodes: r += node.rv_out return r def find_node_rv_out(self, x): for node in self.nodes: if x.ispresent(node.rv_out): return node return None def find_nodes_immediate_ratervs(self, a): r = [] for node in self.nodes: if not self.get_node_descendants(node): continue ratervs = self.get_node_immediate_ratervs(node) if ratervs.ispresent(a): r.append(node) return r def is_src(self, x): if self.is_rate(x) or self.find_node_rv_out(x) is not None: return False for y in self.bnet.get_parents(x): if not self.is_src(y): return False return True def get_srcs(self): r = Comp.empty() for a in self.bnet.index.comprv: if self.is_src(a): r += a return r def chan_in_get_out(self, x): if any(self.is_rate(a) for a in x): return Comp.empty() anc = sum((self.bnet.get_ancestors(y) for y in x), Comp.empty()) - x r = x.copy() did = True while did: did = False for a in self.bnet.index.comprv - r - anc: if self.is_rate(a) or self.find_node_rv_out(a) is not None: continue if not r.super_of(self.bnet.get_parents(a)): continue r += a did = True r -= x return r def get_chans(self): r = [] for x in igen.subset(self.bnet.index.comprv, minsize=1): y = self.chan_in_get_out(x) if not y.isempty(): r.append((x, y)) return r def rv_single_node(self, a): nodes = self.find_nodes_immediate_ratervs(a) node2 = self.find_node_rv_out(a) c = len(nodes) if node2 is not None and node2 not in nodes: c += 1 return c <= 1 def remove_aux(self, x): for node in self.nodes: node.aux_out -= x node.aux_dec -= x node.aux_ndec -= x node.aux_ndec_try -= x node.aux_ndec_force -= x node.aux_borrow -= x def get_node_union_parents(self, nodes): parents = None for node in nodes: t = self.bnet.get_parents(node.rv_out) - node.rv_in_causal if parents is None: parents = t else: parents = parents.inter(t) return parents def get_aux_union_parents(self, aux): nodes = [node for node in self.nodes if node.aux_out.ispresent(aux) or node.aux_borrow.ispresent(aux)] return self.get_node_union_parents(nodes) def get_node_unions(self): n = len(self.nodes) vis = set() r = [] for mask in range((1 << n) - 1, 0, -1): nodes = [node for i, node in enumerate(self.nodes) if mask & (1 << i)] parents = self.get_node_union_parents(nodes) # print(mask) # print(parents) # print() if parents is None or parents.isempty(): continue parents_mask = self.bnet.index.get_mask(parents) if parents_mask in vis: continue vis.add(parents_mask) r.append(nodes) return r def get_index(self): index = IVarIndex() self.bnet.index.comprv.record_to(index) for node in self.nodes: if node.aux_out is not None: node.aux_out.record_to(index) for v0, v1 in self.sublist + self.sublist_rate: v0.record_to(index) v1.record_to(index) return index def name_avoid(self, name): return self.get_index().name_avoid(name) def calc_node_union(self, nodes): inner_mode = self.inner_mode for node in nodes: if node.aux_out is None: node.aux_out = Comp.empty() ratervs = sum((self.get_node_immediate_ratervs(node, original = True) for node in nodes), Comp.empty()) rv_out = sum((node.rv_out for node in nodes), Comp.empty()) des = self.get_node_descendants(nodes) ratervs = sum((c for c in ratervs if self.is_ch_rate(c) or any(self.get_node_immediate_ratervs(d, include_ndec_force = True).ispresent(c) for d in des)), Comp.empty()) # des = self.get_node_dependent_after(node) crvs = list(ratervs) if len(crvs) == 0: crvs.append(Comp.empty()) inner_mode = "combinations" if inner_mode == "combinations": for c in crvs: for mask in range(1, 1 << len(des)): cname = "A_" + rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD) # cname_l = ("A_{" + rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) # + ("" if c.isempty() else "," + c.tostring(style = PsiOpts.STR_STYLE_LATEX)) + "}") cname_l = "A_{" + rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) if len(des) > 1: cname_l += " \\to " ccount = 0 for i, d in enumerate(des): if mask & (1 << i): cname += "_" + d.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD) if ccount: cname_l += "," ccount += 1 cname_l += d.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) cname_l += "}" cname += ("" if c.isempty() else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD)) cname_l += ("" if c.isempty() else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}") cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + cname_l cname = self.name_avoid(cname) a = Comp.rv(cname) # self.aux_importance.append((a, -iutil.bit)) for i, node in enumerate(nodes): if i == 0: node.aux_out += a node.aux_out_sublist.append((a, a + c)) else: node.aux_borrow += a for i, d in enumerate(des): if mask & (1 << i): if d.aux_dec is None: d.aux_dec = Comp.empty() d.aux_dec += a else: d.aux_ndec_try += a if d.rv_ndec_force.ispresent(c): d.aux_ndec_force += a elif inner_mode == "plain": if len(des): for c in crvs: single_node = self.rv_single_node(c) if single_node: cname = ("A_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD)) else: cname = ("A_" + rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD) + ("" if c.isempty() or len(crvs) == 1 else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD))) cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" # print(str(c) + " " + str(len(self.find_nodes_immediate_ratervs(c)))) if single_node: cname += ("A_{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}") else: cname += ("A_{" + rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}" + ("" if c.isempty() or len(crvs) == 1 else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}")) cname = self.name_avoid(cname) a = Comp.rv(cname) if self.is_rate_zero(c): self.aux_dummy += a for i, node in enumerate(nodes): if i == 0: node.aux_out += a node.aux_out_sublist.append((a, a + c)) else: # pass node.aux_borrow += a for d in des: if self.get_node_immediate_ratervs(d).ispresent(c): if d.aux_dec is None: d.aux_dec = Comp.empty() d.aux_dec += a else: d.aux_ndec_try += a if d.rv_ndec_force.ispresent(c): d.aux_ndec_force += a def calc_node(self, node): inner_mode = self.inner_mode if node.aux_out is None: node.aux_out = Comp.empty() ratervs = self.get_node_immediate_ratervs(node) des = self.get_node_descendants(node) # des = self.get_node_dependent_after(node) crvs = list(ratervs) if len(crvs) == 0: crvs.append(Comp.empty()) inner_mode = "combinations" if inner_mode == "combinations": for c in crvs: for mask in range(1, 1 << len(des)): cname = "A_" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD) # cname_l = ("A_{" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) # + ("" if c.isempty() else "," + c.tostring(style = PsiOpts.STR_STYLE_LATEX)) + "}") cname_l = "A_{" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) for i, d in enumerate(des): if mask & (1 << i): cname += "_" + d.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD) cname_l += "," + d.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) cname_l += "}" cname += ("" if c.isempty() else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD)) cname_l += ("" if c.isempty() else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}") cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + cname_l cname = self.name_avoid(cname) a = Comp.rv(cname) # self.aux_importance.append((a, -iutil.bit)) node.aux_out += a node.aux_out_sublist.append((a, a + c)) for i, d in enumerate(des): if mask & (1 << i): if d.aux_dec is None: d.aux_dec = Comp.empty() d.aux_dec += a else: d.aux_ndec_try += a if d.rv_ndec_force.ispresent(c): d.aux_ndec_force += a elif inner_mode == "plain": if len(des): for c in crvs: single_node = self.rv_single_node(c) if single_node: cname = ("A_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD)) else: cname = ("A_" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD) + ("" if c.isempty() or len(crvs) == 1 else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD))) cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" # print(str(c) + " " + str(len(self.find_nodes_immediate_ratervs(c)))) if single_node: cname += ("A_{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}") else: cname += ("A_{" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}" + ("" if c.isempty() or len(crvs) == 1 else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}")) cname = self.name_avoid(cname) a = Comp.rv(cname) node.aux_out += a node.aux_out_sublist.append((a, a + c)) for d in des: if self.get_node_immediate_ratervs(d).ispresent(c): if d.aux_dec is None: d.aux_dec = Comp.empty() d.aux_dec += a else: d.aux_ndec_try += a if d.rv_ndec_force.ispresent(c): d.aux_ndec_force += a def calc_node_unions(self): if all(node.aux_out is not None for node in self.nodes): return unions = self.get_node_unions() for node in self.nodes: if node.aux_out is None: node.aux_out = Comp.empty() for nodes in unions: self.calc_node_union(nodes) def calc_nodes(self): for node in self.nodes: self.calc_node(node) def clear_cache(self): for node in self.nodes: node.clear() def presolve(self): self.sublist += self.sublist_rate self.sublist_rate = [] def get_region(self, oneshot = False): r = self.bnet.copy() tosublist = [] for node in self.nodes: tosublist += node.aux_out_sublist tosublist += self.sublist + self.sublist_rate convexify = None if not oneshot: convexify = Comp.rv("#TMPVAR_GET_REGION") for node in self.nodes: r += (convexify, node.rv_out) r = r.contracted_node(convexify) for v0, v1 in tosublist: if isinstance(v1, Expr) and not isinstance(v0, Expr): r = r.contracted_node(v0) else: r = r.contracted_node(v0) if self.partition is not None: reg = Region.universe() for i, a in enumerate(self.partition): a2 = sum((b for j, b in enumerate(self.partition) if j != i), Comp.empty()) reg &= r.eliminated(a2).get_region() return reg return r.get_region() def is_netcode(self): """Whether this setting is a network coding setting, i.e., all random variables are messages with rates. """ for a in self.bnet.index.comprv: if not self.is_rate(a): return False return True def get_netcode_region(self, convexify = None, mi = True, skip_simplify = False): """Rate region of network coding setting. If mi = False, use the Yan-Yeung-Zhang region (CAUTION: Convexification and closure is not performed): X. Yan, R. W. Yeung, and Z. Zhang, "An implicit characterization of the achievable rate region for acyclic multisource multisink network coding," IEEE Trans. Inf. Theory, vol. 58, no. 9, pp. 5625-5639, 2012. """ self.presolve() if self.bnet.tsorted() is None: raise ValueError("Non-strictly-causal cycle detected.") return None if convexify is None: convexify = True if convexify is not False: if convexify is True: convexify = Comp.index("Q_i") if not isinstance(convexify, Comp): convexify = Comp.empty() r = Region.universe() msgs = Comp.empty() for a in self.bnet.index.comprv: pa = self.bnet.get_parents(a) if pa.isempty(): if not mi: r &= Expr.I(msgs, a) == 0 msgs += a rate = self.get_rv_rates(a) if not mi: r &= Expr.H(a) >= rate r &= rate >= 0 else: a2 = self.get_rv_ratervs(a) if a == a2: rate = self.get_rv_rates(a) if mi: r &= Expr.I(a, pa) <= rate else: r &= Expr.Hc(a, pa) == 0 r &= Expr.H(a) <= rate else: for a3 in a2: rate = self.get_rv_rates(a3) if mi: r &= Expr.I(a3, pa) >= rate else: r &= Expr.Hc(a3, pa) == 0 if mi: tbnet = self.bnet.copy() for a in tbnet.index.comprv: tbnet.set_fcn(a, False) a2 = self.get_rv_ratervs(a) if a != a2: tbnet = tbnet.contracted_node(a) r &= tbnet.get_region() if False: tosublist = [] for node in self.nodes: tosublist += node.aux_out_sublist tosublist += self.sublist for v0, v1 in tosublist: if not (isinstance(v1, Expr) and not isinstance(v0, Expr)): r.substitute(v0, v1) for tozero, toconst in self.sublist_const: r.substitute(tozero, toconst) r = r.exists(self.bnet.index.comprv.inter(r.allcomp())) if not skip_simplify: r.simplify_aux_commonpart(maxlen = 3, forbid_h = False) r = r.simplified() return r def src_combine(self): r = [] groups = [] for a in self.bnet.index.comprv: if self.is_src_rate(a): p = self.bnet.get_parents(a) c = self.bnet.get_children(a) for g in groups: p2 = p.copy() for a2, c2 in g[1]: p2 -= a2 if g[0] == p2: g[1].append((a, c)) break else: groups.append((p, [(a, c)])) for p, g in groups: mask_set = set() for a, c in g: mask_set.add(self.bnet.index.get_mask(c)) mask_list = list(mask_set) for mmask in range(1, 1 << len(mask_list)): cmask = None for i, mask in enumerate(mask_list): if not mmask & (1 << i): continue if cmask is None: cmask = mask else: cmask |= mask if cmask == 0 or cmask in mask_set: continue mask_set.add(cmask) cc = self.bnet.index.from_mask(cmask) cg = [a for a, c in g if cc.super_of(c)] # cname = "(" + str(sum(cg)) +")" + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + "(" + sum(cg).tostring(style = "latex") + ")" cname = str(sum(cg)) + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + sum(cg).tostring(style = "latex") cname = self.name_avoid(cname) crv = Comp.rv(cname) self.set_rate(crv, 0) self.add_node(p, crv) for cc2 in cc: self.bnet.add_edge(crv, cc2) self.bnet = self.bnet.tsorted() return self def ch_combine(self, min_ndummy = 0): groups = [] for a in self.bnet.index.comprv: if self.is_ch_rate(a): p = self.bnet.get_children(a) for g in groups: if g[0] == p: g[1].append(a) break else: groups.append((p, [a])) for p, g in groups: decs = [] for i, a in enumerate(g): for b in self.get_refs(a) - a: bp = self.bnet.get_parents(b) if bp not in decs: decs.append(bp) # Whether y is degraded compared to x deg = [[self.bnet.check_ic(Expr.Ic(p, x, y)) for y in decs] for x in decs] idecs = [0 for i in range(len(g))] for i, a in enumerate(g): for b in self.get_refs(a) - a: bp = self.bnet.get_parents(b) bpi = decs.index(bp) for j in range(len(decs)): if deg[bpi][j]: idecs[i] |= 1 << j # print(idecs) dec_vis = list(idecs) ndummy = 0 for mask in range(1, 1 << len(g)): cdec = 0 cg = Comp.empty() for i, a in enumerate(g): if mask & (1 << i): cdec |= idecs[i] cg += a # print(str(mask) + " " + str(cdec)) if cdec in dec_vis: continue # cname = "(" + str(sum(g)) + ")" + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + "(" + sum(g).tostring(style = "latex") + ")" cname = str(sum(cg, Comp.empty())) + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + sum(cg, Comp.empty()).tostring(style = "latex") cname = self.name_avoid(cname) crv = Comp.rv(cname) # crate = Expr.real("MRATE_" + str(p) + "_" + str(mask)) # self.set_rate(crv, crate) self.set_rate(crv, 0) for pb in p: self.bnet.add_edge(crv, pb) # self.add_edge(crv, p, children_edge = False) # for j in range(len(decs)): # if cdec & (1 << j): # self.add_node(decs[j], crv) ndummy += 1 for i, a in enumerate(g): if idecs[i] | cdec == cdec: # print("TRY " + str(a) + " " + str(crv)) for b in self.get_refs(a) - a: node = self.find_node_rv_out(b) if node is not None: node.rv_ndec_force += crv # print("ADD " + str(node.rv_out) + " " + str(crv)) # self.sublist_zero.append(crate) dec_vis.append(cdec) for i in range(min_ndummy - ndummy): cname = str(sum(g, Comp.empty())) + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + sum(g, Comp.empty()).tostring(style = "latex") cname = self.name_avoid(cname) crv = Comp.rv(cname) self.set_rate(crv, 0) for pb in p: self.bnet.add_edge(crv, pb) self.bnet = self.bnet.tsorted() return self def get_src_splice(self): r = [] groups = [] for a in self.bnet.index.comprv: if self.is_src_rate(a): p = self.bnet.get_parents(a) c = self.bnet.get_children(a) for g in groups: p2 = p.copy() for a2, c2 in g[1]: p2 -= a2 if g[0] == p2: g[1].append((a, c)) break else: groups.append((p, [(a, c)])) def covered(g, i, mask): ci = g[i][1] cj = Comp.empty() for j in range(len(g)): if mask & (1 << j): cj += g[j][1] return cj.super_of(ci) for p, g in groups: for i in range(len(g)): ci = g[i][1] for mask in igen.subset_mask((1 << len(g)) - 1 - (1 << i)): if not covered(g, i, mask): continue bad = False for j in range(len(g)): if mask & (1 << j) and covered(g, i, mask - (1 << j)): bad = True break if bad: continue r.append((self.get_rv_rates(g[i][0]), [self.get_rv_rates(g[j][0]) for j in range(len(g)) if mask & (1 << j)])) return r def get_inner(self, convexify = None, ndec_mode = "all", skip_simplify = False, skip_aux = False, splice = True, combine = True, is_proof = False, num_dummy = 0, target = None): """Get an inner bound of the capacity region via [Lee-Chung 2015] together with simplification procedures. Si-Hyeon Lee, and Sae-Young Chung. "A unified approach for network information theory." 2015 IEEE International Symposium on Information Theory (ISIT). IEEE, 2015. """ if target is not None: skip_aux = False eliminate_quick = False if self.is_netcode(): return self.get_netcode_region(convexify = convexify, skip_simplify = skip_simplify) if combine: cs = self.copy() cs.ch_combine(num_dummy) cs.src_combine() return cs.get_inner(convexify, ndec_mode, skip_simplify, skip_aux, splice, False, is_proof, 0, target) verbose = PsiOpts.settings.get("verbose_codingmodel", False) self.presolve() if self.bnet.tsorted() is None: raise ValueError("Non-strictly-causal cycle detected.") return None if convexify is None: convexify = self.convexify_needed() if convexify is not False: if convexify is True: convexify = Comp.index("Q_i") if self.use_union: self.calc_node_unions() else: self.calc_nodes() for node in self.nodes: if node.aux_ndec is None: r = RegionOp.union([]) cndec_mode = node.ndec_mode if cndec_mode is None: cndec_mode = ndec_mode if cndec_mode == "all": # print("NODE " + str(node.rv_out) + " " + str(node.aux_ndec_force)) for cndec0 in igen.subset(node.aux_ndec_try - node.aux_ndec_force): cndec = node.aux_ndec_force + cndec0 # print(str(node.rv_out) + " " + str(cndec)) node.aux_ndec = cndec if isinstance(cndec, Comp) else Comp.empty() t = self.get_inner(convexify, ndec_mode = ndec_mode, skip_simplify = True, skip_aux = True, splice = splice, combine = combine, is_proof = is_proof, num_dummy = num_dummy, target = target) if t is not None: if target is not None: r = t break r |= t elif cndec_mode == "min" or cndec_mode == "none": # node.aux_ndec = Comp.empty() node.aux_ndec = node.aux_ndec_force.copy() r = self.get_inner(convexify, ndec_mode = ndec_mode, skip_simplify = True, skip_aux = True, splice = splice, combine = combine, is_proof = is_proof, num_dummy = num_dummy, target = target) elif cndec_mode == "max": node.aux_ndec = node.aux_ndec_try.copy() r = self.get_inner(convexify, ndec_mode = ndec_mode, skip_simplify = True, skip_aux = True, splice = splice, combine = combine, is_proof = is_proof, num_dummy = num_dummy, target = target) node.aux_ndec = None if r is None: return None if target is not None: return r if not skip_aux: aux = self.get_aux() caux = aux.copy() if convexify is not False: caux += convexify r = r.eliminate(caux.inter(r.allcomprv()), quick = eliminate_quick) if skip_simplify: return r else: r.remove_missing_aux() r.simplify() r.simplify_union() # print(r) r.remove_missing_aux() r = r.simplified_quick() if verbose: print("============== Simplified ==============") print(r) return r aux = self.get_aux() # aux_rates = [Expr.real("#TR" + str(i)) for i in range(len(aux))] cindex = self.get_index() aux_rates = [] for a in aux: cname = iutil.name_prefix_suffix(a.get_name(), "R_{", "}") cname = cindex.name_avoid(cname) aux_rates.append(Expr.real(cname)) aux_rates[-1].record_to(cindex) cbnet = self.bnet.copy() r = Region.universe() for rate in aux_rates: r &= rate >= 0 nodes_code = [[] for node in self.nodes] for node, node_code in zip(self.nodes, nodes_code): y = self.bnet.get_parents(node.rv_out) - node.rv_in_causal aux_dec = Comp.empty() if node.aux_dec is not None: aux_dec = node.aux_dec auxlist = [(auxx, self.get_aux_union_parents(auxx)) for auxx in node.aux_out] # auxlist.sort(key = lambda auxx, auxp: ) for i, (auxx, auxp) in enumerate(auxlist): caux = Comp.empty() for (aux2, aux2p) in auxlist[:i]: if auxp.super_of(aux2p): caux += aux2 cbnet.add_edge(aux_dec + auxp + node.aux_borrow + caux, auxx) cbnet.add_edge(aux_dec + y + node.aux_borrow + node.aux_out, node.rv_out) # Decoding constraints dec_inp = y + node.aux_borrow for cdec in igen.subset(aux_dec, minsize = 1): for cndec in igen.subset(node.aux_ndec): cd = cdec + cndec expr = Expr.zero() cc = Comp.empty() for c in cd: ci = aux.index_of(c) expr += aux_rates[ci] expr -= I(c & cc + (aux_dec + node.aux_ndec - cd) + dec_inp) cc += c pf_note = ["dec ", dec_inp.copy(), " to ", "#CMD_EXCLUDE_PREV", cdec.copy()] if not cndec.isempty(): pf_note += [" ", "nonuniq ", cndec.copy()] r &= (expr <= 0).add_meta("pf_note", pf_note) if not aux_dec.isempty(): node_code.append({ "mode": "decode", "inp": dec_inp.copy(), "oup": aux_dec.copy(), "ndec": node.aux_ndec.copy() }) # Encoding constraints for i, (auxx, auxp) in enumerate(auxlist): caux = Comp.empty() caux_out = Comp.empty() bad = False for j, (aux2, aux2p) in enumerate(auxlist): if auxp == aux2p: if j > i: bad = True break caux_out += aux2 elif j < i and auxp.super_of(aux2p): caux += aux2 if bad: continue enc_inp = aux_dec + auxp + node.aux_borrow + caux for ce in igen.subset(caux_out, minsize = 1): expr = Expr.zero() cc = Comp.empty() for c in ce: ci = aux.index_of(c) expr += aux_rates[ci] expr -= I(c & cc + enc_inp) cc += c # if self.is_rate(auxp): # r &= (expr >= 0) # else: pf_note = ["enc ", enc_inp.copy(), " to ", "#CMD_EXCLUDE_PREV", ce.copy()] r &= (expr >= 0).add_meta("pf_note", pf_note).add_meta("pf_note_priority", -1) if not caux_out.isempty(): node_code.append({ "mode": "encode", "inp": enc_inp.copy(), "oup": caux_out.copy() }) if not node.rv_out.isempty(): node_code.append({ "mode": "generate", "inp": sum((cbnet.get_parents(x) for x in node.rv_out), Comp.empty()), "oup": node.rv_out.copy() }) # Check dummy duplicates for a in aux: if not self.is_aux_dummy(a): continue a_parents = self.get_parents_noaux(a, cbnet) a_rvout = self.get_node_rv_out_from_aux_dec(a) if a_rvout.isempty(): return None for b in aux - a: if a_parents == self.get_parents_noaux(b, cbnet): b_rvout = self.get_node_rv_out_from_aux_dec(b) if a_rvout == b_rvout: return None if not self.is_aux_dummy(b) and b_rvout.super_of(a_rvout): return None if convexify is not False: r.condition(convexify) for node in self.nodes: cbnet.add_edge(convexify, node.aux_out + node.rv_out) self.bnet_in = cbnet r &= cbnet.get_region() allrv = r.allcomprv() if is_proof: r.add_meta("inner_rv_list", CompArray(list(allrv)), children = False) tosublist = [] for node in self.nodes: tosublist += node.aux_out_sublist tosublist += self.sublist if verbose: print("============== Nodes ==============") for node in self.nodes: print(str(self.bnet.get_parents(node.rv_out)) + (", (borrow " + str(node.aux_borrow) + ")" if not node.aux_borrow.isempty() else "") + " -> " + str(node.rv_out) + ", aux " + str(node.aux_out) + ", dec " + str(node.aux_dec) + (", (ndec " + str(node.aux_ndec) + ")" if not node.aux_ndec.isempty() else "")) print("============== Bayes net ==============") print(cbnet) print("============== Aux ==============") for a in aux: print(str(a) + (" (dummy)" if self.is_aux_dummy(a) else "") + ": " + str(self.get_parents_noaux(a, cbnet)) + " -> " + str(self.get_node_rv_out_from_aux_dec(a))) print("============== Inner bound ==============") # print(r) r.print(note=True) with PsiOpts(meta_subs_criteria = "eqtype"): for v0, v1 in tosublist: r.substitute(v0, v1) if isinstance(v1, Expr) and not isinstance(v0, Expr): r &= v1 >= 0 # if iutil.check_meta_subs_criteria(v0, v1, self): # for node, node_code in zip(self.nodes, nodes_code): # iutil.substitute(node_code, v0, v1) if verbose: print("============== Substitute ==============") for v0, v1 in tosublist: print(str(v0) + " -> " + str(v1)) print("============== After substitute ==============") # print(r) r.print(note=True) reg_record = None if is_proof: reg_record = [] r = r.eliminate(sum(aux_rates, Expr.zero()), quick = eliminate_quick, reg_record = reg_record) # if convexify is not False: # r.condition(convexify) # rallcomprv = r.allcomprv() # encout = self.get_nodes_rv_out().inter(rallcomprv) # r &= markov(rallcomprv - encout - convexify, encout, convexify) caux = aux.copy() if convexify is not False: caux += convexify # r_prior = self.bnet.get_region().copy() # assume_reg = Region.universe() if not self.reg.isuniverse(): # if eliminate_quick and not skip_simplify: # r = r.simplified() if eliminate_quick: r = r.simplified() # r = r & self.reg # r = r.and_cause_consequence(self.reg, avoid = caux) # r = r.and_cause_consequence(self.reg, added_reg = r_prior) r = r.and_cause_consequence(self.reg) if self.partition is not None: r &= indep(*(self.partition)).add_meta("pf_note", ["partition"]) if not skip_aux: r = r.eliminate(caux.inter(r.allcomprv()), quick = eliminate_quick) if verbose: print("============== Elim aux rate ==============") # print(r) r.print(note=True) if splice: splice_list = self.get_src_splice() for w0, w1 in splice_list: r.splice_rate(w0, w1) if verbose: print("============== After splice ==============") # print(r) r.print(note=True) for tozero, toconst in self.sublist_const: r.substitute(tozero, toconst) with PsiOpts(meta_subs_criteria = "eqtype"): for a in cbnet.index.comprv: if self.is_rate_zero(a): r.substitute(a, Comp.empty()) did_simplify = False if target is not None: aux_assign = (target >> r).check_getaux() if aux_assign is None: return None with PsiOpts(meta_subs_criteria = "eqtype"): Comp.substitute_list(r, aux_assign, isaux = True) r = r.eliminate(target.aux.inter(r.allcomprv()), quick = eliminate_quick) r = r.simplified() did_simplify = True # print(r) if skip_simplify: pass # return r else: # r = r.simplified(reg = r_prior) r = r.simplified() # r = r.simplified(self.reg) r.remove_missing_aux() did_simplify = True if verbose: print("============== Simplified ==============") # print(r) r.print(note=True) if is_proof: if skip_aux or not did_simplify: if skip_aux: r = r.eliminate(caux.inter(r.allcomprv()), quick = eliminate_quick) # with PsiOpts(meta_subs_criteria = "eqtype"): # r = r.simplified() r = r.simplified() # r = r.simplified(self.reg) r.remove_missing_aux() if skip_aux: for c in caux: r.substitute_aux(c, c) did_simplify = True if verbose: print("============== Simplified (proof) ==============") # print(r) r.print(note=True) allrv_after = r.get_meta("inner_rv_list") for a in self.bnet.index.comprv + aux + allrv_after.allcomprv(): if self.is_rate_zero(a): iutil.substitute(allrv_after, a, Comp.empty()) for node, node_code in zip(self.nodes, nodes_code): iutil.substitute(node_code, a, Comp.empty()) if verbose: print("============== rv after ==============") print(allrv_after) rnote = [] # print(r) # print(repr(aux_after)) # print() rate_reg_mode = "one" def postprocess(x): for a, b in zip(allrv, allrv_after): iutil.substitute(x, a, b.copy()) if isinstance(x, Region): for v0, v1 in tosublist: x.substitute(v0, v1) for tozero, toconst in self.sublist_const: x.substitute(tozero, toconst) rates_wanted = Expr.zero() for a, b in zip(allrv, allrv_after): if aux.ispresent(a) and not b.isempty(): crate = aux_rates[aux.index_of(a)] rates_wanted += crate # print(rates_wanted) # print(reg_record) # print() Region.eliminate_reg_record_clean(reg_record, rates_wanted, cross=True) # print(reg_record) # print() for node, node_code in zip(self.nodes, nodes_code): for ccode in node_code: if "oup" in ccode: ccode["oup_cb"] = ccode["oup"].copy() ccode["oup_id"] = ccode["oup"].copy() if "ndec" in ccode: ccode["ndec_cb"] = ccode["ndec"].copy() ccode["ndec_id"] = ccode["ndec"].copy() cnote = ["Codebook:"] rnote.append(cnote) cnaux = 0 for a, b in zip(allrv, allrv_after): b_cb = b.copy() b_id = b.copy() if aux.ispresent(a): id_name = "" if len(str(cnaux + 1)) >= 2: id_name = "i_{" + str(cnaux + 1) + "}" else: id_name = "i_" + str(cnaux + 1) b_cb = sum((Comp.rv(iutil.name_prefix_suffix(x.get_name(), "", "[" + id_name + "]")) for x in b), Comp.empty()) b_id = Comp.rv(id_name) if not b.isempty(): crate = aux_rates[aux.index_of(a)] found = False if rate_reg_mode == "detail": cnote = [" " * 2, b_cb, ",", " ", "rate = ", crate] rnote.append(cnote) for reg_var, reg_r in reg_record: if reg_var.get_name() == crate.get_name(): reg_r2 = reg_r.copy() reg_r2.remove_meta("pf_note") postprocess(reg_r2) reg_r2.simplify(reg = r) if reg_r2.isuniverse(): continue if rate_reg_mode == "one": one_bound = reg_r2.get_one_bound(crate, 2) if one_bound is None: continue cnote = [" " * 2, b_cb, ",", " ", "rate = ", one_bound] rnote.append(cnote) found = True elif rate_reg_mode == "detail": cnote = [" " * 4, reg_r2] rnote.append(cnote) if rate_reg_mode == "one" and not found: cnote = [" " * 2, b_cb] rnote.append(cnote) cnaux += 1 # print(a) # print(b) # print(b_cb) # print(b_id) # print() for node, node_code in zip(self.nodes, nodes_code): for ccode in node_code: for key, value in ccode.items(): if key in ["oup_cb", "ndec_cb"]: iutil.substitute(value, a, b_cb.copy()) elif key in ["oup_id", "ndec_id"]: iutil.substitute(value, a, b_id.copy()) else: iutil.substitute(value, a, b.copy()) for node, node_code in zip(self.nodes, nodes_code): nline = 0 for ccode in node_code: cnote = [] if ccode["mode"] in ("encode", "decode"): if ccode["oup"].isempty(): continue if nline == 0: if node.label is not None and node.label != "": cnote += [node.label, " "] else: cnote += [" " * 4] # cnote += ["decodes", " ", ccode["inp"], " ", "to", " ", ccode["oup"]] cnote += ["finds", " ", ccode["oup_id"], ":", " "] if "ndec" in ccode and not ccode["ndec"].isempty(): cnote += ["exists", " ", ccode["ndec_id"], ":", " "] cnote += ["(", ccode["inp"]] if "ndec" in ccode and not ccode["ndec"].isempty(): cnote += [",", ccode["ndec_cb"]] cnote += [",", ccode["oup_cb"], ")", "is typical"] rnote.append(cnote) nline += 1 elif ccode["mode"] == "generate": oup = ccode["oup"] - ccode["inp"] inp = ccode["inp"] - self.get_rate_rvs() if oup.isempty(): continue if nline == 0: if node.label is not None and node.label != "": cnote += [node.label, " "] else: cnote += [" " * 4] cnote += ["generates", " ", oup, " ", "from", " ", inp] rnote.append(cnote) nline += 1 # print(repr(rnote)) r.add_meta("pf_note_pre", rnote, children = False) return r def nfold(self, n = 2, natural_eqprob = True, all_eqprob = True): """ Return the n-letter setting. Parameters ---------- n : int The blocklength. Returns ------- CodingModel. """ r = self.copy() bnet_out = BayesNet() reg_out_map = {} clen = n toeqprob = Comp.empty() for x in self.bnet.allcomp() + self.reg.allcomprv(): t = None if self.is_rate(x): t = x.copy() else: t = rv_array(x.get_name(), n) t[0] = x.copy() if natural_eqprob or all_eqprob: toeqprob += x reg_out_map[x.get_name()] = t for (a, b) in self.bnet.edges(): if not all_eqprob: toeqprob -= b am = reg_out_map[a.get_name()] bm = reg_out_map[b.get_name()] node = self.find_node_rv_out(b) if node is None: bnet_out += (am, bm) else: if node.rv_in_causal.ispresent(a) and isinstance(am, CompArray) and isinstance(bm, CompArray): for t1 in range(clen): for t2 in range(t1 + 1, clen): bnet_out += (am[t1]+bm[t1], bm[t2]) else: bnet_out += (am.allcomp(), bm.allcomp()) for x in self.bnet.allcomp(): if self.bnet.is_fcn(x): xm = reg_out_map[x.get_name()] bnet_out.set_fcn(xm) r.bnet = bnet_out r.reg = Region.universe() for i in range(n): tr = self.reg.copy() for x in self.reg.allcomprv(): v = reg_out_map[x.get_name()] if isinstance(v, CompArray): tr.substitute(x, v[i]) r.reg = r.reg & tr if not toeqprob.isempty(): for i in range(1, n): v0 = CompArray.empty() vi = CompArray.empty() for x in toeqprob: v0.append(reg_out_map[x.get_name()][0]) vi.append(reg_out_map[x.get_name()][i]) r.reg = r.reg & (ent_vector(*v0) == ent_vector(*vi)) def map_getlist_id(a, i): if isinstance(a, Expr): return a.copy() t = None for c in a: cm = reg_out_map[c.get_name()] if isinstance(cm, CompArray): cm = cm[i] if t is None: t = cm.copy() else: t += cm return t r.sublist = [] for i in range(n): for v0, v1 in self.sublist: r.sublist.append((map_getlist_id(v0, i), map_getlist_id(v1, i))) r.nodes = [] for i in range(n): for node in self.nodes: v = reg_out_map[node.rv_out.get_name()] if isinstance(v, CompArray): v = v[i] else: if i != 0: continue cnode = node.copy() cnode.clear() cnode.rv_out = v.copy() r.nodes.append(cnode) return r def convexify_needed(self): return self.get_outer(convexify = False, convexify_test = True) def get_outer(self, aux = None, oneshot = False, future = True, convexify = None, add_csiszar_sum = True, leaf_remove = True, node_fcn = True, node_fcn_force = False, skip_simplify = True, convexify_test = False, is_proof = False, include_nondecode_series = False, include_last_future = False, remove_created = True, full = False): """Get an outer bound of the capacity region. """ if full: future = True add_csiszar_sum = True include_nondecode_series = True include_last_future = True decoding_rate = True # is_proof future_also = is_proof nat_msg_rate = False # not is_proof coded_combination = is_proof # if self.is_netcode(): # return self.get_netcode_region(convexify = convexify) self.presolve() if self.bnet.tsorted() is None: raise ValueError("Non-strictly-causal cycle detected.") return None if convexify is None: convexify = self.convexify_needed() if convexify is not False: if convexify is True: convexify = Comp.index("Q_o") # auxs = self.get_aux() # reg_out_aux = auxs.copy() reg_out_aux = Comp.empty() bnet_out = BayesNet() reg_out_map = {} clen = 1 if oneshot else 3 if future else 2 ttr = [1, 0, 2] # for v0, v1 in self.sublist: # print(str(v0) + " " + str(v1)) rv_wfuture = Comp.empty() to_csiszar = [] rv_series = Comp.empty() rv_timedep = Comp.empty() for x in self.bnet.allcomp(): removable = False removed = False if leaf_remove and self.bnet.get_children(x).isempty() and x.ispresent(self.get_rv_sub(x)): removable = True # print(str(x) + " " + str(self.is_rate(x)) + " " + str(self.get_rv_rates(x))) t = None if self.is_rate(x): # t = CompArray([x.copy()]) t = x.copy() elif oneshot: t = CompArray([x.copy()]) elif future: t = CompArray.series_sym(x) if removable: t[1] = Comp.empty() t[2] = Comp.empty() removed = True else: rv_wfuture += x else: t = CompArray.series_sym(x, suff = None) reg_out_map[x.get_name()] = t if isinstance(t, CompArray): for tt in t: rv_timedep += tt for tt in t[1:]: rv_series += tt if not removed: to_csiszar.append(t) for (a, b) in self.bnet.edges(): # if not self.get_rv_sub(b).ispresent(b): # continue am = reg_out_map[a.get_name()] bm = reg_out_map[b.get_name()] node = self.find_node_rv_out(b) if node is None: bnet_out += (am.swapped_id(0, 1), bm.swapped_id(0, 1)) else: # if convexify is False and self.is_rate(b) and am.allcomp().ispresent(rv_series): # rv_series += bm.allcomp() if node.rv_in_causal.ispresent(a) and isinstance(am, CompArray) and isinstance(bm, CompArray): for t1 in range(clen): for t2 in range(t1, clen): if convexify is False: bnet_out += (am[ttr[t1]] + bm[ttr[t1]], bm[ttr[t2]]) else: bnet_out += (am[ttr[t1]], bm[ttr[t2]]) else: if convexify is False: bnet_out += (am.swapped_id(0, 1).allcomp(), bm.swapped_id(0, 1).allcomp()) else: for tbm in bm.swapped_id(0, 1).allcomp(): bnet_out += (am.swapped_id(0, 1).allcomp(), tbm) for b in self.bnet.allcomp(): bm = reg_out_map[b.get_name()] if not isinstance(bm, CompArray): continue node = self.find_node_rv_out(b) if node is None: continue for a in node.rv_in_scausal: am = reg_out_map[a.get_name()] if not isinstance(am, CompArray): continue for t1 in range(clen): for t2 in range(t1, clen): if t1 == 1 and t2 == 1: continue if convexify is False: bnet_out += (am[ttr[t1]] + bm[ttr[t1]], bm[ttr[t2]]) else: bnet_out += (am[ttr[t1]], bm[ttr[t2]]) if convexify is False: for x in self.bnet.allcomp(): parent_empty = self.bnet.get_parents(x).isempty() if parent_empty: xm = reg_out_map[x.get_name()] if isinstance(xm, CompArray) and len(xm) >= 3: bnet_out += (xm[1], xm[2]) rvnats = Comp.empty() for x in self.bnet.allcomp(): parent_empty = self.bnet.get_parents(x).isempty() if parent_empty or self.find_node_rv_out(x) is not None: xm = reg_out_map[x.get_name()] if isinstance(xm, CompArray): for txm in (xm[1:] if parent_empty else xm): if txm.isempty(): continue series_parent = bnet_out.get_parents(txm).ispresent(rv_series) if series_parent: continue rvnats += txm if convexify is not False: bnet_out += (convexify, txm) else: if not txm.ispresent(rv_series[0]): if convexify_test: # print((rv_series[0], txm)) return True bnet_out += (rv_series[0], txm) # print(rv_series) r_fcn = Region.universe() for x in self.bnet.allcomp(): if self.bnet.is_fcn(x): xm = reg_out_map[x.get_name()] bnet_out.set_fcn(xm) elif node_fcn and self.find_node_rv_out(x) is not None: xm = reg_out_map[x.get_name()] if isinstance(xm, CompArray): for txm in xm: if txm.isempty(): continue series_parent = bnet_out.get_parents(txm).ispresent(rv_series) if node_fcn_force or convexify is not False or series_parent: bnet_out.set_fcn(txm) # print(txm) if convexify is False: for b in rv_series: b2 = bnet_out.get_parents(txm) - xm.allcomp() + b if not b2.super_of(txm): r_fcn &= Expr.Hc(txm, b2) == 0 else: if future: bnet_out.set_fcn(xm) bnet_out_final = bnet_out.copy() if remove_created: for x in self.bnet.allcomp(): if self.created_rv.ispresent(x) and self.is_rate(x): xm = reg_out_map[x.get_name()] for a in xm: bnet_out_final = bnet_out_final.eliminated(a) bnet_out_final = bnet_out_final.scc() # if convexify_test: # return False if not convexify_test: # print(bnet_out) self.bnet_out_final = bnet_out_final self.bnet_out = bnet_out self.bnet_out_arrays = [] for x in self.bnet.allcomp(): xm = reg_out_map[x.get_name()] if isinstance(xm, CompArray): t = xm.swapped_id(0, 1).allcomp() if len(t) > 1: self.bnet_out_arrays.append(t) r = bnet_out_final.get_region().add_meta("pf_note", ["Bayesian network"]) # if not self.reg.isuniverse(): # # r = r & self.reg # r = r.and_cause_consequence(self.reg) if not convexify_test and not oneshot and future and add_csiszar_sum: # r &= csiszar_sum(*(list(x for _, x in reg_out_map.items()))) r &= csiszar_sum(*to_csiszar).add_meta("pf_note", ["Csiszar sum"]).add_meta("dual_weight", 5.0) r &= r_fcn if convexify is not False: reg_out_aux += convexify for x in rv_series: r &= Expr.Hc(convexify, x) == 0 reg_out_aux += bnet_out.allcomp() - self.bnet.allcomp() # for b in self.bnet.allcomp(): # b2 = self.get_rv_sub(b) # if b2.ispresent(b): # continue # if self.is_rate(b): # rate = self.get_rv_rate(b) # r &= rate >= 0 def map_getlist(a): t = None for c in a: cm = reg_out_map[c.get_name()].copy() if t is None: t = cm.copy() else: if isinstance(t, Comp) and isinstance(cm, CompArray): t, cm = cm, t if isinstance(cm, Comp) and isinstance(t, CompArray): t[0] = t[0] + cm else: t = t + cm return t def map_substitute(r, a, b): am = map_getlist(a) bm = map_getlist(b) if isinstance(am, Comp) or isinstance(bm, Comp): r.substitute(am.allcomp(), bm.allcomp()) else: for a2, b2 in zip(am, bm): r.substitute(a2, b2) # Source-channel data processing if (convexify_test or (not oneshot and convexify is not False)) and all(node.rv_in_scausal.isempty() for node in self.nodes): cconvexify = Comp.empty() if convexify is not False: cconvexify = convexify seqs = sum((a for a in self.bnet.allcomp() if isinstance(reg_out_map[a.get_name()], CompArray)), Comp.empty()) srcs = self.get_srcs() chans = self.get_chans() for sa in igen.subset(srcs.inter(seqs), minsize=1): sam = map_getlist(sa) for sb, sc2 in chans: if not seqs.super_of(sb): continue sbm = map_getlist(sb) for sc in igen.subset(sc2.inter(seqs), minsize=1): scm = map_getlist(sc) if self.bnet.check_ic(Expr.Ic(sa, sc, sb)): for sd in igen.subset(seqs - sa - sb, minsize=1): sdm = map_getlist(sd) if self.bnet.check_ic(Expr.Ic(sa + sb, sd, sc)): cexpr = Expr.Ic(sbm[0], scm[0], cconvexify) - Expr.Ic(sam[0], sdm[0], cconvexify) cexpr.simplify_quick() if cexpr.isnonneg(): continue # print(sa, " - ", sb, " - ", sc, " - ", sd) if convexify_test: return True r &= (cexpr >= 0).add_meta("pf_note", ["data proc. ", sa, " - ", sb, " - ", sc, " - ", sd]) # for sa in igen.subset(seqs, minsize=1): # sam = map_getlist(sa) # if bnet_out.check_ic(Expr.Ic(sam[0], sam[1], cconvexify)): # for sc in igen.subset(seqs - sa, minsize=1): # sb = sum((self.bnet.get_parents(c) for c in sc), Comp.empty()) - sc # if sb.isempty() or not seqs.super_of(sb): # continue # sbm = map_getlist(sb) # scm = map_getlist(sc) # if bnet_out.check_ic(Expr.Ic(scm[0], sum(sbm) + scm[1], sbm[0] + cconvexify)): # if bnet_out.check_ic(Expr.Ic(sum(sam), sum(scm), sum(sbm))): # for sd in igen.subset(seqs - sa - sb, minsize=1): # sdm = map_getlist(sd) # if bnet_out.check_ic(Expr.Ic(sum(sam) + sum(sbm), sum(sdm), sum(scm))): # if convexify_test: # return True # r &= (Expr.Ic(sam[0], sdm[0], cconvexify) <= Expr.Ic(sbm[0], scm[0], cconvexify)).add_meta("pf_note", ["data proc. ", sa, " - ", sb, " - ", sc, " - ", sd]) if convexify_test: return False msgs = Comp.empty() msgnats = Comp.empty() for v0, v1 in self.sublist: if isinstance(v1, Expr) and not isinstance(v0, Expr): msgs += reg_out_map[v0[0].get_name()] vout = self.bnet.get_parents(v0[0]) if vout.isempty(): msgnats += reg_out_map[v0[0].get_name()] if decoding_rate: for a in self.bnet.index.comprv: a2 = self.get_rv_ratervs(a) if a != a2: vout = self.bnet.get_parents(a) voutm = map_getlist(vout) voutm0 = Comp.empty() voutm1 = Comp.empty() voutm2 = None if isinstance(voutm, CompArray): voutm0 = voutm[0] if len(voutm) >= 2: voutm1 = voutm[1] if len(voutm) >= 3: voutm2 = voutm[2] else: continue for a3 in a2: rate = self.get_rv_rates(a3) for cmsg in igen.subset(msgnats - a3, minsize = 1): r &= (Expr.Ic(a3, voutm0, voutm1 + cmsg) >= Expr.Ic(a3, voutm0, voutm1) ).add_meta("pf_note", ["indep. of msgs ", a3, ", ", cmsg]) if future_also and voutm2 is not None: r &= (Expr.Ic(a3, voutm0, voutm2 + cmsg) >= Expr.Ic(a3, voutm0, voutm2) ).add_meta("pf_note", ["indep. of msgs ", a3, ", ", cmsg]) r &= (Expr.Ic(a3, voutm0, voutm1) >= rate ).add_meta("pf_note", ["decode ", a3]).add_meta("dual_weight", 0.1) if future_also and voutm2 is not None: r &= (Expr.Ic(a3, voutm0, voutm2) >= rate ).add_meta("pf_note", ["decode ", a3]).add_meta("dual_weight", 0.15) codedmsgs = [] for v0, v1 in self.sublist: if isinstance(v0, Comp): rv_wfuture -= v0 if isinstance(v1, Expr) and not isinstance(v0, Expr): map_substitute(r, v0, v0[0]) isnat = False vout = self.bnet.get_parents(v0[0]) if vout.isempty(): vout = self.bnet.get_children(v0[0]) isnat = True voutm = map_getlist(vout) voutm0 = Comp.empty() voutm1 = Comp.empty() voutm2 = None if isinstance(voutm, CompArray): voutm0 = voutm[0] if len(voutm) >= 2: voutm1 = voutm[1] if len(voutm) >= 3: voutm2 = voutm[2] else: voutm0 = voutm am = reg_out_map[v0[0].get_name()] if isnat: if oneshot: r &= Expr.H(am) >= v1 r &= v1 >= 0 else: for cmsg in igen.subset(msgnats - am, minsize = 1): r &= (Expr.Ic(am, voutm0, voutm1 + cmsg) >= Expr.Ic(am, voutm0, voutm1) ).add_meta("pf_note", ["indep. of msgs ", am, ", ", cmsg]) if nat_msg_rate: r &= (Expr.Ic(am, voutm0, voutm1) >= v1 ).add_meta("pf_note", ["rate of ", am]) r &= v1 >= 0 # r &= Expr.Ic(am, voutm0, voutm1) == v1 else: if oneshot: r &= Expr.H(am) <= v1 else: codedmsgs.append((am, voutm0, voutm1, voutm2)) for ctv in igen.subset(rv_wfuture - vout): tva = map_getlist(ctv) tv = None if tva is None: tv = Comp.empty() else: tv = tva[0] + tva[1] + tva[2] # print(am) # print(ctv) # print(tva) # print(rv_wfuture) # print() for cmsg in igen.subset(msgs - am): r &= (Expr.Ic(am, voutm0, voutm1 + tv + cmsg) <= v1 ).add_meta("pf_note", ["rate of ", am]) if coded_combination and tva is not None: r &= (Expr.Ic(am, voutm0 + tva[0], voutm1 + tva[1] + cmsg) <= v1 ).add_meta("pf_note", ["rate of ", am]).add_meta("dual_weight", 0.5) if future_also and voutm2 is not None: r &= (Expr.Ic(am, voutm0, voutm2 + tv + cmsg) <= v1 ).add_meta("pf_note", ["rate of ", am]) if coded_combination and tva is not None: r &= (Expr.Ic(am, voutm0 + tva[0], voutm2 + tva[2] + cmsg) <= v1 ).add_meta("pf_note", ["rate of ", am]).add_meta("dual_weight", 0.5) # r &= Expr.Ic(am, voutm0, voutm1) == v1 # print(v0) # print(v1) reg_out_aux += am else: map_substitute(r, v0, v1) # print(r) for tozero, toconst in self.sublist_const: r.substitute(tozero, toconst) if not include_nondecode_series or not include_last_future: rv_decode = Comp.empty() rv_nondecode = Comp.empty() for a in self.bnet.index.comprv: alist = map_getlist(a) if not (isinstance(alist, CompArray) and len(alist) > 1): continue if any(self.find_node_rv_out(x) is not None for x in self.bnet.get_children(a)): rv_decode += a else: rv_nondecode += a rv_toelim = Comp.empty() if not include_nondecode_series: for a in rv_nondecode: alist = map_getlist(a) for i, b in enumerate(alist): if i: rv_toelim += b if not include_last_future: if len(rv_decode): alist = map_getlist(rv_decode[len(rv_decode) - 1]) if len(alist) >= 3: rv_toelim += alist[2] if not rv_toelim.isempty(): r = r.exists_quick(rv_toelim, method = "ci") if not self.reg.isuniverse(): # r = r & self.reg r = r.and_cause_consequence(self.reg) reg_out_aux = reg_out_aux.inter(r.allcomprv()) r = r.exists(reg_out_aux) if aux is not None: aux_pairs = [] aux_force = msgs if convexify is not False: aux_force = aux_force + convexify channel_ins = Comp.empty() for a in self.bnet.index.comprv: voutm = map_getlist(a) if isinstance(voutm, CompArray) and len(voutm) >= 3: if r.ispresent(voutm[1]) and r.ispresent(voutm[2]): aux_pairs.append((voutm[1], voutm[2])) pa = self.bnet.get_parents(a) if pa.ispresent(msgnats): channel_ins += sum(voutm, Comp.empty()) msgs_sorted = Comp.empty() for a in self.bnet.index.comprv: if msgs.ispresent(a): msgs_sorted += a msgs_sorted += msgs def score_fcn(a): r = 0 msgs_vis = [False] * len(msgs_sorted) hastime = False hassingleconvexify = False if len(a) >= 2: for b1, b2 in itertools.permutations(a, 2): if b1.super_of(b2): return None for b in a: hasconvexify = False hasseries = b.ispresent(rv_series) if hasseries: hastime = True if convexify is not False and b.ispresent(convexify): hasconvexify = True hastime = True r -= 1 if hasseries: return None nmsg = 0 for i, msg in enumerate(msgs_sorted): if b.ispresent(msg): if msgs_vis[i]: r += 100 else: r -= i msgs_vis[i] = True nmsg += 1 if nmsg == 0: if not hasconvexify: r += 20 elif nmsg > 1: r += nmsg * 70 if len(b) == 1: if not hasconvexify: r += 5 elif len(b) > 2: r += len(b) * 5 if b.ispresent(channel_ins): r += 30 if hasconvexify: if len(b) == 1: hassingleconvexify = True if len(b) > 1: r += len(b) * 10 for t in msgs_vis: if not t: r += 10 if not hastime: r += 20 return (r, -1 if hassingleconvexify else 0) r = r.aux_reduced(aux, aux_pairs = aux_pairs, aux_force = aux_force, score_fcn = score_fcn) if not skip_simplify: return r.simplified() else: return r def get_outer_nfold(self, n = 2, **kwargs): return self.nfold(n).get_outer(**kwargs) / n def proof_inner(self, r, *args, **kwargs): """Proof of an inner bound. """ r2 = r & self.get_region() if not self.reg.isuniverse(): r2 = r2.and_cause_consequence(self.reg) r_inner = None # with PsiOpts(simplify_aux_all = False): # r_inner = self.get_inner(*args, **kwargs) r_inner = self.get_inner(*args, is_proof = True, target = r2, **kwargs) if r_inner is None: return None r_inner = r_inner.tounion() for x, c in r_inner.regs: pf = None with PsiOpts(proof_new = True): if r2 >> x.exists(r_inner.aux): return PsiOpts.get_proof() return None def proof_outer(self, r, future = None, *args, **kwargs): """Proof of an outer bound. """ if future is None: future = [False, True] elif isinstance(future, bool): future = [future] for f in future: r_outer = self.get_outer(*args, is_proof = True, future = f, **kwargs) with PsiOpts(proof_new = True): if r_outer >> r: return PsiOpts.get_proof() return None def optimum(self, v, b, sn, name = None, inner = True, outer = True, inner_kwargs = None, outer_kwargs = None, tighten = False, **kwargs): """Return the variable obtained from maximizing (sn=1) or minimizing (sn=-1) the expression v over variables b (Comp, Expr or list) """ # if self.is_netcode(): # outer = False if inner_kwargs is None: inner_kwargs = dict() if outer_kwargs is None: outer_kwargs = dict() if inner: inner = self.get_inner(**inner_kwargs) else: inner = None if outer: outer = self.get_outer(**outer_kwargs) else: outer = None if name is not None: name += PsiOpts.settings["fcn_suffix"] if inner is not None: return inner.optimum(v, b, sn, name = name, reg_outer = outer, tighten = tighten, quick = None, quick_outer = True, **kwargs) if outer is not None: return outer.optimum(v, b, sn, name = name, tighten = tighten, quick = True, **kwargs) return None def maximum(self, expr, vs, **kwargs): """Return the variable obtained from maximizing the expression expr over variables vs (Comp, Expr or list) """ return self.optimum(expr, vs, 1, **kwargs) def minimum(self, expr, vs, **kwargs): """Return the variable obtained from minimizing the expression expr over variables vs (Comp, Expr or list) """ return self.optimum(expr, vs, -1, **kwargs) def node_groups(self): r = [] for node in self.nodes: a = self.bnet.get_parents(node.rv_out) b = node.rv_out.copy() for t in r: if a.super_of(t[1]) and (t[1] + t[0]).super_of(a): t[0] += b break else: r.append([b, a, node]) return r def node_group_info(self): groups = self.node_groups() return groups def graph(self, lr = True, enc_node = True, ortho = False, **kwargs): """Return the graphviz digraph of the network that can be displayed in the console. """ if graphviz is None: raise RuntimeError("Requires graphviz. Please install it first.") r = graphviz.Digraph() if lr: r.graph_attr["rankdir"] = "LR" if ortho: r.graph_attr["splines"] = "ortho" for key, value in kwargs.items(): r.graph_attr[key] = str(value) groups = self.node_groups() rvs = self.bnet.allcomp() for a in self.bnet.allcomp(): shape = "plaintext" #"oval" node = self.find_node_rv_out(a) label = str(self.get_rv_sub(a)) if node is not None and not enc_node: shape = "rect" if shape == "plaintext": r.node(a.get_name(), label, shape = shape, margin = "0") else: r.node(a.get_name(), label, shape = shape) if enc_node: for i, (b, a, node) in enumerate(groups): cname = "enc_" + str(a) + "_" + str(b) label = str(i + 1) if node.label is not None: label = node.label r.node(cname, label, shape = "rect") for ai in a: r.edge(ai.get_name(), cname) for bi in b: r.edge(cname, bi.get_name()) for ai in node.rv_in_scausal: r.edge(ai.get_name(), cname, style = "dashed") rvs -= b for (a, b) in self.bnet.edges(): if b in rvs: r.edge(a.get_name(), b.get_name()) return r def graph_outer(self, **kwargs): return self.bnet_out.graph(groups = self.bnet_out_arrays, **kwargs) class CommEnc(IBaseObj): def __init__(self, rv_out, msgs, rv_in_causal = None): # self.rv_in = Comp.empty() # self.msgs = [] # for x in list_in: # if isinstance(x, Comp): # self.rv_in += x # else: # self.msgs.append(x) self.msgs = msgs self.rv_out = rv_out if rv_in_causal is None: self.rv_in_causal = Comp.empty() else: self.rv_in_causal = rv_in_causal def record_to(self, index, skip_msg = False): # self.rv_in.record_to(index) self.rv_out.record_to(index) self.rv_in_causal.record_to(index) if not skip_msg: for x in self.msgs: for y in x: y.record_to(index) class CommDec(IBaseObj): def __init__(self, rv_in, msgs, msgints = None): self.rv_in = rv_in self.msgs = msgs self.msgints = msgints def record_to(self, index): self.rv_in.record_to(index) for x in self.msgs: for y in x: y.record_to(index) if self.msgints is not None: for x in self.msgints: for y in x: y.record_to(index) class CommModel(IBaseObj): def __init__(self, bnet = None, reg = None, nature = None): if bnet is None: self.bnet = BayesNet() else: self.bnet = bnet if reg is None: self.reg = Region.universe() else: self.reg = reg self.bnet_in = None self.reg_in = None self.bnet_out = None self.reg_out = None self.reg_out_aux = None self.reg_out_tsrv = None self.reg_out_vs = None self.reg_out_map = None self.msgs = None self.maux_rt = None self.maux_rtsub = None self.decs = None if nature is None: self.nature = Comp.empty() else: self.nature = nature self.enclist = [] self.declist = [] def __iadd__(self, other): if isinstance(other, CommEnc): self.enclist.append(other) elif isinstance(other, CommDec): self.declist.append(other) else: self.bnet += other return self def create_reg_in(self, name_prefix = "A_"): self.bnet_in = self.bnet.copy() self.msgs = [] self.decs = [] index = IVarIndex() self.reg.record_to(index) for x in self.enclist: x.record_to(index) for x in self.declist: x.record_to(index) for x in self.enclist: cauxall = Comp.empty() rv_in = self.bnet.get_parents(x.rv_out) rv_in = rv_in - x.rv_in_causal for m in x.msgs: caux = None if isinstance(m, Expr): tname = name_prefix + str(m) tname = index.name_avoid(tname) caux = Comp.rv(tname) caux.record_to(index) else: caux = m[1].copy() self.msgs.append([rv_in.copy(), m[0], caux, x.rv_out.copy()]) cauxall += caux self.bnet_in += (rv_in, cauxall) self.bnet_in += (cauxall, x.rv_out) # self.reg_in &= (Expr.Ic(cauxall, index.comprv - cauxall - (x.rv_in + x.rv_out), # x.rv_in + x.rv_out) == 0) # for x in self.declist: # self.decs.append([x.rv_in.copy(), x.msgs.copy(), iutil.copy(x.msgints)]) self.reg_in = self.bnet_in.get_region() & self.reg def get_encout(self): r = Comp.empty() for x in self.enclist: r += x.rv_out return r def enc_id_get_aux(self, i): r = Comp.empty() for rv_in, rtbs, b, rv_out in self.msgs: if self.enclist[i].rv_out.ispresent(rv_out): r += b return r def create_reg_out(self, future = True): self.create_reg_in("M_") auxs = sum(x[2] for x in self.msgs) self.reg_out_aux = auxs.copy() #self.reg_out_tsrv = Comp.rv("#TS") self.bnet_out = BayesNet() self.reg_out_map = {} clen = 3 if future else 2 ttr = [1, 0, 2] for x in self.bnet.allcomp(): t = None if future: t = CompArray.series_sym(x) else: t = CompArray.series_sym(x, suff = None) self.reg_out_map[x.get_name()] = t for (a, b) in self.bnet.edges(): self.bnet_out += (self.reg_out_map[a.get_name()], self.reg_out_map[b.get_name()]) for i in range(len(self.enclist)): auxs = self.enc_id_get_aux(i) rv_out = self.enclist[i].rv_out rv_in_causal = self.enclist[i].rv_in_causal rv_out_map = self.reg_out_map[rv_out.get_name()] for x in rv_out_map: self.bnet_out += (auxs, x) for pa in self.bnet.get_parents(rv_out): pa_map = self.reg_out_map[pa.get_name()] if rv_in_causal.ispresent(pa): for t1 in range(clen): for t2 in range(t1 + 1, clen): self.bnet_out += (pa_map[ttr[t1]], rv_out_map[ttr[t2]]) else: for t1 in range(clen): for t2 in range(clen): if t1 != t2: self.bnet_out += (pa_map[t1], rv_out_map[t2]) self.reg_out = self.bnet_out.get_region() & self.reg if future: self.reg_out &= csiszar_sum(*(list(x for _, x in self.reg_out_map.items()) + list(auxs))) self.reg_out_aux += self.bnet_out.allcomp() - self.bnet.allcomp() def get_outer(self, future = True, convexify = False, skip_simplify = False): self.create_reg_out(future = future) r = self.reg_out.copy() rts = self.get_rates() for rt in rts: r &= rt >= 0 for dec in self.declist: a = dec.rv_in rts = dec.msgs ap = self.reg_out_map[a.get_name()][1] for rt in rts: for i, (_, rtbs, b, _) in enumerate(self.msgs): if rtbs.ispresent(rt): r &= rt <= Expr.Ic(b, a, ap) aux = self.reg_out_aux # if convexify is not False: # if convexify is True: # convexify = Comp.index("Q_T") # r.condition(convexify) # encout = self.get_encout() # encout_map # r &= markov(r.allcomprv() - encout - convexify, encout, convexify) # aux += convexify r = r.exists(aux) if not skip_simplify: return r.simplified() else: return r def get_decreq(self): r = [] for dec in self.declist: a = dec.rv_in rts = dec.msgs reqs = [] reqints = None if dec.msgints is not None: reqints = [] for i, (_, rtbs, b, _) in enumerate(self.msgs): if any(rtbs.ispresent(rt) for rt in rts): reqs.append(i) if dec.msgints is not None and any(rtbs.ispresent(rt) for rt in dec.msgints): reqints.append(i) # cknown = a.copy() for j in range(len(reqs) - 1, -1, -1): # cdec = a.copy() # for cc in cknown: # ccan = self.bnet_in.get_ancestors(cc) # if cknown.super_of(ccan): # cdec += cc # r.append((reqs[j], cdec)) # TODO ????????????? r.append((reqs[j], a, reqs[j+1:], None if reqints is None else [a for a in reqints if a < reqs[j]])) # r.append((reqs[j], a+cknown)) # cknown += self.msgs[reqs[j]][2] return r def get_ratereq(self, known, reqprevs, i, tlist): verbose = PsiOpts.settings.get("verbose_commmodel", False) r = Region.universe() lmask = (1 << i) + sum(1 << t for t in tlist) submask = sum(1 << j for j in range(i) if self.maux_rtsub[j] is not None) for tmask in range(1 << len(tlist)): mask = (1 << i) + sum(1 << tlist[j] for j in range(len(tlist)) if tmask & (1 << j)) expr = Expr.zero() for j in range(i + 1): if mask & (1 << j): expr += self.maux_rt[j] if self.maux_rtsub[j] is not None: expr += self.maux_rtsub[j] expr -= mi_rect_max(self.msgs[j][2], [known] + [(self.msgs[j2][2], self.maux_rtsub[j2]) if self.maux_rtsub[j2] is not None else self.msgs[j2][2] for j2 in reqprevs] + [(self.msgs[j2][2], self.maux_rtsub[j2]) if self.maux_rtsub[j2] is not None else self.msgs[j2][2] for j2 in range(j) if lmask & (1 << j2)]) expr += mi_rect_max([self.msgs[j][0]] + [(self.msgs[j2][2], self.maux_rtsub[j2]) if self.maux_rtsub[j2] is not None else self.msgs[j2][2] for j2 in range(j)], self.msgs[j][2]) expr.simplify_quick() r &= (expr <= 0) r_st = None if verbose: print("========== inner bound ==========") print("known:" + str(known) + " decoded:" + str(sum(self.msgs[t][2] for t in reqprevs)) + " nonunique:" + str(sum(self.msgs[t][2] for t in tlist)) + " target:" + str(self.msgs[i][2])) r_st = str(r) print(r_st) r = r.flattened(minmax_elim = True) if verbose: r_st2 = str(r) if r_st2 != r_st: print("expanded:") print(r_st2) print("") return r def get_ratereq_old(self, known, i, tlist): r = Region.universe() lmask = (1 << i) + sum(1 << t for t in tlist) submask = sum(1 << j for j in range(i) if self.maux_rtsub[j] is not None) for tmask in range(1 << len(tlist)): mask = (1 << i) + sum(1 << tlist[j] for j in range(len(tlist)) if tmask & (1 << j)) for cp in itertools.product(*[igen.subset_mask(submask & ((1 << j) - 1)) for j in range(i + 1) if mask & (1 << j)]): expr = Expr.zero() caux = Comp.empty() #cauxmiss = Comp.empty() ji = 0 for j in range(i + 1): if mask & (1 << j): # TODO ?????????????????????????????????????? expr += self.maux_rt[j] - Expr.I(self.msgs[j][2], known + caux) if self.maux_rtsub[j] is not None: expr += self.maux_rtsub[j] cauxmiss = self.msgs[j][0].copy() for j2 in range(j): if submask & (1 << j2): if cp[ji] & (1 << j2): expr -= self.maux_rtsub[j2] cauxmiss += self.msgs[j2][2] else: cauxmiss += self.msgs[j2][2] if not cauxmiss.isempty(): expr += Expr.I(self.msgs[j][2], cauxmiss) ji += 1 if lmask & (1 << j): caux += self.msgs[j][2] #cauxmiss += self.msgs[j][2] expr.simplify_quick() r &= (expr <= 0) # print("==========") # print(str(known) + " " + str(self.msgs[i][2]) + " " + str(sum(self.msgs[t][2] for t in tlist))) # print(r) return r def get_rates(self): r = Expr.zero() for _, rts, a, _ in self.msgs: for rt in rts: if not r.ispresent(rt): r += rt return r def get_aux(self): r = Comp.empty() for _, rts, a, _ in self.msgs: r += a return r def get_inner_iter(self, subcodebook = True, skip_simplify = False): self.maux_rt = [rt[0] for _, rt, a, _ in self.msgs] self.maux_rtsub = [None] * len(self.msgs) if subcodebook is True: self.maux_rtsub = [Expr.real("#RS_" + x.get_name()) for _, rt, x, _ in self.msgs] self.maux_rtsub[-1] = None elif isinstance(subcodebook, Comp): self.maux_rtsub = [Expr.real("#RS_" + x.get_name()) if subcodebook.ispresent(x) else None for _, rt, x, _ in self.msgs] self.maux_rtsub[-1] = None subcodebook = True reqs = self.get_decreq() #r = Region.universe() r = self.reg_in.copy() rts = self.get_rates() for rt in rts: r &= rt >= 0 if subcodebook: for rt in self.maux_rtsub: if rt is not None: r &= rt >= 0 for reqi, known, reqprevs, reqints in reqs: r2 = Region.empty() if reqints is None: for mask in range(1 << reqi): tlist = [] for i in range(reqi): if mask & (1 << i): tlist.append(i) r2 |= self.get_ratereq(known, reqprevs, reqi, tlist) else: r2 = self.get_ratereq(known, reqprevs, reqi, reqints) if not skip_simplify: r2 = r2.simplified() r &= r2 if subcodebook and any(x is not None for x in self.maux_rtsub): r.eliminate(sum(x for x in self.maux_rtsub if x is not None)) return r def rate_splitting(self, r): if self.msgs is None: self.create_reg_in() rttmp = [Expr.real("#RT_" + x.get_name()) for _, rt, x, _ in self.msgs] for rt1, rt2 in zip((rt for _, rt, x, _ in self.msgs), rttmp): r.substitute(rt1, rt2) decflag = [0] * len(self.msgs) for i, (_, rtbs, b, _) in enumerate(self.msgs): for j, dec in enumerate(self.declist): a = dec.rv_in rts = dec.msgs if any(rtbs.ispresent(rt) for rt in rts): decflag[i] |= 1 << j rtauxs = [] n = len(self.msgs) rn = Region.universe() bexps = [Expr.zero() for i in range(n)] for i, (_, rtbs, b, _) in enumerate(self.msgs): vis = [] cexp = Expr.zero() for mask in range(1, 1 << n): if any(mask | v == mask for v in vis): continue cdecflag = 0 for i2 in range(n): if mask & (1 << i2): cdecflag |= decflag[i2] if cdecflag | decflag[i] != cdecflag: continue vis.append(mask) crtaux = Expr.real("#RTA_" + str(i) + "_" + str(mask)) rn &= crtaux >= 0 for i2 in range(n): if mask & (1 << i2): bexps[i2] += crtaux cexp += crtaux rtauxs.append(crtaux) rn &= rtbs <= cexp for i in range(n): rn &= bexps[i] <= rttmp[i] r &= rn rts = self.get_rates() for rt in rts: r &= rt >= 0 # print(r) r.eliminate(sum(rttmp + rtauxs, Expr.zero())) # print(r) return r def get_inner(self, subcodebook = True, rate_split = False, shuffle = False, convexify = False, convexify_diag = False, skip_simplify = False, skip_simplify_iter = False): r = None self.create_reg_in() aux = self.get_aux() if shuffle: r = Region.empty() tmaux = self.msgs for cmaux in itertools.permutations(tmaux): tbnet = self.bnet.copy() for i in range(len(cmaux) - 1): tbnet += (cmaux[i][3], cmaux[i+1][3]) if tbnet.iscyclic(): # print("CYCLIC " + str(cmaux)) continue self.msgs = list(cmaux) r |= self.get_inner_iter(subcodebook = subcodebook, skip_simplify = skip_simplify_iter) self.msgs = tmaux else: r = self.get_inner_iter(subcodebook = subcodebook, skip_simplify = skip_simplify_iter) if convexify_diag: r = r.convexified_diag(self.get_rates(), skip_simplify = skip_simplify) if convexify is not False: if convexify is True: convexify = Comp.index("Q_i") r.condition(convexify) encout = self.get_encout() r &= markov(r.allcomprv() - encout - convexify, encout, convexify) aux += convexify if rate_split: r = self.rate_splitting(r) if not skip_simplify: #return r.distribute().simplified() r = r.exists(aux) # print(r) r.simplify_union() # print(r) r.remove_missing_aux() return r.simplified_quick() else: return r.exists(aux) # Generators class igen: def subset_mask(mask): x = 0 while True: yield x if x >= mask: break x = ((x | ~mask) + 1) & mask def partition_mask(mask, n, max_mask = None): if n == 1: if mask != 0 and not (max_mask is not None and mask > max_mask): yield (mask,) return for xmask in igen.subset_mask(mask): if xmask == 0: continue if max_mask is not None and xmask > max_mask: break for t in igen.partition_mask(mask & ~xmask, n - 1, xmask): yield t + (xmask,) def partition(x, n): m = len(x) for mask in igen.partition_mask((1 << m) - 1, n): yield tuple(sum(t for i, t in enumerate(x) if xmask & (1 << i)) for xmask in mask) def subset(x, minsize = 0, maxsize = 100000, size = None, reverse = False, coeffmode = 0, replacement = False, zero_elem = 0): """Iterate sum of subset. Parameters: coeffmode : Set to 0 to only allow positive terms. Set to 1 to allow positive/negative terms, but not all negative. Set to 2 to allow positive/negative terms. Set to -1 to allow positive/negative terms, but not all positive/negative. """ #if isinstance(x, types.GeneratorType): if not hasattr(x, '__len__'): x = list(x) if size is not None: minsize = size maxsize = size iterfcn = None if replacement: iterfcn = itertools.combinations_with_replacement else: iterfcn = itertools.combinations n = len(x) cr = None if reverse: maxsize = min(maxsize, n) cr = range(maxsize, minsize - 1, -1) else: cr = range(minsize, maxsize + 1) for s in cr: if s > n: return if s == 0: if len(x) > 0: if isinstance(x[0], Comp): yield Comp.empty() elif isinstance(x[0], Expr): yield Expr.zero() elif isinstance(x[0], IBaseArray): yield type(x[0]).zeros(shape = x[0].shape) else: yield 0 else: if isinstance(x, Comp): yield Comp.empty() elif isinstance(x, Expr): yield Expr.zero() elif isinstance(x, IBaseArray): yield type(x).zeros(shape = x.shape) else: yield 0 elif s == 1: if coeffmode == 2: for y in x: yield y yield -y else: for y in x: yield y else: if coeffmode == 0: for comb in iterfcn(x, s): yield sum(comb, zero_elem) else: for comb in iterfcn(x, s): for sflag in range(int(coeffmode == -1), (1 << s) - int(coeffmode == 1 or coeffmode == -1)): yield sum((comb[i] * (-1 if sflag & (1 << i) else 1) for i in range(s)), zero_elem) def pm(x): """Plus or minus. """ for a in x: yield a yield -a def sI(x, maxsize = 100000, cond = True, ent = True, pm = False): """Iterate mutual information of variables in x. Parameters: cond : Whether to include conditional mutual information. ent : Whether to include entropy. pm : Whether to include both positive and negative. """ n = len(x) def xmask(mask): return sum((x[i] for i in range(n) if mask & (1 << i)), Comp.empty()) for mask in igen.subset([1 << i for i in range(n)], minsize = 1, maxsize = maxsize): bmask = mask if not cond: if ent: yield Expr.H(xmask(bmask)) if pm: yield -Expr.H(xmask(bmask)) while True: if not cond: if bmask <= 0: break amask = mask - bmask if amask > bmask: break if amask > 0: yield Expr.I(xmask(amask), xmask(bmask)) if pm: yield -Expr.I(xmask(amask), xmask(bmask)) else: a0mask = mask - bmask if ent: if a0mask > 0: yield Expr.Hc(xmask(a0mask), xmask(bmask)) if pm: yield -Expr.Hc(xmask(a0mask), xmask(bmask)) while a0mask > 0: a1mask = mask - bmask - a0mask if a1mask > a0mask: break if a1mask > 0: yield Expr.Ic(xmask(a1mask), xmask(a0mask), xmask(bmask)) if pm: yield -Expr.Ic(xmask(a1mask), xmask(a0mask), xmask(bmask)) a0mask = (a0mask - 1) & (mask - bmask) if bmask <= 0: break bmask = (bmask - 1) & mask def test(x, fcn, sgn = 0, yield_set = False): """Test fcn for values in generator x. Set sgn = 1 if fcn is increasing. Set sgn = -1 if fcn is decreasing. """ ub = MonotoneSet(sgn = 1) lb = MonotoneSet(sgn = -1) for a in x: if sgn != 0: if a in ub: continue if a in lb: continue if fcn(a): if not yield_set: yield a if sgn > 0: ub.add(a) if yield_set: yield (a, ub) elif sgn < 0: lb.add(a) if yield_set: yield (a, lb) else: if sgn > 0: lb.add(a) elif sgn < 0: ub.add(a) def remove_comments(s): com_chars = ("#", "%") lines = s.split("\n") r = "" for x in lines: if x.lstrip().startswith(com_chars): continue mi = len(x) for c in com_chars: i = x.find(c) if i >= 0: mi = min(mi, i) if r != "": r += "\n" r += x[:mi] return r def parse_sanitize(s, grammar = False): if not grammar: s = remove_comments(s) c = "" if not grammar: c = " " s = s.replace(r"\{", c + "BBRACE_L" + c) s = s.replace(r"\}", c + "BBRACE_R" + c) s = s.replace(r"\,", " ") s = s.replace(r"\!", " ") s = s.replace(r"\;", " ") s = s.replace(r"\\", " ") s = s.replace(r"\[", " ") s = s.replace(r"\]", " ") s = s.replace(r"\quad", " ") s = s.replace(r"\qquad", " ") if not grammar: s = s.replace(r"\\", " ") s = s.replace(r"\leftrightarrow", c + "lrarrow" + c) s = s.replace(r"\leftarrow", c + "larrow" + c) s = s.replace(r"\rightarrow", c + "rarrow" + c) s = s.replace(r"\left", " ") s = s.replace(r"\right", " ") s = s.replace(r"\big", " ") s = s.replace(r"\Big", " ") s = s.replace(r"\bigg", " ") s = s.replace(r"\Bigg", " ") s = s.replace(r"\small", " ") s = s.replace(r"\tiny", " ") s = s.replace(r"\scriptsize", " ") s = s.replace(r"\footnotesize", " ") s = s.replace(r"\normalsize", " ") s = s.replace(r"\large", " ") s = s.replace(r"\Large", " ") s = s.replace(r"\LARGE", " ") s = s.replace(r"\huge", " ") s = s.replace(r"\Huge", " ") s = s.replace(r"\displaystyle", " ") s = s.replace(r"\nonumber", " ") s = s.replace(r"\notag", " ") s = s.replace(r"\exists", c + "exists" + c) s = s.replace(r"\forall", c + "forall" + c) for i in range(5): s = s.replace(r"," + " " * i + r"exists", "comma_exists") s = s.replace(r"," + " " * i + r"forall", "comma_forall") s = s.replace(r"{\perp" + " " * i + r"\perp}", c + "\perp" + c) s = s.replace("\\", c + "BACKSLASH_") return s @lark_v_args(inline = True) class RegionTree(lark_Transformer): grammar = r""" ?start: region_or_expr | NAME "=" region_or_expr -> set_var | "check" region_or_expr -> check | "latex" region_or_expr -> latex | "assume" region -> assume | "clear" "assume" -> clear_assume | "style" NAME -> style ?region_or_expr: region | expr | "simplify" region_or_expr -> simplified ?region: region_aux | region (">>" | "implies" | "\Rightarrow" | "\to" | "\rightarrow") region_aux -> rshift ?region_aux: region_union | "exists" comp ("." | ":" | "st") region_aux -> exists_r | "forall" comp ("." | ":" | "st") region_aux -> forall_r | region_aux ", exists" comp -> exists | region_aux ", forall" comp -> forall ?region_union: region_inter | region_union "|" region_inter -> or_ | region_union "or" region_inter -> or_ | region_union "OR" region_inter -> or_ | region_union "\vee" region_inter -> or_ ?region_inter: region_atom | region_inter "&" region_atom -> and_ | region_inter_name "and" region_atom -> and_ | region_inter_name "AND" region_atom -> and_ | region_inter_name "\wedge" region_atom -> and_ | region_inter "," region_atom -> and_ ?region_inter_name: region_inter | NAME -> region_var ?region_atom: "(" region ")" | "{" region "}" | "\{" region "\}" | "\begin" "{" dis_latex_block "}" ["{" NAME "}"] region "\end" "{" dis_latex_block "}" -> taken2 | expr ( ["&"] rel expr)+ -> rels | "~" region_atom -> not_ | "not" region_atom -> not_ | "NOT" region_atom -> not_ | "\lnot" region_atom -> not_ | "markov" "(" comp_closer ("," comp_closer)+ ")" -> markov | "indep" "(" comp_closer ("," comp_closer)+ ")" -> indep | region_atom "." "exists" comp -> exists | region_atom "." "forall" comp -> forall | region_atom "." "simplified" "(" ")" -> simplified | region_atom "." "latex" "(" ")" -> latex | (comp_single | comp_comma_b) (("markov" | "->" | "<->" | "\to" | "\rightarrow" | "\leftrightarrow") (comp_single | comp_comma_b))+ -> markov | (comp_single | comp_comma_b) (("indep" | "\perp" | ("\perp" "\perp") | ".") (comp_single | comp_comma_b))+ -> indep | "assumption" -> assumption | "@" NAME -> region_var ?region_atom_name: region_atom | NAME -> region_var ?rel: ("==" | "=") -> rel_eq | ("!=" | "\neq") -> rel_ne | "<" -> rel_lt | ">" -> rel_gt | ("<=" | "\le") -> rel_le | (">=" | "\ge") -> rel_ge ?expr: expr_product | expr "+" expr_product -> add | expr "-" expr_product -> sub ?expr_product: expr_atom | expr_number expr_atom -> mul | expr_product ("*" | "\cdot" | "\times") expr_atom -> mul | expr_product "/" expr_atom -> truediv ?expr_atom: expr_number | "H" "(" comp ")" -> ent | "H" "(" comp "|" comp ")" -> entc | "I" "(" comp_and ")" -> mi | "I" "(" comp_and "|" comp ")" -> mic | "-" expr_atom -> neg | NAME -> expr_var | NAME "{" (NAME | NUMBER) "}" -> expr_var ?expr_number: cnumber | "\frac" "{" expr "}" "{" expr "}" -> truediv | "(" expr ")" | "{" expr "}" | "\{" expr "\}" ?cnumber: NUMBER -> number ?comp_and: comp | comp_and "&" comp -> and_ | comp_and "\wedge" comp -> and_ | comp_and ";" comp -> and_ | comp_and ":" comp -> and_ ?comp: comp_closer | comp "," comp_closer -> add ?comp_closer: comp_atom | comp_closer "+" comp_atom -> add | comp_closer comp_atom -> add ?comp_atom: comp_single | "(" comp ")" | "{" comp "}" | "\{" comp "\}" ?comp_single: NAME -> comp_var | NAME "{" (NAME | NUMBER) "}" -> comp_var | NAME "^" "{" (NAME | NUMBER) "}" -> comp_var | NAME "^" (NAME | NUMBER) -> comp_var ?comp_comma: "(" comp_comma ")" | (comp_single | comp_comma) "," comp -> add ?comp_comma_b: "(" comp_comma ")" ?dis_latex_block: ("array" | "align" | ("align" "*") | "equation" | ("equation" "*") | "eqnarray" | ("eqnarray" "*") | "split" | ("split" "*") | "multline" | ("multline" "*")) %import common.CNAME -> NAME %import common.NUMBER %import common.WS %ignore WS %ignore "$" %ignore "$$" """ grammar = parse_sanitize(grammar, grammar = True) from operator import eq, ne, lt, le, gt, ge, and_, or_, not_, add, sub, mul, truediv, neg, rshift number = float def __init__(self): self.varmap = {} def clear(self): self.varmap = {} def take1(self, *args): return args[1] def taken1(self, *args): return args[-1] def taken2(self, *args): return args[-2] def set_var(self, name, value): self.varmap[name] = value return value def type_var(self, *args, cur_type = "comp"): name = "" if len(args) == 2 and args[0].endswith("_"): name = str(args[0]) + "{" + str(args[1]) + "}" else: name = "".join(str(a) for a in args) if name in self.varmap: return self.varmap[name] if cur_type == "expr": r = Expr.real(name) elif cur_type == "region": r = Region.universe() else: r = Comp.rv(name) self.varmap[name] = r return r def ent(self, x): return Expr.H(x) def entc(self, x, y): return Expr.Hc(x, y) def mi(self, x): return I(x) def mic(self, x, y): return I(x | y) def expr_var(self, *args): return self.type_var(*args, cur_type = "expr") def comp_var(self, *args): return self.type_var(*args, cur_type = "comp") def region_var(self, *args): return self.type_var(*args, cur_type = "region") def simplified(self, x): return x.simplified_truth() def check(self, x): return x.check() def latex(self, x): return x.latex() def assume(self, x): x.assume() return self.assumption() def assumption(self): truth = PsiOpts.settings["truth"] if truth is not None: return truth.copy() else: return Region.universe() def clear_assume(self): PsiOpts.setting(truth = None) def add_list(self, x, y): if isinstance(x, Comp): x = [x] if isinstance(y, Comp): y = [y] return x + y def markov(self, *x): return markov(*x) def indep(self, *x): return indep(*x) def exists(self, x, y): return x.exists(y) def forall(self, x, y): return x.forall(y) def exists_r(self, x, y): return y.exists(x) def forall_r(self, x, y): return y.forall(x) def rels(self, *x): r = None for i in range(0, len(x) - 2, 2): a, rel, b = x[i], x[i + 1], x[i + 2] t = None if rel == "eq": t = a == b elif rel == "ne": t = a != b elif rel == "lt": t = a < b elif rel == "gt": t = a > b elif rel == "le": t = a <= b elif rel == "ge": t = a >= b if r is None: r = t else: r &= t return r def rel_eq(self): return "eq" def rel_ne(self): return "ne" def rel_lt(self): return "lt" def rel_gt(self): return "gt" def rel_le(self): return "le" def rel_ge(self): return "ge" def style(self, s): if s == "std" or s == "standard": PsiOpts.setting(str_style = "std") elif s == "code" or s == "psitip": PsiOpts.setting(str_style = "code") elif s == "latex": PsiOpts.setting(str_style = "latex") else: return "Unrecognized style: " + s + ". Options are: std, code, latex" class RegionParser: default_parser = None def __init__(self): if lark is None: raise RuntimeError("Requires Lark. Please install it first.") self.tree = RegionTree() # self.parser = lark.Lark(RegionTree.grammar, parser = "earley", transformer = self.tree) self.parser = lark.Lark(RegionTree.grammar, parser = "lalr", transformer = self.tree) def parse(self, s): s = parse_sanitize(s) if "\n" in s: try: return self.parser.parse(s) except lark.exceptions.LarkError as err: lines = [x.strip() for x in s.split("\n")] lines = [x for x in lines if len(x)] s = ",\n".join(lines) return self.parser.parse(s) def clear(self): self.tree.clear() @staticmethod def parse_default(s, clear = True): if RegionParser.default_parser is None: RegionParser.default_parser = RegionParser() if clear: RegionParser.default_parser.clear() return RegionParser.default_parser.parse(s) # Shortcuts def alland(a): """Intersection of elements in list a (using operator &).""" r = None for x in a: if r is None: r = x else: r &= x return r def anyor(a): """Union of elements in list a (using operator |).""" r = None for x in a: if r is None: r = x else: r |= x return r def rv(*args, latex = None, index = False, split = True, alg = None, **kwargs): """Random variable""" if len(args) == 0: return Comp.empty() if split: args = [x for b in args for x in iutil.split_comma(b)] if isinstance(latex, str): if split: latex = iutil.split_comma(latex) else: latex = [latex] if latex is not None: args = [x + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + y for x, y in zip(args, latex)] r = Comp.empty() for a in args: t = iutil.ensure_comp(a) if t is not None: r += t # if isinstance(a, str): # r += Comp.rv(a) # elif isinstance(a, IBaseObj): # r += a.allcomprv_noaux() if index: r.add_markers([("index_shift", 0)]) for key, value in kwargs.items(): r.add_markers([(key, value)]) if alg is not None: r.set_algtype(alg) return r def rv_seq(name, st, en = None, alg = None): """Sequence of random variables""" r = Comp.array(name, st, en) if alg is not None: r.set_algtype(alg) return r def rv_array(name, st, en = None, alg = None): """Array of random variables""" r = rv_seq(name, st, en, alg) return CompArray.make(*r) def rv_series(name, future = True, sufp = "P", suff = "F"): """Array of random variables, with past, current and future random variables""" if not future: suff = None return CompArray.series_sym(name, sufp = sufp, suff = suff) def real(*args, latex = None, split = True): """Real variable""" if split: args = [x for b in args for x in iutil.split_comma(b)] if isinstance(latex, str): if split: latex = iutil.split_comma(latex) else: latex = [latex] if latex is not None: args = [x + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + y for x, y in zip(args, latex)] is_one = len(args) == 1 r = [] for a in args: if isinstance(a, str): r.append(Expr.real(a)) elif iutil.isnumeric(a): r.append(Expr.const(a)) elif isinstance(a, IBaseObj): for b in a.allcomprealvar_exprlist(): r.append(b.copy()) is_one = False if is_one and len(r) == 1: return r[0] return ExprArray(r) def real_array(name, st, en = None): """Array of random variables""" t = rv_seq(name, st, en) return ExprArray([Expr.real(a.name) for a in t.varlist]) def real_seq(name, st, en = None): """Array of random variables""" t = rv_seq(name, st, en) return ExprArray([Expr.real(a.name) for a in t.varlist]) def expr(*args): """Convert to an expression""" if len(args) == 0: return Expr.zero() r = None for a in args: t = iutil.ensure_expr(a, strict = False) if t is not None: if r is None: r = t else: r += t return r def region(*args): """Convert to a region""" if len(args) == 0: return Region.universe() r = None for a in args: t = iutil.ensure_region(a) if t is not None: if r is None: r = t else: r &= t return r def rv_empty(): """Empty random variable""" return Comp.empty() def zero(): """Zero expression""" return Expr.zero() def universe(): """Universal set. Returns a region with no constraints. """ return Region.universe() def empty(): """Empty set. Returns a region that is empty. """ return RegionOp.empty() def emptyset(): """Empty set. Returns a region that is empty. """ return RegionOp.empty() def H_inner(*args, prefer_multi = False): if len(args) == 1: x = args[0] if isinstance(x, Comp): if x.isempty(): return Expr.zero() return Expr.H(x) if isinstance(x, ConcDist): return x.entropy() if isinstance(x, Term): return Expr([(x.copy(), 1.0)]) return Expr.H(iutil.ensure_comp(x)) else: return Expr.fromterm(Term.from_symbols(args, prefer_multi=prefer_multi)) @fcn_list_to_list def H(*args): """Entropy. Returns a symbolic expression for the entropy e.g. use H(X + Y | Z + W) for H(X,Y|Z,W) """ return H_inner(*args, prefer_multi=False) @fcn_list_to_list def I(*args): """Mutual information. Returns a symbolic expression for the mutual information e.g. use I(X + Y & Z | W) for I(X,Y;Z|W) """ return H_inner(*args, prefer_multi=True) def I0(*args): """Shorthand for I(...)==0. e.g. use I0(X + Y & Z | W) for I(X,Y;Z|W)==0 """ return H(*args) == 0 @fcn_list_to_list def Hc(x, z): """Conditional entropy. Hc(X, Z) is the same as H(X | Z) """ return Expr.Hc(x, z) @fcn_list_to_list def Ic(x, y, z): """Conditional mutual information. Ic(X, Y, Z) is the same as I(X & Y | Z) """ return Expr.Ic(x, y, z) @fcn_list_to_list def indep(*args): """Return Region where the arguments are independent.""" args = [iutil.ensure_comp(t) for t in args] r = Region.universe() for i in range(1, len(args)): r &= Expr.I(iutil.sumlist(args[:i]), iutil.sumlist(args[i])) == 0 return r def indep_across(*args): """Take several arrays, return Region where entries are independent across dimension.""" n = max([len(a) for a in args]) vec = [iutil.sumlist([a[i] for a in args if i < len(a)]) for i in range(n)] return indep(*vec) @fcn_list_to_list def equiv(*args): """Return Region where the arguments contain the same information.""" args = iutil.type_coerce(args) if len(args) <= 1: return Region.universe() r = Region.universe() for i in range(1, len(args)): if isinstance(args[0], Comp): r &= (Expr.Hc(args[i], args[0]) == 0) & (Expr.Hc(args[0], args[i]) == 0) elif isinstance(args[0], Expr): r &= args[0] == args[i] elif isinstance(args[0], Region): r &= args[0] == args[i] return r @fcn_list_to_list def markov(*args): """Return Region where the arguments form a Markov chain.""" args = [iutil.ensure_comp(t) for t in args] r = Region.universe() for i in range(2, len(args)): r &= Expr.Ic(iutil.sumlist(args[:i-1]), iutil.sumlist(args[i]), iutil.sumlist(args[i-1])) == 0 return r def eqdist(*args): """Return Region where the argument lists have the same distribution. Only equalities of entropies are enforced. e.g. eqdist([X, Y], [Z, W]) """ m = min(len(a) for a in args) r = Region.universe() for i in range(1, len(args)): for mask in range(1, 1 << m): x = Comp.empty() y = Comp.empty() for j in range(m): if mask & (1 << j) != 0: x += args[0][j] y += args[i][j] if x != y: r &= Expr.H(x) == Expr.H(y) return r def eqdist_across(*args): """Take several arrays, return Region where entries have the same distribution across dimension. Only equalities of entropies are enforced. """ n = min([len(a) for a in args]) vec = [[a[i] for a in args] for i in range(n)] return eqdist(*vec) def exchangeable(*args): """Return Region where the arguments are exchangeable random variables. Only equalities of entropies are enforced. e.g. exchangeable(X, Y, Z) """ r = Region.universe() for tsize in range(1, len(args)): cvar = sum(args[:tsize]) for comb in itertools.combinations(args, tsize): tvar = sum(comb) if tvar != cvar: r &= Expr.H(cvar) == Expr.H(tvar) return r def iidseq(*args): """Return Region where the arguments form an i.i.d. sequence. Only equalities of entropies are enforced. e.g. iidseq(X, Y, Z), iidseq([X1,X2], [Y1,Y2], [Z1,Z2]) """ return indep(*args) & eqdist(*args) def iidseq_across(*args): """Take several arrays, return Region where entries are i.i.d. across dimension. Only equalities of entropies are enforced. """ n = min([len(a) for a in args]) vec = [[a[i] for a in args] for i in range(n)] return indep(*vec) & eqdist(*vec) def bnet(*args): """Create a Bayesian network with a list of edges (each edge is a tuple), or obtain the Bayesian network of a region. """ r = BayesNet() for a in args: if isinstance(a, Region): r += a.get_bayesnet() else: r += a return r @fcn_list_to_list def sfrl_cons(x, y, u, k = None, gap = None): """Strong functional representation lemma. Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978. """ if k is None: r = (Expr.Hc(y, x + u) == 0) & (Expr.I(x, u) == 0) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) r &= Expr.Ic(x, u, y) <= gap return r else: r = (Expr.Hc(k, x + u) == 0) & (Expr.Hc(y, u + k) == 0) & (Expr.I(x, u) == 0) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) r &= Expr.H(k) <= Expr.I(x, y) + gap return r @fcn_list_to_list def sunflower(*args): """Sunflower dependency. """ r = Region.universe() for i in range(len(args)): r &= indep(*(args[:i] + args[i+1:])).conditioned(args[i]) return r @fcn_list_to_list def stardep(*args): """Star dependency. """ SU = Comp.rv("SU").avoid(*args) r = Region.universe() for a in args: r &= H(SU | a) == 0 return (r & indep(*args).conditioned(SU)).exists(SU) @fcn_list_to_list def cardbd(x, n): """Return Region where the cardinality of x is upper bounded by n.""" if n <= 1: return H(x) == 0 loge = PsiOpts.settings["ent_coeff"] V = rv_seq("V", 0, n-1).avoid(x) r = Expr.H(V[n - 2]) == 0 r2 = Region.universe() for i in range(0, n - 1): r2 &= Expr.Hc(V[i], V[i - 1] if i > 0 else x) == 0 r |= Expr.Hc(V[i - 1] if i > 0 else x, V[i]) == 0 r = r.implicated(r2, skip_simplify = True).forall(V) return r & (H(x) <= numpy.log(n) * loge) @fcn_list_to_list def isbin(x): """Return Region where x is a binary random variable.""" return cardbd(x, 2) @fcn_list_to_list def isuniform(x): """Return Region where x is a uniformly distributed random variable. Zhen Zhang and Raymond W Yeung, "A non-Shannon-type conditional inequality of information quantities", IEEE Trans. Inf. Theory 43, 6 (1997), pp. 1982-1986. """ U = rv("U").avoid(x) V = rv("V").avoid(x) r = (H(x | U+V) == 0) & (H(U | x+V) == 0) & (H(V | x+U) == 0) r &= indep(x, U) & indep(x, V) & indep(U, V) return r.exists(U+V) def sfrl(gap = None): """Strong functional representation lemma. Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978. """ disjoint_id = iutil.get_count() # SX, SY, SU = rv("SX", "SY", "SU") SX = rv("SX", latex = "S_X") SY = rv("SY", latex = "S_Y") SU = rv("SU", latex = "S_U") SX.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) SY.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) #SU.add_markers([("mustuse", 1)]) r = ((Expr.Hc(SY, SX + SU) == 0).add_meta("pf_note", ["SFRL fcn"]) & (Expr.I(SX, SU) == 0).add_meta("pf_note", ["SFRL indep."])) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) r &= (Expr.Ic(SX, SU, SY) <= gap).add_meta("pf_note", ["SFRL gap"]) return r.exists(SU).forall(SX + SY) def copylem(n = 2, m = 1): """Copy lemma: for any A, B, there exists C such that (A, B) has the same distribution as (A, C), and B-A-C forms a Markov chain. n, m are the dimensions of A, B respectively. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. Randall Dougherty, Chris Freiling, and Kenneth Zeger. "Non-Shannon information inequalities in four random variables." arXiv preprint arXiv:1104.3602 (2011). """ disjoint_id = iutil.get_count() symm_id_x = iutil.get_count() symm_id_y = iutil.get_count() X = rv_seq("A", 0, n) for i in range(n): X[i].add_markers([("disjoint", disjoint_id), ("symm", symm_id_x), ("symm_nonempty", 1)]) Y = rv_seq("B", 0, m) Z = rv_seq("C", 0, m) for i in range(m): Y[i].add_markers([("disjoint", disjoint_id), ("symm", symm_id_y), ("symm_nonempty", 1)]) return (eqdist(X + Y, X + Z) & markov(Y, X, Z)).exists(Z).forall(X + Y).add_meta("pf_note", ["copy lemma"]) def dblmarkov(): """Double Markov property: If X-Y-Z and Y-X-Z are Markov chains, then there exists W that is a function of X, a function of Y, and (X,Y)-W-Z is Markov chain. Imre Csiszar and Janos Korner. Information theory: coding theorems for discrete memoryless systems. Cambridge University Press, 2011. """ symm_id_x = iutil.get_count() nonsubset_id_x = iutil.get_count() X = rv("DX", latex = "D_X") Y = rv("DY", latex = "D_Y") Z = rv("DZ", latex = "D_Z") W = rv("DW", latex = "D_W") X.add_markers([("symm", symm_id_x), ("nonsubset", nonsubset_id_x), ("nonempty", 1)]) Y.add_markers([("symm", symm_id_x), ("nonsubset", nonsubset_id_x), ("nonempty", 1)]) Z.add_markers([("nonempty", 1)]) return ((markov(X, Y, Z) & markov(Y, X, Z)) >> ((H(W|X) == 0) & (H(W|Y) == 0) & markov(X+Y, W, Z)).exists(W)).forall(X+Y+Z).add_meta("pf_note", ["double Markov"]) def mmrv_thm(n = 2): """The non-Shannon inequality in the paper: Makarychev, K., Makarychev, Y., Romashchenko, A., & Vereshchagin, N. (2002). A new class of non-Shannon-type inequalities for entropies. Communications in Information and Systems, 2(2), 147-166. """ disjoint_id = iutil.get_count() symm_id_x = iutil.get_count() symm_id_u = iutil.get_count() X = rv_seq("CX", 0, n) for i in range(n): X[i].add_markers([("disjoint", disjoint_id), ("symm", symm_id_x), ("symm_nonempty", 2)]) U = Comp.rv("CU") V = Comp.rv("CV") Z = Comp.rv("CZ") U.add_markers([("nonempty", 1), ("symm", symm_id_u)]) V.add_markers([("nonempty", 1), ("symm", symm_id_u)]) Z.add_markers([("nonempty", 1)]) expr = H(X) + n * I(U & V & Z) expr -= sum(I(U & V | Xi) for Xi in X) expr -= sum(H(Xi) for Xi in X) expr -= I(U+V & Z) return (expr <= 0).forall(X + U + V + Z) def zydfz_thm(mode = ""): """The non-Shannon inequalities in the paper: Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. Randall Dougherty, Christopher Freiling, and Kenneth Zeger. "Six new non-Shannon information inequalities." 2006 IEEE International Symposium on Information Theory. IEEE, 2006. """ disjoint_id = iutil.get_count() A = Comp.rv("A") B = Comp.rv("B") C = Comp.rv("C") D = Comp.rv("D") A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) r = Region.universe() if mode.lower() != "dfz": r &= (2*I(C&D) <= I(A&B) + I(A&C+D) + 3*I(C&D|A) + I(C&D|B) ).add_meta("pf_note", ["ZY ineq."]) # ZY if mode.lower() != "zy": r &= (2*I(A&B) <= 3*I(A&B|C) + 3*I(A&C|B) + 3*I(B&C|A) + 2*I(A&D) + 2*I(B&C|D) ).add_meta("pf_note", ["DFZ ineq. 1"]) # DFZ1 r &= (2*I(A&B) <= 4*I(A&B|C) + I(A&C|B) + 2*I(B&C|A) + 3*I(A&B|D) + I(B&D|A) + 2*I(C&D)).add_meta("pf_note", ["DFZ ineq. 2"]) # DFZ2 r &= (2*I(A&B) <= 3*I(A&B|C) + 2*I(A&C|B) + 4*I(B&C|A) + 2*I(A&C|D) + I(A&D|C) + 2*I(B&D) + I(C&D|A)).add_meta("pf_note", ["DFZ ineq. 3"]) # DFZ3 r &= (2*I(A&B) <= 5*I(A&B|C) + 3*I(A&C|B) + I(B&C|A) + 2*I(A&D) + 2*I(B&C|D) ).add_meta("pf_note", ["DFZ ineq. 4"]) # DFZ4 r &= (2*I(A&B) <= 4*I(A&B|C) + 4*I(A&C|B) + I(B&C|A) + 2*I(A&D) + 3*I(B&C|D) + I(C&D|B) ).add_meta("pf_note", ["DFZ ineq. 5"]) # DFZ5 r &= (2*I(A&B) <= 3*I(A&B|C) + 2*I(A&C|B) + 2*I(B&C|A) + 2*I(A&B|D) + I(A&D|B) + I(B&D|A) + 2*I(C&D)).add_meta("pf_note", ["DFZ ineq. 6"]) # DFZ6 return r.forall(A+B+C+D) def dfz_thm(ncopy = 3): """The non-Shannon inequalities in the paper: Dougherty, Randall, Chris Freiling, and Kenneth Zeger. "Non-Shannon information inequalities in four random variables." arXiv preprint arXiv:1104.3602 (2011). """ disjoint_id = iutil.get_count() A = Comp.rv("A") B = Comp.rv("B") C = Comp.rv("C") D = Comp.rv("D") A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) r = Region.universe() cl = [] cl += [(4, 2, 1, 3, 1, 0, 2)] if ncopy == 2: cl += [ (5, 3, 1, 2, 0, 0, 2), (4, 4, 1, 2, 1, 1, 2), (3, 3, 3, 2, 0, 0, 2), (3, 4, 2, 3, 1, 0, 2), (3, 2, 2, 2, 1, 1, 2) ] for a, b, c, d, e, f, g in cl: r &= (2*I(A&B) <= a*I(A&B|C) + b*I(A&C|B) + c*I(B&C|A) + d*I(A&B|D) + e*I(A&D|B) + f*I(B&D|A) + g*I(C&D) ).add_meta("pf_note", ["DFZ ineq. 2 copy"]) if ncopy >= 3: cl = [ (3, 4, 4, 4, 3, 1, 1, 3, 0), (3, 9, 6, 1, 3, 0, 0, 3, 0), (3, 4, 6, 6, 3, 0, 0, 3, 0), (2, 3, 3, 1, 5, 2, 0, 2, 0), (3, 4, 3, 3, 3, 3, 3, 3, 0), (3, 6, 3, 1, 6, 3, 0, 3, 0), (4, 5, 8, 8, 4, 1, 1, 4, 0), (2, 4, 2, 1, 2, 0, 0, 2, 3), (4, 5, 5, 5, 4, 4, 4, 4, 0), (2, 3, 3, 2, 2, 0, 0, 2, 0), (3, 7, 5, 1, 3, 1, 1, 3, 0), (4, 6, 4, 3, 4, 2, 1, 4, 0), (2, 5, 2, 1, 2, 0, 0, 2, 0), (2, 4, 3, 1, 2, 0, 0, 2, 0), (2, 4, 1, 2, 2, 3, 0, 2, 0), (2, 4, 2, 1, 2, 4, 1, 2, 0) ] cl += [ (3, 4, 9, 3, 6, 3, 0, 3, 0), (3, 7, 4, 1, 4, 1, 0, 3, 0), (3, 4, 6, 4, 4, 1, 0, 3, 0), (4, 5, 17, 6, 6, 7, 0, 4, 0), (4, 5, 17, 13, 6, 2, 0, 4, 0), (3, 4, 7, 5, 3, 1, 0, 3, 0), (6, 8, 9, 9, 6, 10, 1, 6, 0), (6, 13, 20, 2, 9, 3, 0, 6, 0), (4, 10, 15, 1, 4, 2, 2, 4, 0), (4, 6, 11, 3, 6, 2, 0, 4, 0), (3, 6, 6, 1, 5, 4, 0, 3, 0), (3, 6, 8, 1, 3, 2, 2, 3, 0), (4, 5, 6, 6, 4, 2, 2, 4, 0), (3, 8, 6, 1, 3, 1, 0, 3, 0), (4, 14, 10, 1, 6, 2, 0, 4, 0), (3, 4, 4, 3, 3, 4, 2, 3, 0), (4, 13, 9, 1, 7, 3, 0, 4, 0), (6, 8, 16, 7, 6, 3, 3, 6, 0) ] for a, b, c, d, e, f, g, h, i in cl: r &= (a*I(A&B) <= b*I(A&B|C) + c*I(A&C|B) + d*I(B&C|A) + e*I(A&B|D) + f*I(A&D|B) + g*I(B&D|A) + h*I(C&D) + i*I(C&D|A) ).add_meta("pf_note", ["DFZ ineq. 3 copy"]) return r.forall(A+B+C+D) def matus_thm(min_s, max_s): """The non-Shannon inequalities in the paper: F. Matus, "Infinitely many information inequalities", Proc. IEEE International Symposium on Information Theory, 2007 """ disjoint_id = iutil.get_count() A = Comp.rv("A") B = Comp.rv("B") C = Comp.rv("C") D = Comp.rv("D") A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)]) r = Region.universe() for s in range(min_s, max_s + 1): r &= (s*I(A&B) <= (s*(s+3)/2) * I(A&B|C) + (s*(s+1)/2) * I(A&C|B) + I(B&C|A) + s*I(A&B|D) + s*I(C&D) ).add_meta("pf_note", ["Matus ineq. s = " + str(s)]) return r.forall(A+B+C+D) def ainfdiv(n = 2, coeff = None, cadd = None): """The approximate infinite divisibility of information. C. T. Li, "Infinite Divisibility of Information," arXiv preprint arXiv:2008.06092 (2020). """ X = Comp.rv("X") Z = rv_seq("Z", 0, n) if coeff is None: coeff = 1.0 / (1.0 - (1.0 - 1.0 / n) ** n) / n else: coeff = coeff / n if cadd is None: cadd = 2.43 r = iidseq(*Z) & (H(X | Z) == 0) r &= H(Z[0]) <= H(X) * coeff + cadd return r.exists(Z).forall(X) def exists_xor(): """For any X, there exists Z0 uniformly distributed and Z1 = X XOR Z0. """ X = Comp.rv("X") Z = rv_seq("Z", 0, 2) r = indep(X, Z[0]) & indep(X, Z[1]) & (H(X | Z) == 0) return r.exists(Z).forall(X).add_meta("pf_note", ["exists XOR"]) def exists_linear(n, p = 2, maxsize = None): """There exists linear combinations of n i.i.d. GF(p) elements. p must be a prime. """ if maxsize is None: maxsize = n vs = [] for mask in igen.subset([1 << i for i in range(n)], minsize = 1, maxsize = maxsize): gens = [] first = True for i in range(n): if not (mask & (1 << i)): gens.append([0]) continue if first: gens.append([1]) first = False else: gens.append(list(range(1, p))) for t in itertools.product(*gens): vs.append(t) loge = PsiOpts.settings["ent_coeff"] logp = numpy.log(p) * loge r = Region.universe() V = Comp.empty() for v in vs: cname = "" cname_latex = "" for i, a in enumerate(v): if a == 0: continue if cname != "": cname += "(+)" if cname_latex != "": cname_latex += "\oplus " # cname_latex += "+_{" + str(p) + "}" if a != 1: cname += str(a) cname_latex += str(a) cname += "V" + str(i) cname_latex += "V_{" + str(i) + "}" crv = rv(cname, latex = cname_latex, split = False) r &= Expr.H(crv) == logp V += crv for basis in itertools.combinations(range(len(vs)), n): mat = numpy.array([vs[i] for i in basis], dtype = numpy.int64) if int(round(numpy.linalg.det(mat))) % p == 0: continue Vb = sum((V[i] for i in basis), Comp.empty()) r &= indep(*Vb) r &= Expr.Hc(V - Vb, Vb) == 0 return r.exists(V).add_meta("pf_note", ["exists XOR"]) def existence(f, numarg = 2, nonempty = False): """A region for the existence of the random variable f(X, Y). e.g. existence(meet), existence(mss) """ T = rv_seq("T", 0, numarg) X = f(*T) r = X.varlist[0].reg if X.get_marker_key("symm_args") is not None: # r.issymmetric(T): symm_id_x = iutil.get_count() for i in range(numarg): T[i].add_markers([("symm", symm_id_x)]) if X.get_marker_key("nonsubset_args") is not None: nonsubset_id_x = iutil.get_count() for i in range(numarg): T[i].add_markers([("nonsubset", nonsubset_id_x)]) X = f(*T) r = X.varlist[0].reg X = X.copy_noreg() if nonempty: X.add_markers([("nonempty", 1)]) r.substitute(X, X) return r.exists(X).forall(T) def rv_bit(): """A random variable with entropy 1.""" U = Comp.rv("BIT") return Comp.rv_reg(U, H(U) == 1) def exists_bit(n = 1): """There exists a random variable with entropy 1.""" U = rv_seq("BIT", 0, n) return (alland([H(x) == 1 for x in U]) & indep(*U)).exists(U) @fcn_list_to_list def emin(*args): """Return the minimum of the expressions.""" R = real(iutil.fcn_name_maker("min", args, pname = "emin", lname = "\\min", fcn_suffix = False)) R = real(str(R)) r = universe() for x in args: r &= R <= x return r.maximum(R, None, allow_reuse = True) @fcn_list_to_list def emax(*args): """Return the maximum of the expressions.""" R = real(iutil.fcn_name_maker("max", args, pname = "emax", lname = "\\max", fcn_suffix = False)) R = real(str(R)) r = universe() for x in args: r &= R >= x return r.minimum(R, None, allow_reuse = True) @fcn_list_to_list def eabs(x): """Absolute value of expression.""" R = real(iutil.fcn_name_maker("abs", x, fcn_suffix = False)) return ((R >= x) & (R >= -x)).minimum(R, None, allow_reuse = True) @fcn_list_to_list def meet(*args): """Gacs-Korner common part. Peter Gacs and Janos Korner. Common information is far less than mutual information. Problems of Control and Information Theory, 2(2):149-162, 1973. """ U = Comp.rv(iutil.fcn_name_maker("meet", args)).avoid(*args) V = Comp.rv("V").avoid(*args) U.add_markers([("mustuse", 1), ("symm_args", 1), ("nonsubset_args", 1)]) V.add_markers([("nonempty", 1)]) r = Region.universe() r2 = Region.universe() for a in args: r &= (Expr.Hc(U, a) == 0).add_meta("pf_note", ["meet"]) r2 &= (Expr.Hc(V, a) == 0).add_meta("pf_note", ["meet"]) r = r & (Expr.Hc(V, U) == 0).implicated(r2, skip_simplify = True).forall(V).add_meta("pf_note", ["meet maximal"]) ret = Comp.rv_reg(U, r, reg_det = True) #ret.add_markers([("symm_args", 1), ("nonsubset_args", 1)]) return ret @fcn_list_to_list def mss(x, y): """Minimal sufficient statistic of x about y.""" U = Comp.rv(iutil.fcn_name_maker("mss", [x, y])).avoid(x, y) V = Comp.rv("V").avoid(x, y) U.add_markers([("mustuse", 1), ("nonsubset_args", 1)]) r = (Expr.Hc(U, x) == 0) & (Expr.Ic(x, y, U) == 0) r2 = (Expr.Hc(V, x) == 0) & (Expr.Ic(x, y, V) == 0) r = r & (Expr.Hc(U, V) == 0).implicated(r2, skip_simplify = True).forall(V) ret = Comp.rv_reg(U, r, reg_det = True) #ret.add_markers([("nonsubset_args", 1)]) return ret @fcn_list_to_list def sfrl_rv(x, y, gap = None): """Strong functional representation lemma. Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978. """ U = Comp.rv(iutil.fcn_name_maker("sfrl", [x, y], pname = "sfrl_rv")).avoid(x, y) #U = Comp.rv(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True)) r = (Expr.Hc(y, x + U) == 0) & (Expr.I(x, U) == 0) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) r &= Expr.Ic(x, U, y) <= gap return Comp.rv_reg(U, r, reg_det = False) def esfrl_rv(x, y, gap = None): """Strong functional representation lemma, extended form. Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978. """ U = Comp.rv(iutil.fcn_name_maker("esfrl", [x, y], pname = "esfrl_rv")).avoid(x, y) K = Comp.rv(iutil.fcn_name_maker("esfrl_K", [x, y], pname = "esfrl_rv_K")).avoid(x, y) r = (Expr.Hc(K, x + U) == 0) & (Expr.Hc(Y, U + K) == 0) & (Expr.I(x, U) == 0) if gap is not None: if not isinstance(gap, Expr): gap = Expr.const(gap) r &= Expr.H(K) <= Expr.I(x, y) + gap return (Comp.rv_reg(U, r, reg_det = False), Comp.rv_reg(K, r, reg_det = False)) def copylem_rv(x, y): """Copy lemma: for any X, Y, there exists Z such that (X, Y) has the same distribution as (X, Z), and Y-X-Z forms a Markov chain. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. Randall Dougherty, Chris Freiling, and Kenneth Zeger. "Non-Shannon information inequalities in four random variables." arXiv preprint arXiv:1104.3602 (2011). """ U = Comp.rv(iutil.fcn_name_maker("copy", [x, y], pname = "copylem_rv")).avoid(x, y) r = eqdist(list(x) + [y], list(x) + [U]) & markov(y, sum(x), U) return Comp.rv_reg(U, r, reg_det = False) @fcn_list_to_list def total_corr(*args): """Total correlation. Watanabe S (1960). Information theoretical analysis of multivariate correlation, IBM Journal of Research and Development 4, 66-82. e.g. total_corr(X & Y & Z | W) """ x = Term.from_symbols(args, prefer_multi=True) if isinstance(x, Comp): return Expr.H(x) return sum([Expr.Hc(a, x.z) for a in x.x]) - Expr.Hc(sum(x.x), x.z) @fcn_list_to_list def dual_total_corr(*args): """Dual total correlation. Han T. S. (1978). Nonnegative entropy measures of multivariate symmetric correlations, Information and Control 36, 133-156. e.g. dual_total_corr(X & Y & Z | W) """ x = Term.from_symbols(args, prefer_multi=True) if isinstance(x, Comp): return Expr.H(x) r = Expr.Hc(sum(x.x), x.z) for i in range(len(x.x)): r -= Expr.Hc(x.x[i], sum([x.x[j] for j in range(len(x.x)) if j != i]) + x.z) return r def prob_markov_constraints(Xs, U, p): P = ConcModel() p0 = p.marginal(0) P[Xs[0]] = p0 P[U | Xs[0]] = "var,rand" vars = [P[U | Xs[0]]] Xt = Comp.empty() for i, X in enumerate(Xs): if i > 0: P[X | U] = "var,rand" vars.append(P[X | U]) Xt += X cons = (Xt | Xs[0]).pmf() == p / p0 return (P, vars, cons) @fcn_list_to_list def gacs_korner(*args, mi = None): """Gacs-Korner common information. Peter Gacs and Janos Korner. Common information is far less than mutual information. Problems of Control and Information Theory, 2(2):149-162, 1973. e.g. gacs_korner(X & Y & Z | W) """ x = Term.from_symbols(args, prefer_multi=True) U = Comp.rv("U").avoid(x) R = real(iutil.fcn_name_maker("K", x, pname = "gacs_korner" + ("_mi" if mi is True else ""), cropi = True)) r = None ro = None if mi is not True: r = universe() for a in x.x: r &= Expr.Hc(U, a+x.z) == 0 r &= R <= Expr.Hc(U, x.z) if mi is not False: ro = universe() sumx = sum(x.x, Comp.empty()) for a in x.x: ro &= Expr.Ic(U, sumx - a, a + x.z) == 0 ro &= R <= Expr.Ic(U, sumx, x.z) if mi is False: return r.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U).maximum(R, None, allow_reuse = True) elif mi is True: return ro.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U).maximum(R, None, allow_reuse = True) else: return ro.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U).maximum(R, None, allow_reuse = True, reg_outer = r.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U)) @fcn_list_to_list def gacs_korner_mi(*args): """Gacs-Korner common information, alternative characterization via mutual information. Peter Gacs and Janos Korner. Common information is far less than mutual information. Problems of Control and Information Theory, 2(2):149-162, 1973. e.g. gacs_korner_mi(X & Y & Z | W) """ x = Term.from_symbols(args, prefer_multi=True) U = Comp.rv("U").avoid(x) R = real(iutil.fcn_name_maker("\\tilde{K}", x, pname = "gacs_korner_mi", cropi = True)) r = universe() sumx = sum(x.x, Comp.empty()) for a in x.x: r &= Expr.Ic(U, sumx - a, a + x.z) == 0 r &= R <= Expr.Ic(U, sumx, x.z) return r.exists(U).maximum(R, None, allow_reuse = True) @fcn_list_to_list def wyner_ci(*args): """Wyner's common information. A. D. Wyner. The common information of two dependent random variables. IEEE Trans. Info. Theory, 21(2):163-179, 1975. e.g. wyner_ci(X & Y & Z | W) """ x = Term.from_symbols(args, prefer_multi=True) def fcncall(xdist): xdist = xdist.flattened_sublen() Xs = Comp.array("X", len(xdist.p.shape)) Xs = sum((X.set_card(s) for X, s in zip(Xs, xdist.p.shape)), Comp.empty()) card_bd = iutil.product(xdist.p.shape) + 1 U = Comp.rv("U").set_card(card_bd) P, vars, cons = prob_markov_constraints(Xs, U, xdist) return P.minimize(I(Xs & U), vars, cons) if isinstance(x, ConcDist): return fcncall(x) U = Comp.rv("U").avoid(x) R = real(iutil.fcn_name_maker("J", x, pname = "wyner_ci", cropi = True)) r = indep(*(x.x)).conditioned(U + x.z) r &= R >= Expr.Ic(U, sum(x.x), x.z) r = r.add_meta("pf_note", ["def. Wyner CI"]).exists(U).minimum(R, None, allow_reuse = True) r.terms[0][0].fcncall = fcncall r.terms[0][0].fcnargs = [x] return r @fcn_list_to_list def exact_ci(*args): """Common entropy (one-shot exact common information). G. R. Kumar, C. T. Li, and A. El Gamal. Exact common information. In Information Theory (ISIT), 2014 IEEE International Symposium on, 161-165. IEEE, 2014. e.g. exact_ci(X & Y & Z | W) """ x = Term.from_symbols(args, prefer_multi=True) def fcncall(xdist): xdist = xdist.flattened_sublen() Xs = Comp.array("X", len(xdist.p.shape)) Xs = sum((X.set_card(s) for X, s in zip(Xs, xdist.p.shape)), Comp.empty()) tprod = iutil.product(xdist.p.shape) card_bd = min(tprod, 2 ** (tprod // max(xdist.p.shape)) - 1) U = Comp.rv("U").set_card(card_bd) P, vars, cons = prob_markov_constraints(Xs, U, xdist) return P.minimize(H(U), vars, cons) if isinstance(x, ConcDist): return fcncall(x) U = Comp.rv("U").avoid(x) R = real(iutil.fcn_name_maker("G", x, pname = "exact_ci", cropi = True)) r = indep(*(x.x)).conditioned(U + x.z) r &= R >= Expr.Hc(U, x.z) r = r.add_meta("pf_note", ["def. common entropy"]).exists(U).minimum(R, None, allow_reuse = True) r.terms[0][0].fcncall = fcncall r.terms[0][0].fcnargs = [x] return r @fcn_list_to_list def H_nec(*args): """Necessary conditional entropy. Cuff, P. W., Permuter, H. H., & Cover, T. M. (2010). Coordination capacity. IEEE Transactions on Information Theory, 56(9), 4181-4206. e.g. H_nec(X + Y | W) """ x = Term.from_symbols(args) U = Comp.rv("U").avoid(x) R = real(iutil.fcn_name_maker("Hnec", x, pname = "H_nec", lname = "H^\\dagger", cropi = True)) r = markov(x.z, U, x.x[0]) & (Expr.Hc(U, x.x[0]) == 0) r &= R >= Expr.Hc(U, x.z) return r.add_meta("pf_note", ["def. nec. cond. ent."]).exists(U).minimum(R, None, allow_reuse = True) @fcn_list_to_list def excess_fi(x, y): """Excess functional information. Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978. e.g. excess_fi(X, Y) """ U = Comp.rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("excess_fi", [x, y], pname = "excess_fi", lname = "\\Psi")) r = indep(U, x) r &= R >= Expr.Hc(y, U) - Expr.I(x, y) return r.add_meta("pf_note", ["def. excess fcn info"]).exists(U).minimum(R, None, allow_reuse = True) @fcn_list_to_list def korner_graph_ent(x, y): """Korner graph entropy. J. Korner, "Coding of an information source having ambiguous alphabet and the entropy of graphs," in 6th Prague conference on information theory, 1973, pp. 411-425. C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878. e.g. korner_graph_ent(X, Y) """ U = Comp.rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("korner_graph_ent", [x, y], lname = "H_K")) r = markov(U, x, y) & (Expr.Hc(x, y+U) == 0) r &= R >= Expr.I(x, U) return r.exists(U).minimum(R, None, allow_reuse = True) @fcn_list_to_list def perfect_privacy(x, y): """Perfect privacy rate. A. Makhdoumi, S. Salamatian, N. Fawaz, and M. Medard, "From the information bottleneck to the privacy funnel," in Information Theory Workshop (ITW), 2014 IEEE, Nov 2014, pp. 501-505. S. Asoodeh, F. Alajaji, and T. Linder, "Notes on information-theoretic privacy," in Communication, Control, and Computing (Allerton), 2014 52nd Annual Allerton Conference on, Sept 2014, pp. 1272-1278. e.g. perfect_privacy(X, Y) """ U = Comp.rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("perfect_privacy", [x, y], lname = "g_0")) r = markov(x, y, U) & (Expr.I(x, U) == 0) r &= R <= Expr.I(y, U) return r.exists(U).maximum(R, None, allow_reuse = True) @fcn_list_to_list def max_interaction_info(x, y): """Maximal interaction information. C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878. e.g. max_interaction_info(X, Y) """ U = Comp.rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("max_interaction_info", [x, y], lname = "G_{NNI}")) r = Region.universe() r &= R <= Expr.Ic(x, y, U) - Expr.I(x, y) return r.exists(U).maximum(R, None, allow_reuse = True) @fcn_list_to_list def asymm_interaction_info(x, y): """Asymmetric private interaction information. C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878. e.g. max_interaction_info(X, Y) """ U = rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("asymm_interaction_info", [x, y], lname = "G_{PNI}")) r = indep(x, U) r &= R <= Expr.Ic(x, y, U) - Expr.I(x, y) return r.exists(U).maximum(R, None, allow_reuse = True) @fcn_list_to_list def symm_interaction_info(x, y): """Symmetric private interaction information. C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878. e.g. max_interaction_info(X, Y) """ U = Comp.rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("symm_interaction_info", [x, y], lname = "G_{PPI}")) r = indep(x, U) & indep(y, U) r &= R <= Expr.Ic(x, y, U) - Expr.I(x, y) return r.exists(U).maximum(R, None, allow_reuse = True) @fcn_list_to_list def minent_coupling(x, y): """Minimum entropy coupling of the distributions p_{Y|X=x}. M. Vidyasagar, "A metric between probability distributions on finite sets of different cardinalities and applications to order reduction," IEEE Transactions on Automatic Control, vol. 57, no. 10, pp. 2464-2477, 2012. A. Painsky, S. Rosset, and M. Feder, "Memoryless representation of Markov processes," in 2013 IEEE International Symposium on Information Theory. IEEE, 2013, pp. 2294-298. M. Kovacevic, I. Stanojevic, and V. Senk, "On the entropy of couplings," Information and Computation, vol. 242, pp. 369-382, 2015. M. Kocaoglu, A. G. Dimakis, S. Vishwanath, and B. Hassibi, "Entropic causal inference," in Thirty-First AAAI Conference on Artificial Intelligence, 2017. F. Cicalese, L. Gargano, and U. Vaccaro, "Minimum-entropy couplings and their applications," IEEE Transactions on Information Theory, vol. 65, no. 6, pp. 3436-3451, 2019. Cheuk Ting Li, "Efficient Approximate Minimum Entropy Coupling of Multiple Probability Distributions," https://arxiv.org/abs/2006.07955 , 2020. e.g. minent_coupling(X, Y) """ U = Comp.rv("U").avoid(x, y) R = real(iutil.fcn_name_maker("MEC", [x, y], pname = "minent_coupling", lname = "H_{couple}")) r = indep(U, x) & (Expr.Hc(y, x + U) == 0) r &= R >= Expr.H(U) return r.exists(U).minimum(R, None, allow_reuse = True) @fcn_list_to_list def mutual_dep(x): """Mutual dependence. Csiszar, Imre, and Prakash Narayan. "Secrecy capacities for multiple terminals." IEEE Transactions on Information Theory 50, no. 12 (2004): 3047-3061. """ n = len(x.x) if n <= 2: return I(x) R = real(iutil.fcn_name_maker("MD", x, pname = "mutual_dep", lname = "C_{MD}", cropi = True)) Hall = Expr.Hc(sum(x.x), x.z) r = universe() for part in iutil.enum_partition(n): if len(part) <= 1: continue expr = Expr.zero() for cell in part: xb = Comp.empty() for i in range(n): if cell & (1 << i) != 0: xb += x.x[i] expr += Expr.Hc(xb, x.z) r &= R <= (expr - Hall) / (len(part) - 1) return r.maximum(R, None, allow_reuse = True) @fcn_list_to_list def intrinsic_mi(x): """Intrinsic mutual information. U. Maurer and S. Wolf. "Unconditionally secure key agreement and the intrinsic conditional information." IEEE Transactions on Information Theory 45.2 (1999): 499-514. e.g. intrinsic_mi(X & Y | Z) """ U = Comp.rv("U").avoid(x) R = real(iutil.fcn_name_maker("IMI", x, pname = "intrinsic_mi", lname = "I_{intrinsic}", cropi = True)) r = markov(sum(x.x), x.z, U) & (R >= mutual_dep(Term(x.x, U))) return r.exists(U).minimum(R, None, allow_reuse = True) def mi_rect(xs, ys, z = None, sgn = 1): if z is None: z = Comp.empty() if not isinstance(xs, list): xs = [xs] if not isinstance(ys, list): ys = [ys] xs = [x if isinstance(x, tuple) else (x,) for x in xs] ys = [y if isinstance(y, tuple) else (y,) for y in ys] exprs = [] for px in itertools.product(*[range(len(x)) for x in xs]): for py in itertools.product(*[range(len(y)) for y in ys]): cx = sum((x[0] for x, p in zip(xs, px) if p or len(x) == 1), Comp.empty()) cy = sum((y[0] for y, p in zip(ys, py) if p or len(y) == 1), Comp.empty()) cxm = sum((x[1] for x, p in zip(xs, px) if p), Expr.zero()) cym = sum((y[1] for y, p in zip(ys, py) if p), Expr.zero()) if sgn > 0: if cx.isempty() or cy.isempty(): if cxm.iszero() and cym.iszero(): exprs.append(Expr.zero()) else: exprs.append(Expr.Ic(cx, cy, z) - cxm - cym) else: if cx.isempty() or cy.isempty(): exprs.append(-cxm - cym) else: exprs.append(Expr.Ic(cx, cy, z) - cxm - cym) if len(exprs) == 1: return exprs[0] if sgn > 0: return emax(*exprs) else: return emin(*exprs) def mi_rect_max(xs, ys, z = None): return mi_rect(xs, ys, z, 1) def mi_rect_min(xs, ys, z = None): return mi_rect(xs, ys, z, -1) def directed_info(x, y, z = None): """Directed information. Massey, James. "Causality, feedback and directed information." Proc. Int. Symp. Inf. Theory Applic.(ISITA-90). 1990. Parameters can be either Comp or CompArray. """ x = CompArray.arg_convert(x) y = CompArray.arg_convert(y) if z is None: return sum(I(x.past_ns() & y | y.past())) else: z = CompArray.arg_convert(z) return sum(I(x.past_ns() & y | y.past() + z.past_ns())) def comp_vector(*args): return CompArray.make(igen.subset(args, minsize = 1)) def comp_vector_lex(*args): n = len(args) return CompArray.make(sum((args[i] for i in range(n) if mask & (1 << i)), Comp.empty()) for mask in range(1, 1 << n)) def ent_vector(*args): """Entropy vector. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. """ if len(args) == 0: return ExprArray.empty() return H(comp_vector(*args)) def ent_vector_lex(*args): """Entropy vector in lexicographical order. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. """ if len(args) == 0: return ExprArray.empty() return H(comp_vector_lex(*args)) def mi_vector(*args, minsize = 1, maxsize = 1000): """Mutual information vector. """ r = ExprArray.empty() for xs in igen.subset([1 << x for x in range(len(args))], minsize = minsize, maxsize = maxsize): r.append(I(alland(args[x] for x in range(len(args)) if xs & (1 << x)))) return r def ent_cells(*args, minsize = 1, maxsize = 1000): """Cells of the I-measure. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. """ allrv = sum(args) r = ExprArray.empty() for xs in igen.subset([1 << x for x in range(len(args))], minsize = minsize, maxsize = maxsize): r.append(I(alland(args[x] for x in range(len(args)) if xs & (1 << x)) | sum(args[x] for x in range(len(args)) if not (xs & (1 << x))))) return r def mi_cells(*args, minsize = 2, maxsize = 1000): """Cells of the I-measure, excluding conditional entropies. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. """ return ent_cells(*args, minsize = minsize, maxsize = maxsize) def ent_region(n, real_name = "R", var_name = "X", name_subset = True, st = 0): """Entropy region. Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities," IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998. """ xs = rv_seq(var_name, 0, n) cv = comp_vector(*xs) real_ids = [] if name_subset: real_ids = ["".join(x) for x in igen.subset([[str(i + st)] for i in range(n)], minsize = 1, zero_elem = [])] else: real_ids = range(1, 1 << n) re = real_array(real_name, real_ids) return (re == H(cv)).exists(xs) def csiszar_sum(*args, telescoping = False, seq_cond = False, timerv = None): """Csiszar sum identity. J. Korner and K. Marton, "Images of a set via two channels and their role in multi-user communication," IEEE Trans. Inf. Theory, vol. 23, no. 6, pp. 751-761, 1977. I. Csiszar and J. Korner, "Broadcast channels with confidential messages," IEEE Trans. Inf. Theory, vol. 24, no. 3, pp. 339-348, 1978. """ r = Region.universe() vecs = [] vecs_f_mask = 0 ones = [] for a in args: if isinstance(a, CompArray) or isinstance(a, list): vecs.append(a) if len(a) >= 3: vecs_f_mask |= 1 << (len(vecs) - 1) # ones.append(sum(a)) if timerv is not None: for i in range(1, len(a)): r &= H(timerv | a[i]) == 0 else: ones.append(a) for xmask in igen.subset_mask(vecs_f_mask): if xmask == 0: continue x = sum(vecs[i] for i in range(len(vecs)) if xmask & (1 << i)) #for ymask in igen.subset_mask((1 << len(vecs)) - 1 - xmask): for ymask in range(1, 1 << len(vecs)): if ymask == 0: continue y = sum(vecs[i][0:2] for i in range(len(vecs)) if ymask & (1 << i)) vmask_all = 0 if seq_cond: vmask_all = ((1 << len(vecs)) - 1) & ~(xmask | ymask) for vmask in igen.subset_mask(vmask_all): for umask in range(1 << len(ones)): u = (sum((ones[i] for i in range(len(ones)) if umask & (1 << i)), Comp.empty()) + sum(sum(vecs[i]) for i in range(len(vecs)) if vmask & (1 << i))) if telescoping: r &= Expr.Ic(x[0] + x[2], y[1], u) == Expr.Ic(x[2], y[0] + y[1], u) else: r &= Expr.Ic(x[2], y[0], y[1]+u) == Expr.Ic(y[1], x[0], x[2]+u) return r.add_meta("pf_note", ["Csiszar sum"]) def ingleton_term(a1, a2, a3, a4): """The expression in Ingleton inequality. A. W. Ingleton, "Representation of matroids," in Combinatorial mathematics and its applications, D. Welsh, Ed. London: Academic Press, pp. 149-167, 1971. """ return -(Expr.H(a1) + Expr.H(a2) + Expr.H(a1+a2+a3) + Expr.H(a1+a2+a4) + Expr.H(a3+a4) - Expr.H(a1+a2) - Expr.H(a1+a3) - Expr.H(a1+a4) - Expr.H(a2+a3) - Expr.H(a2+a4)) def ingleton_ineq(a1, a2, a3, a4): """Ingleton inequality. A. W. Ingleton, "Representation of matroids," in Combinatorial mathematics and its applications, D. Welsh, Ed. London: Academic Press, pp. 149-167, 1971. """ return (ingleton_term(a1, a2, a3, a4) >= 0).add_meta("pf_note", ["Ingleton ineq."]) def ingleton_bound(*args): """Bound on entropy obtained by Ingleton inequality. A. W. Ingleton, "Representation of matroids," in Combinatorial mathematics and its applications, D. Welsh, Ed. London: Academic Press, pp. 149-167, 1971. """ n = len(args) amask = (1 << n) - 1 r = Region.universe() for a1m in igen.subset_mask(amask): if a1m == 0: continue a1 = sum((args[i] for i in range(n) if a1m & (1 << i)), Comp.empty()) for a2m in igen.subset_mask(amask - a1m): if a2m == 0: continue if a2m > a1m: break a2 = sum((args[i] for i in range(n) if a2m & (1 << i)), Comp.empty()) for a3m in igen.subset_mask(amask - a1m - a2m): if a3m == 0: continue a3 = sum((args[i] for i in range(n) if a3m & (1 << i)), Comp.empty()) for a4m in igen.subset_mask(amask - a1m - a2m - a3m): if a4m == 0: continue if a4m > a3m: break a4 = sum((args[i] for i in range(n) if a4m & (1 << i)), Comp.empty()) r &= ingleton_ineq(a1, a2, a3, a4) # return r.simplified() return r def dfz_linear_ineq(A, B, C, D, E): """Linear rank inequalities in: Dougherty, Randall, Chris Freiling, and Kenneth Zeger. "Linear rank inequalities on five or more variables." arXiv preprint arXiv:0910.0284 (2009). """ return ( ( I(A&B&C) <= I(A&B|D)+I(C&D|E)+I(A&E) ) &( I(A&B&C) <= I(A&C|D)+I(A&D|E)+I(B&E) ) &( I(A&B&D) <= I(A&C)+I(B&E|C)+I(A&D|C+E) ) &( I(A&B&D+E) <= I(A&C)+I(B&D|C)+I(A&E|C+D) ) &( I(A&B&C+E) <= I(A&C)+I(B&D|C)+I(A&E|D)+I(B&C|D+E) ) &( I(A&B&C+D) <= I(A&C)+I(B&D|E)+I(D&E|C)+I(A&C|D+E) ) &( I(A&B&D+E) <= I(A&C|D)+I(A&E|C)+I(B&D)+I(B&D|C+E) ) &( I(A&B&D)+I(A&B&C) <= I(A&B|E)+I(C&D)+I(C+D&E) ) &( I(A&B&E)+I(A&B&D) <= I(A&C)+I(D&E)+I(B&D+E|C) ) &( I(A&B&D)+I(A&B&C) <= I(C&D)+I(A&E)+I(B&D|E)+I(A&C|D+E) ) &( I(A&B+C) <= I(A&C|B+D)+I(A&C+E)+I(A&B|D+E)+I(B&D|C+E) ) &( I(A&B|C) <= I(A&B|D)+I(A&D|E)+I(B&E|C)+I(A&C|B+E)+I(C&E|B+D) ) &( I(A&B+C) <= I(A&B|D)+I(A&C+E)+I(B&D|C+E)+I(A&C|B+E)+I(C&E|B+D) ) &( I(A&B+C) <= I(A&D)+I(B&E|D)+I(A&B|C+E)+I(A&C|B+D)+I(A&C|D+E) ) &( I(A&B+C) <= I(A&D)+I(B&E|D)+I(A&C|E)+I(A&B|C+D)+I(A&C|B+D)+I(B&D|C+E) ) &( I(A&B+C) <= I(A&B|C+D)+I(A&C|B+D)+I(B+C&D|E)+I(B&C|D+E)+I(A&E) ) &( I(A+B&C|D) <= I(A&D|B+C)+I(B&D|A+C)+I(A&C|B+E)+I(B&C|A+E)+I(A&B|D+E)+I(C&E|D) ) &( I(A&B&D)+I(A&C&D) <= I(B&C)+I(B&D|E)+I(C&D|E)+I(A&E) ) &( I(A&B&E)+I(A&C&D) <= I(B&D)+I(A&C|D)+I(D&E)+I(B&E|C+D)+I(C&D|B+E) ) &( I(A&B&E)+I(A&C&D) <= I(B&C)+I(B&D)+I(A&E|B)+I(C&D|E)+I(B&E|C+D) ) &( I(A&B&C+E)+I(A&C&D) <= I(B&D)+I(A&D|E)+I(C&E)+I(B&C|D+E)+I(B&E|C+D) ) &( I(A&B&D)+I(A&B&C)+I(A&C&E) <= I(C&D)+I(A&D|E)+2*I(B&E)+I(B&C|D+E)+I(C&E|B+D) ) &( I(A&B&D)+I(A&B+C) <= 2*I(A&C|E)+I(B&E)+I(D&E)+I(A&B|C+D)+2*I(B&D|C+E)+I(C&E|B+D) ) &( I(A&C+D)+I(B&C|D) <= I(B&C|E)+I(C&E|D)+I(A&E)+I(A&C|B+D)+I(A+B&D|C)+I(A&D|B+E)+I(A&B|D+E) ) ).add_meta("pf_note", ["DFZ linear ineq."]) def linear_bound(*args): """The outer bound of the region given by linear rank inequalities. Tight for 5 or fewer random variables. Dougherty, Randall, Chris Freiling, and Kenneth Zeger. "Linear rank inequalities on five or more variables." arXiv preprint arXiv:0910.0284 (2009). """ n = len(args) amask = (1 << n) - 1 if n <= 3: return Region.universe() r = ingleton_bound(*args) if n <= 4: return r for a1m in igen.subset_mask(amask): if a1m == 0: continue A = sum((args[i] for i in range(n) if a1m & (1 << i)), Comp.empty()) for a2m in igen.subset_mask(amask - a1m): if a2m == 0: continue B = sum((args[i] for i in range(n) if a2m & (1 << i)), Comp.empty()) for a3m in igen.subset_mask(amask - a1m - a2m): if a3m == 0: continue C = sum((args[i] for i in range(n) if a3m & (1 << i)), Comp.empty()) for a4m in igen.subset_mask(amask - a1m - a2m - a3m): if a4m == 0: continue D = sum((args[i] for i in range(n) if a4m & (1 << i)), Comp.empty()) for a5m in igen.subset_mask(amask - a1m - a2m - a3m - a4m): if a5m == 0: continue E = sum((args[i] for i in range(n) if a5m & (1 << i)), Comp.empty()) r &= dfz_linear_ineq(A, B, C, D, E) return r # Numerical functions @fcn_list_to_list def elog(x, base = None): """Logarithm.""" loge = PsiOpts.settings["ent_coeff"] fcnargs = [x] if base is not None: fcnargs.append(base) R = Expr.real(iutil.fcn_name_maker("log", fcnargs, pname = "elog", lname = "\\log")) reg = Region.universe() def fcncall(x, cbase = None): cloge = loge if cbase is not None: if iutil.istorch(cbase): cloge = 1.0 / torch.log(cbase) else: cloge = 1.0 / numpy.log(cbase) if iutil.istorch(x) or iutil.istorch(cloge): return torch.log(x) * cloge else: return numpy.log(x) * cloge return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, fcnargs)) @fcn_list_to_list def renyi(*args, order = 2): """Renyi entropy. Renyi, Alfred (1961). "On measures of information and entropy". Proceedings of the fourth Berkeley Symposium on Mathematics, Statistics and Probability 1960. pp. 547-561. """ ceps = PsiOpts.settings["eps"] loge = PsiOpts.settings["ent_coeff"] x = Term.from_symbols(args) if isinstance(order, int): order = float(order) def fcncall(xdist, corder): if isinstance(corder, ConcReal): corder = corder.x if isinstance(corder, int): corder = float(corder) if isinstance(corder, float): if abs(order - 1) < ceps: return xdist.entropy() if abs(order - 0) < ceps: r = 0.0 for a in xdist.items(): if a > ceps: r += 1 return numpy.log(r) * loge if numpy.isinf(order): if xdist.istorch(): return -torch.log(torch.max(torch.flatten(xdist.p))) * loge else: return -numpy.log(max(xdist.items())) * loge if xdist.istorch(): return (torch.log(torch.sum(torch.pow(torch.flatten(xdist.p), corder))) / (1.0 - corder) * loge) else: return (numpy.log(sum(numpy.power(a, corder) for a in xdist.items())) / (1.0 - corder) * loge) if isinstance(x, ConcDist): return fcncall(x, order) R = Expr.real(iutil.fcn_name_maker("renyi", [x, order], pname = "renyi", cropi = True)) reg = Region.universe() if isinstance(order, float): if abs(order - 1) < ceps: return H(x) if order > 1: reg = (R <= H(x)) & (R >= 0) else: reg = R >= H(x) return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x, order])) @fcn_list_to_list def H0(*args): """Max entropy. Renyi, Alfred (1961). "On measures of information and entropy". Proceedings of the fourth Berkeley Symposium on Mathematics, Statistics and Probability 1960. pp. 547-561. """ ceps = PsiOpts.settings["eps"] loge = PsiOpts.settings["ent_coeff"] x = Term.from_symbols(args) x.sort() def fcncall(xdist): r = 0.0 for a in xdist.items(): if a > ceps: r += 1 return numpy.log(r) * loge if isinstance(x, ConcDist): return fcncall(x) R = Expr.real(iutil.fcn_name_maker("H_0", x, pname = "H0", cropi = True)) # reg = R >= H(x) return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), None, 0, fcncall, [x], termtname = "H0")) @fcn_list_to_list def maxcorr(*args): """Maximal correlation. H. O. Hirschfeld, "A connection between correlation and contingency," in Mathematical Proceedings of the Cambridge Philosophical Society, vol. 31, no. 04. Cambridge Univ Press, 1935, pp. 520-524. H. Gebelein, "Das statistische problem der korrelation als variations-und eigenwertproblem und sein zusammenhang mit der ausgleichsrechnung," ZAMM-Journal of Applied Mathematics and Mechanics/Zeitschrift fur Angewandte Mathematik und Mechanik, vol. 21, no. 6, pp. 364-379, 1941. A. Renyi, "On measures of dependence," Acta mathematica hungarica, vol. 10, no. 3, pp. 441-451, 1959. """ ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] x = Term.from_symbols(args, prefer_multi=True) def fcncall(xdist): xdist = xdist.flattened_sublen() if len(xdist.p.shape) != 2: raise ValueError("Only maximal correlation between two random variables is supported.") return None if xdist.istorch(): px = torch.sum(xdist.p, 1) # return torch.sum(px) py = torch.sum(xdist.p, 0) tmat = torch.zeros(xdist.p.shape, dtype=torch.float64) for x in range(xdist.p.shape[0]): for y in range(xdist.p.shape[1]): rxy = torch.sqrt(px[x] * py[y]) tmat[x, y] = xdist.p[x, y] / (rxy + ceps_d) - rxy return torch.linalg.norm(tmat, 2) else: px = numpy.sum(xdist.p, 1) py = numpy.sum(xdist.p, 0) tmat = numpy.zeros(xdist.p.shape) for x in range(xdist.p.shape[0]): for y in range(xdist.p.shape[1]): rxy = numpy.sqrt(px[x] * py[y]) if rxy > ceps: tmat[x, y] = xdist.p[x, y] / rxy - rxy return numpy.linalg.norm(tmat, 2) if isinstance(x, ConcDist): return fcncall(x) R = Expr.real(iutil.fcn_name_maker("maxcorr", x, pname = "maxcorr", cropi = True)) reg = R >= 0 return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x])) @fcn_list_to_list def divergence(x, y, mode = "kl"): """ Divergence between probability distributions. Parameters ---------- x : Comp or ConcDist The first distribution (if random variable is given, consider its distribution). y : Comp or ConcDist The second distribution (if random variable is given, consider its distribution). mode : str, optional Choices are "kl" (Kullback-Leibler divergence or relative entropy), "tv" (total variation distance), "chi2" (Chi-squared divergence), "hellinger" (Hellinger distance) and "js" (Jensen-Shannon divergence). The default is "kl". Returns ------- Expr, float or torch.Tensor The expression of the divergence. If x,y are ConcDist, gives float or torch.Tensor. """ mode = mode.lower() ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] loge = PsiOpts.settings["ent_coeff"] def fcncall(xdist, ydist): r = 0.0 if mode == "kl": for a, b in zip(xdist.items(), ydist.items()): r += iutil.xlogxoy(a, b) * loge elif mode == "tv": for a, b in zip(xdist.items(), ydist.items()): r += abs(a - b) * 0.5 elif mode == "chi2": if xdist.istorch() or ydist.istorch(): for a, b in zip(xdist.items(), ydist.items()): r += (a ** 2) / (b + ceps_d) else: for a, b in zip(xdist.items(), ydist.items()): r += (a ** 2) / b r -= 1.0 elif mode == "hellinger": for a, b in zip(xdist.items(), ydist.items()): r += (iutil.sqrt(a) - iutil.sqrt(b)) ** 2 r = iutil.sqrt(r * 0.5) elif mode == "js": for a, b in zip(xdist.items(), ydist.items()): r += iutil.xlogxoy(a, (a + b) * 0.5) * loge * 0.5 r += iutil.xlogxoy(b, (a + b) * 0.5) * loge * 0.5 return r if isinstance(x, ConcDist) and isinstance(y, ConcDist): return fcncall(x, y) R = Expr.real(iutil.fcn_name_maker("divergence", [x, y, mode], pname = "divergence", cropi = True)) reg = R >= 0 return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x, y])) @fcn_list_to_list def varent(*args): """Varentropy (variance of self information) and dispersion (variance of information density). Kontoyiannis, Ioannis, and Sergio Verdu. "Optimal lossless compression: Source varentropy and dispersion." 2013 IEEE International Symposium on Information Theory. IEEE, 2013. Polyanskiy, Yury, H. Vincent Poor, and Sergio Verdu. "Channel coding rate in the finite blocklength regime." IEEE Transactions on Information Theory 56.5 (2010): 2307-2359. """ ceps = PsiOpts.settings["eps"] ceps_d = PsiOpts.settings["opt_eps_denom"] x = Term.from_symbols(args) def fcncall(xdist): xdist = xdist.flattened_sublen() n = len(xdist.p.shape) pxs = None if xdist.istorch(): pxs = [torch.sum(xdist.p, tuple(j for j in range(n) if j != i)) for i in range(n)] else: pxs = [numpy.sum(xdist.p, tuple(j for j in range(n) if j != i)) for i in range(n)] s1 = 0.0 s2 = 0.0 for zs in itertools.product(*[range(z) for z in xdist.p.shape]): t = 0.0 if n == 1: s1 += -iutil.xlogxoy(xdist.p[zs], 1.0) s2 += iutil.xlogxoy2(xdist.p[zs], 1.0) else: prod = iutil.product(pxs[i][zs[i]] for i in range(n)) s1 += iutil.xlogxoy(xdist.p[zs], prod) s2 += iutil.xlogxoy2(xdist.p[zs], prod) return s2 - s1 ** 2 if isinstance(x, ConcDist): return fcncall(x) R = Expr.real(iutil.fcn_name_maker("varent", x, pname = "varent", cropi = True)) reg = R >= 0 return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x])) def main(): parser = RegionParser() s = "" PsiOpts.setting(str_style = "std") while True: try: inputstr = "> " if s != "": inputstr = " " t = input(inputstr) except EOFError: break if t == "exit" or t == "quit": break if s != "": s += " " # s += "\n" s += t try: res = parser.parse(s) if res is None: pass elif isinstance(res, (Expr, Region)): print(res.simplified_truth(quick = True)) elif isinstance(res, float): print(iutil.float_tostr(res, bracket = False)) else: print(res) print() s = "" except lark.exceptions.LarkError as err: if t == "": print(err) print() s = "" except Exception as err: print(err) print() s = "" if __name__ == '__main__': main()